summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Core.py
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2025-08-13 13:36:30 -0400
committerBen Sima <ben@bsima.me>2025-08-28 12:14:09 -0400
commit0b005c192b2c141c7f6c9bff4a0702361814c21d (patch)
tree3527a76137f6ee4dd970bba17a93617a311149cb /Biz/PodcastItLater/Core.py
parent7de0a3e0abbf1e152423e148d507e17b752a4982 (diff)
Prototype PodcastItLater
This implements a working prototype of PodcastItLater. It basically just works for a single user currently, but the articles are nice to listen to and this is something that we can start to build with.
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
-rw-r--r--Biz/PodcastItLater/Core.py1117
1 files changed, 1117 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
new file mode 100644
index 0000000..c0d0acf
--- /dev/null
+++ b/Biz/PodcastItLater/Core.py
@@ -0,0 +1,1117 @@
+"""Core, shared logic for PodcastItalater.
+
+Includes:
+- Database models
+- Data access layer
+- Shared types
+"""
+
+# : out podcastitlater-core
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import pathlib
+import pytest
+import secrets
+import sqlite3
+import sys
+import time
+from collections.abc import Iterator
+from contextlib import contextmanager
+from typing import Any
+
+logger = Log.setup()
+
+
+class Database: # noqa: PLR0904
+ """Data access layer for PodcastItLater database operations."""
+
+ @staticmethod
+ @contextmanager
+ def get_connection(
+ db_path: str = "podcast.db",
+ ) -> Iterator[sqlite3.Connection]:
+ """Context manager for database connections.
+
+ Yields:
+ sqlite3.Connection: Database connection with row factory set.
+ """
+ conn = sqlite3.connect(db_path)
+ conn.row_factory = sqlite3.Row
+ try:
+ yield conn
+ finally:
+ conn.close()
+
+ @staticmethod
+ def init_db(db_path: str = "podcast.db") -> None:
+ """Initialize database with required tables."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+
+ # Queue table for job processing
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS queue (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ url TEXT,
+ email TEXT,
+ status TEXT DEFAULT 'pending',
+ retry_count INTEGER DEFAULT 0,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ error_message TEXT
+ )
+ """)
+
+ # Episodes table for completed podcasts
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS episodes (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ title TEXT NOT NULL,
+ content_length INTEGER,
+ audio_url TEXT NOT NULL,
+ duration INTEGER,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Create indexes for performance
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_created "
+ "ON queue(created_at)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_created "
+ "ON episodes(created_at)",
+ )
+
+ conn.commit()
+ logger.info("Database initialized successfully")
+
+ # Run migration to add user support
+ Database.migrate_to_multi_user(db_path)
+
+ @staticmethod
+ def add_to_queue(
+ url: str,
+ email: str,
+ user_id: int,
+ db_path: str = "podcast.db",
+ ) -> int:
+ """Insert new job into queue, return job ID.
+
+ Raises:
+ ValueError: If job ID cannot be retrieved after insert.
+ """
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO queue (url, email, user_id) VALUES (?, ?, ?)",
+ (url, email, user_id),
+ )
+ conn.commit()
+ job_id = cursor.lastrowid
+ if job_id is None:
+ msg = "Failed to get job ID after insert"
+ raise ValueError(msg)
+ logger.info("Added job %s to queue: %s", job_id, url)
+ return job_id
+
+ @staticmethod
+ def get_pending_jobs(
+ limit: int = 10,
+ db_path: str = "podcast.db",
+ ) -> list[dict[str, Any]]:
+ """Fetch jobs with status='pending' ordered by creation time."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM queue WHERE status = 'pending' "
+ "ORDER BY created_at ASC LIMIT ?",
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def update_job_status(
+ job_id: int,
+ status: str,
+ error: str | None = None,
+ db_path: str = "podcast.db",
+ ) -> None:
+ """Update job status and error message."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ if status == "error":
+ cursor.execute(
+ "UPDATE queue SET status = ?, error_message = ?, "
+ "retry_count = retry_count + 1 WHERE id = ?",
+ (status, error, job_id),
+ )
+ else:
+ cursor.execute(
+ "UPDATE queue SET status = ? WHERE id = ?",
+ (status, job_id),
+ )
+ conn.commit()
+ logger.info("Updated job %s status to %s", job_id, status)
+
+ @staticmethod
+ def get_job_by_id(
+ job_id: int,
+ db_path: str = "podcast.db",
+ ) -> dict[str, Any] | None:
+ """Fetch single job by ID."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def create_episode( # noqa: PLR0913, PLR0917
+ title: str,
+ audio_url: str,
+ duration: int,
+ content_length: int,
+ user_id: int | None = None,
+ db_path: str = "podcast.db",
+ ) -> int:
+ """Insert episode record, return episode ID.
+
+ Raises:
+ ValueError: If episode ID cannot be retrieved after insert.
+ """
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episodes "
+ "(title, audio_url, duration, content_length, user_id) "
+ "VALUES (?, ?, ?, ?, ?)",
+ (title, audio_url, duration, content_length, user_id),
+ )
+ conn.commit()
+ episode_id = cursor.lastrowid
+ if episode_id is None:
+ msg = "Failed to get episode ID after insert"
+ raise ValueError(msg)
+ logger.info("Created episode %s: %s", episode_id, title)
+ return episode_id
+
+ @staticmethod
+ def get_recent_episodes(
+ limit: int = 20,
+ db_path: str = "podcast.db",
+ ) -> list[dict[str, Any]]:
+ """Get recent episodes for RSS feed generation."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?",
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_queue_status_summary(db_path: str = "podcast.db") -> dict[str, Any]:
+ """Get queue status summary for web interface."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+
+ # Count jobs by status
+ cursor.execute(
+ "SELECT status, COUNT(*) as count FROM queue GROUP BY status",
+ )
+ rows = cursor.fetchall()
+ status_counts = {row["status"]: row["count"] for row in rows}
+
+ # Get recent jobs
+ cursor.execute(
+ "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10",
+ )
+ rows = cursor.fetchall()
+ recent_jobs = [dict(row) for row in rows]
+
+ return {"status_counts": status_counts, "recent_jobs": recent_jobs}
+
+ @staticmethod
+ def get_queue_status(db_path: str = "podcast.db") -> list[dict[str, Any]]:
+ """Return pending/processing/error items for web interface."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT id, url, email, status, created_at, error_message
+ FROM queue
+ WHERE status IN ('pending', 'processing', 'error')
+ ORDER BY created_at DESC
+ LIMIT 20
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_all_episodes(
+ db_path: str = "podcast.db",
+ user_id: int | None = None,
+ ) -> list[dict[str, Any]]:
+ """Return all episodes for RSS feed."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ if user_id:
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length
+ FROM episodes
+ WHERE user_id = ?
+ ORDER BY created_at DESC
+ """,
+ (user_id,),
+ )
+ else:
+ cursor.execute("""
+ SELECT id, title, audio_url, duration, created_at,
+ content_length
+ FROM episodes
+ ORDER BY created_at DESC
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_retryable_jobs(
+ max_retries: int = 3,
+ db_path: str = "podcast.db",
+ ) -> list[dict[str, Any]]:
+ """Get failed jobs that can be retried."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM queue WHERE status = 'error' "
+ "AND retry_count < ? ORDER BY created_at ASC",
+ (max_retries,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def retry_job(job_id: int, db_path: str = "podcast.db") -> None:
+ """Reset a job to pending status for retry."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE queue SET status = 'pending', "
+ "error_message = NULL WHERE id = ?",
+ (job_id,),
+ )
+ conn.commit()
+ logger.info("Reset job %s to pending for retry", job_id)
+
+ @staticmethod
+ def delete_job(job_id: int, db_path: str = "podcast.db") -> None:
+ """Delete a job from the queue."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,))
+ conn.commit()
+ logger.info("Deleted job %s from queue", job_id)
+
+ @staticmethod
+ def get_all_queue_items(
+ db_path: str = "podcast.db",
+ user_id: int | None = None,
+ ) -> list[dict[str, Any]]:
+ """Return all queue items for admin view."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ if user_id:
+ cursor.execute(
+ """
+ SELECT id, url, email, status, retry_count, created_at,
+ error_message
+ FROM queue
+ WHERE user_id = ?
+ ORDER BY created_at DESC
+ """,
+ (user_id,),
+ )
+ else:
+ cursor.execute("""
+ SELECT id, url, email, status, retry_count, created_at,
+ error_message
+ FROM queue
+ ORDER BY created_at DESC
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_status_counts(db_path: str = "podcast.db") -> dict[str, int]:
+ """Get count of queue items by status."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT status, COUNT(*) as count
+ FROM queue
+ GROUP BY status
+ """)
+ rows = cursor.fetchall()
+ return {row["status"]: row["count"] for row in rows}
+
+ @staticmethod
+ def get_user_status_counts(
+ user_id: int,
+ db_path: str = "podcast.db",
+ ) -> dict[str, int]:
+ """Get count of queue items by status for a specific user."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT status, COUNT(*) as count
+ FROM queue
+ WHERE user_id = ?
+ GROUP BY status
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return {row["status"]: row["count"] for row in rows}
+
+ @staticmethod
+ def migrate_to_multi_user(db_path: str = "podcast.db") -> None:
+ """Migrate database to support multiple users."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+
+ # Create users table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ email TEXT UNIQUE NOT NULL,
+ token TEXT UNIQUE NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Add user_id columns to existing tables
+ # Check if columns already exist to make migration idempotent
+ cursor.execute("PRAGMA table_info(queue)")
+ queue_info = cursor.fetchall()
+ queue_columns = [col[1] for col in queue_info]
+
+ if "user_id" not in queue_columns:
+ cursor.execute(
+ "ALTER TABLE queue ADD COLUMN user_id INTEGER "
+ "REFERENCES users(id)",
+ )
+
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "user_id" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN user_id INTEGER "
+ "REFERENCES users(id)",
+ )
+
+ # Create indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_user_id "
+ "ON queue(user_id)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_user_id "
+ "ON episodes(user_id)",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support multiple users")
+
+ @staticmethod
+ def create_user(email: str, db_path: str = "podcast.db") -> tuple[int, str]:
+ """Create a new user and return (user_id, token).
+
+ Raises:
+ ValueError: If user ID cannot be retrieved after insert or if user
+ not found.
+ """
+ # Generate a secure token for RSS feed access
+ token = secrets.token_urlsafe(32)
+
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "INSERT INTO users (email, token) VALUES (?, ?)",
+ (email, token),
+ )
+ conn.commit()
+ user_id = cursor.lastrowid
+ if user_id is None:
+ msg = "Failed to get user ID after insert"
+ raise ValueError(msg)
+ logger.info("Created user %s with email %s", user_id, email)
+ except sqlite3.IntegrityError:
+ # User already exists
+ cursor.execute(
+ "SELECT id, token FROM users WHERE email = ?",
+ (email,),
+ )
+ row = cursor.fetchone()
+ if row is None:
+ msg = f"User with email {email} not found"
+ raise ValueError(msg) from None
+ return int(row["id"]), str(row["token"])
+ else:
+ return int(user_id), str(token)
+
+ @staticmethod
+ def get_user_by_email(
+ email: str,
+ db_path: str = "podcast.db",
+ ) -> dict[str, Any] | None:
+ """Get user by email address."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_by_token(
+ token: str,
+ db_path: str = "podcast.db",
+ ) -> dict[str, Any] | None:
+ """Get user by RSS token."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE token = ?", (token,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_by_id(
+ user_id: int,
+ db_path: str = "podcast.db",
+ ) -> dict[str, Any] | None:
+ """Get user by ID."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_queue_status(
+ user_id: int,
+ db_path: str = "podcast.db",
+ ) -> list[dict[str, Any]]:
+ """Return pending/processing/error items for a specific user."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, url, email, status, created_at, error_message
+ FROM queue
+ WHERE user_id = ? AND
+ status IN ('pending', 'processing', 'error')
+ ORDER BY created_at DESC
+ LIMIT 20
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_user_recent_episodes(
+ user_id: int,
+ limit: int = 20,
+ db_path: str = "podcast.db",
+ ) -> list[dict[str, Any]]:
+ """Get recent episodes for a specific user."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE user_id = ? "
+ "ORDER BY created_at DESC LIMIT ?",
+ (user_id, limit),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_user_all_episodes(
+ user_id: int,
+ db_path: str = "podcast.db",
+ ) -> list[dict[str, Any]]:
+ """Get all episodes for a specific user for RSS feed."""
+ with Database.get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE user_id = ? "
+ "ORDER BY created_at DESC",
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+
+class TestDatabase(Test.TestCase):
+ """Test the Database class."""
+
+ def setUp(self) -> None:
+ """Set up test database."""
+ self.test_db = "test_podcast.db"
+ # Clean up any existing test database
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+ Database.init_db(self.test_db)
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+
+ def test_init_db(self) -> None:
+ """Verify all tables and indexes are created correctly."""
+ with Database.get_connection(self.test_db) as conn:
+ cursor = conn.cursor()
+
+ # Check tables exist
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+ tables = {row[0] for row in cursor.fetchall()}
+ self.assertIn("queue", tables)
+ self.assertIn("episodes", tables)
+ self.assertIn("users", tables)
+
+ # Check indexes exist
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='index'")
+ indexes = {row[0] for row in cursor.fetchall()}
+ self.assertIn("idx_queue_status", indexes)
+ self.assertIn("idx_queue_created", indexes)
+ self.assertIn("idx_episodes_created", indexes)
+ self.assertIn("idx_queue_user_id", indexes)
+ self.assertIn("idx_episodes_user_id", indexes)
+
+ def test_connection_context_manager(self) -> None:
+ """Ensure connections are properly closed."""
+ # Get a connection and verify it works
+ with Database.get_connection(self.test_db) as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT 1")
+ result = cursor.fetchone()
+ self.assertEqual(result[0], 1)
+
+ # Connection should be closed after context manager
+ with pytest.raises(sqlite3.ProgrammingError):
+ cursor.execute("SELECT 1")
+
+ def test_migration_idempotency(self) -> None:
+ """Verify migrations can run multiple times safely."""
+ # Run migration multiple times
+ Database.migrate_to_multi_user(self.test_db)
+ Database.migrate_to_multi_user(self.test_db)
+ Database.migrate_to_multi_user(self.test_db)
+
+ # Should still work fine
+ with Database.get_connection(self.test_db) as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users")
+ # Should not raise an error
+
+
+class TestUserManagement(Test.TestCase):
+ """Test user management functionality."""
+
+ def setUp(self) -> None:
+ """Set up test database."""
+ self.test_db = "test_podcast.db"
+ # Clean up any existing test database
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+ Database.init_db(self.test_db)
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+
+ def test_create_user(self) -> None:
+ """Create user with unique email and token."""
+ user_id, token = Database.create_user("test@example.com", self.test_db)
+
+ self.assertIsInstance(user_id, int)
+ self.assertIsInstance(token, str)
+ self.assertGreater(len(token), 20) # Should be a secure token
+
+ def test_create_duplicate_user(self) -> None:
+ """Verify duplicate emails return existing user."""
+ # Create first user
+ user_id1, token1 = Database.create_user(
+ "test@example.com",
+ self.test_db,
+ )
+
+ # Try to create duplicate
+ user_id2, token2 = Database.create_user(
+ "test@example.com",
+ self.test_db,
+ )
+
+ # Should return same user
+ self.assertIsNotNone(user_id1)
+ self.assertIsNotNone(user_id2)
+ self.assertEqual(user_id1, user_id2)
+ self.assertEqual(token1, token2)
+
+ def test_get_user_by_email(self) -> None:
+ """Retrieve user by email."""
+ user_id, token = Database.create_user("test@example.com", self.test_db)
+
+ user = Database.get_user_by_email("test@example.com", self.test_db)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["id"], user_id)
+ self.assertEqual(user["email"], "test@example.com")
+ self.assertEqual(user["token"], token)
+
+ def test_get_user_by_token(self) -> None:
+ """Retrieve user by RSS token."""
+ user_id, token = Database.create_user("test@example.com", self.test_db)
+
+ user = Database.get_user_by_token(token, self.test_db)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["id"], user_id)
+ self.assertEqual(user["email"], "test@example.com")
+
+ def test_get_user_by_id(self) -> None:
+ """Retrieve user by ID."""
+ user_id, token = Database.create_user("test@example.com", self.test_db)
+
+ user = Database.get_user_by_id(user_id, self.test_db)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["email"], "test@example.com")
+ self.assertEqual(user["token"], token)
+
+ def test_invalid_user_lookups(self) -> None:
+ """Verify None returned for non-existent users."""
+ self.assertIsNone(
+ Database.get_user_by_email("nobody@example.com", self.test_db),
+ )
+ self.assertIsNone(
+ Database.get_user_by_token("invalid-token", self.test_db),
+ )
+ self.assertIsNone(Database.get_user_by_id(9999, self.test_db))
+
+ def test_token_uniqueness(self) -> None:
+ """Ensure tokens are cryptographically unique."""
+ tokens = set()
+ for i in range(10):
+ _, token = Database.create_user(
+ f"user{i}@example.com",
+ self.test_db,
+ )
+ tokens.add(token)
+
+ # All tokens should be unique
+ self.assertEqual(len(tokens), 10)
+
+
+class TestQueueOperations(Test.TestCase):
+ """Test queue operations."""
+
+ def setUp(self) -> None:
+ """Set up test database with a user."""
+ self.test_db = "test_podcast.db"
+ # Clean up any existing test database
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+ Database.init_db(self.test_db)
+ self.user_id, _ = Database.create_user("test@example.com", self.test_db)
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+
+ def test_add_to_queue(self) -> None:
+ """Add job with user association."""
+ job_id = Database.add_to_queue(
+ "https://example.com/article",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ self.assertIsInstance(job_id, int)
+ self.assertGreater(job_id, 0)
+
+ def test_get_pending_jobs(self) -> None:
+ """Retrieve jobs in correct order."""
+ # Add multiple jobs
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ time.sleep(0.01) # Ensure different timestamps
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ time.sleep(0.01)
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ # Get pending jobs
+ jobs = Database.get_pending_jobs(limit=10, db_path=self.test_db)
+
+ self.assertEqual(len(jobs), 3)
+ # Should be in order of creation (oldest first)
+ self.assertEqual(jobs[0]["id"], job1)
+ self.assertEqual(jobs[1]["id"], job2)
+ self.assertEqual(jobs[2]["id"], job3)
+
+ def test_update_job_status(self) -> None:
+ """Update status and error messages."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ # Update to processing
+ Database.update_job_status(job_id, "processing", db_path=self.test_db)
+ job = Database.get_job_by_id(job_id, self.test_db)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "processing")
+
+ # Update to error with message
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Network timeout",
+ self.test_db,
+ )
+ job = Database.get_job_by_id(job_id, self.test_db)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "error")
+ self.assertEqual(job["error_message"], "Network timeout")
+ self.assertEqual(job["retry_count"], 1)
+
+ def test_retry_job(self) -> None:
+ """Reset failed jobs for retry."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ # Set to error
+ Database.update_job_status(job_id, "error", "Failed", self.test_db)
+
+ # Retry
+ Database.retry_job(job_id, self.test_db)
+ job = Database.get_job_by_id(job_id, self.test_db)
+
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "pending")
+ self.assertIsNone(job["error_message"])
+
+ def test_delete_job(self) -> None:
+ """Remove jobs from queue."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ # Delete job
+ Database.delete_job(job_id, self.test_db)
+
+ # Should not exist
+ job = Database.get_job_by_id(job_id, self.test_db)
+ self.assertIsNone(job)
+
+ def test_get_retryable_jobs(self) -> None:
+ """Find jobs eligible for retry."""
+ # Add job and mark as error
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ Database.update_job_status(job_id, "error", "Failed", self.test_db)
+
+ # Should be retryable
+ retryable = Database.get_retryable_jobs(
+ max_retries=3,
+ db_path=self.test_db,
+ )
+ self.assertEqual(len(retryable), 1)
+ self.assertEqual(retryable[0]["id"], job_id)
+
+ # Exceed retry limit
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Failed again",
+ self.test_db,
+ )
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Failed yet again",
+ self.test_db,
+ )
+
+ # Should not be retryable anymore
+ retryable = Database.get_retryable_jobs(
+ max_retries=3,
+ db_path=self.test_db,
+ )
+ self.assertEqual(len(retryable), 0)
+
+ def test_user_queue_isolation(self) -> None:
+ """Ensure users only see their own jobs."""
+ # Create second user
+ user2_id, _ = Database.create_user("user2@example.com", self.test_db)
+
+ # Add jobs for both users
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "user2@example.com",
+ user2_id,
+ self.test_db,
+ )
+
+ # Get user-specific queue status
+ user1_jobs = Database.get_user_queue_status(self.user_id, self.test_db)
+ user2_jobs = Database.get_user_queue_status(user2_id, self.test_db)
+
+ self.assertEqual(len(user1_jobs), 1)
+ self.assertEqual(user1_jobs[0]["id"], job1)
+
+ self.assertEqual(len(user2_jobs), 1)
+ self.assertEqual(user2_jobs[0]["id"], job2)
+
+ def test_status_counts(self) -> None:
+ """Verify status aggregation queries."""
+ # Add jobs with different statuses
+ Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ Database.update_job_status(job2, "processing", db_path=self.test_db)
+ Database.update_job_status(job3, "error", "Failed", self.test_db)
+
+ # Get status counts
+ counts = Database.get_user_status_counts(self.user_id, self.test_db)
+
+ self.assertEqual(counts.get("pending", 0), 1)
+ self.assertEqual(counts.get("processing", 0), 1)
+ self.assertEqual(counts.get("error", 0), 1)
+
+
+class TestEpisodeManagement(Test.TestCase):
+ """Test episode management functionality."""
+
+ def setUp(self) -> None:
+ """Set up test database with a user."""
+ self.test_db = "test_podcast.db"
+ # Clean up any existing test database
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+ Database.init_db(self.test_db)
+ self.user_id, _ = Database.create_user("test@example.com", self.test_db)
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+
+ def test_create_episode(self) -> None:
+ """Create episode with user association."""
+ episode_id = Database.create_episode(
+ title="Test Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=5000,
+ user_id=self.user_id,
+ db_path=self.test_db,
+ )
+
+ self.assertIsInstance(episode_id, int)
+ self.assertGreater(episode_id, 0)
+
+ def test_get_recent_episodes(self) -> None:
+ """Retrieve episodes in reverse chronological order."""
+ # Create multiple episodes
+ ep1 = Database.create_episode(
+ "Article 1",
+ "url1",
+ 100,
+ 1000,
+ self.user_id,
+ self.test_db,
+ )
+ time.sleep(0.01)
+ ep2 = Database.create_episode(
+ "Article 2",
+ "url2",
+ 200,
+ 2000,
+ self.user_id,
+ self.test_db,
+ )
+ time.sleep(0.01)
+ ep3 = Database.create_episode(
+ "Article 3",
+ "url3",
+ 300,
+ 3000,
+ self.user_id,
+ self.test_db,
+ )
+
+ # Get recent episodes
+ episodes = Database.get_recent_episodes(limit=10, db_path=self.test_db)
+
+ self.assertEqual(len(episodes), 3)
+ # Should be in reverse chronological order
+ self.assertEqual(episodes[0]["id"], ep3)
+ self.assertEqual(episodes[1]["id"], ep2)
+ self.assertEqual(episodes[2]["id"], ep1)
+
+ def test_get_user_episodes(self) -> None:
+ """Ensure user isolation for episodes."""
+ # Create second user
+ user2_id, _ = Database.create_user("user2@example.com", self.test_db)
+
+ # Create episodes for both users
+ ep1 = Database.create_episode(
+ "User1 Article",
+ "url1",
+ 100,
+ 1000,
+ self.user_id,
+ self.test_db,
+ )
+ ep2 = Database.create_episode(
+ "User2 Article",
+ "url2",
+ 200,
+ 2000,
+ user2_id,
+ self.test_db,
+ )
+
+ # Get user-specific episodes
+ user1_episodes = Database.get_user_all_episodes(
+ self.user_id,
+ self.test_db,
+ )
+ user2_episodes = Database.get_user_all_episodes(user2_id, self.test_db)
+
+ self.assertEqual(len(user1_episodes), 1)
+ self.assertEqual(user1_episodes[0]["id"], ep1)
+
+ self.assertEqual(len(user2_episodes), 1)
+ self.assertEqual(user2_episodes[0]["id"], ep2)
+
+ def test_episode_metadata(self) -> None:
+ """Verify duration and content_length storage."""
+ Database.create_episode(
+ title="Test Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=12345,
+ content_length=98765,
+ user_id=self.user_id,
+ db_path=self.test_db,
+ )
+
+ episodes = Database.get_user_all_episodes(self.user_id, self.test_db)
+ episode = episodes[0]
+
+ self.assertEqual(episode["duration"], 12345)
+ self.assertEqual(episode["content_length"], 98765)
+
+
+def test() -> None:
+ """Run the tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestDatabase,
+ TestUserManagement,
+ TestQueueOperations,
+ TestEpisodeManagement,
+ ],
+ )
+
+
+def main() -> None:
+ """Run all PodcastItLater.Core tests."""
+ if "test" in sys.argv:
+ test()