diff options
| author | Ben Sima <ben@bsima.me> | 2025-08-13 13:36:30 -0400 |
|---|---|---|
| committer | Ben Sima <ben@bsima.me> | 2025-08-28 12:14:09 -0400 |
| commit | 0b005c192b2c141c7f6c9bff4a0702361814c21d (patch) | |
| tree | 3527a76137f6ee4dd970bba17a93617a311149cb /Biz/PodcastItLater | |
| parent | 7de0a3e0abbf1e152423e148d507e17b752a4982 (diff) | |
Prototype PodcastItLater
This implements a working prototype of PodcastItLater. It basically just works
for a single user currently, but the articles are nice to listen to and this is
something that we can start to build with.
Diffstat (limited to 'Biz/PodcastItLater')
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 1117 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.nix | 91 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 1939 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker.nix | 58 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 1194 |
5 files changed, 4399 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py new file mode 100644 index 0000000..c0d0acf --- /dev/null +++ b/Biz/PodcastItLater/Core.py @@ -0,0 +1,1117 @@ +"""Core, shared logic for PodcastItalater. + +Includes: +- Database models +- Data access layer +- Shared types +""" + +# : out podcastitlater-core +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import pathlib +import pytest +import secrets +import sqlite3 +import sys +import time +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = Log.setup() + + +class Database: # noqa: PLR0904 + """Data access layer for PodcastItLater database operations.""" + + @staticmethod + @contextmanager + def get_connection( + db_path: str = "podcast.db", + ) -> Iterator[sqlite3.Connection]: + """Context manager for database connections. + + Yields: + sqlite3.Connection: Database connection with row factory set. + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + @staticmethod + def init_db(db_path: str = "podcast.db") -> None: + """Initialize database with required tables.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + + # Queue table for job processing + cursor.execute(""" + CREATE TABLE IF NOT EXISTS queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + email TEXT, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + error_message TEXT + ) + """) + + # Episodes table for completed podcasts + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + content_length INTEGER, + audio_url TEXT NOT NULL, + duration INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes for performance + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_created " + "ON queue(created_at)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_created " + "ON episodes(created_at)", + ) + + conn.commit() + logger.info("Database initialized successfully") + + # Run migration to add user support + Database.migrate_to_multi_user(db_path) + + @staticmethod + def add_to_queue( + url: str, + email: str, + user_id: int, + db_path: str = "podcast.db", + ) -> int: + """Insert new job into queue, return job ID. + + Raises: + ValueError: If job ID cannot be retrieved after insert. + """ + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO queue (url, email, user_id) VALUES (?, ?, ?)", + (url, email, user_id), + ) + conn.commit() + job_id = cursor.lastrowid + if job_id is None: + msg = "Failed to get job ID after insert" + raise ValueError(msg) + logger.info("Added job %s to queue: %s", job_id, url) + return job_id + + @staticmethod + def get_pending_jobs( + limit: int = 10, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Fetch jobs with status='pending' ordered by creation time.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'pending' " + "ORDER BY created_at ASC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_job_status( + job_id: int, + status: str, + error: str | None = None, + db_path: str = "podcast.db", + ) -> None: + """Update job status and error message.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + if status == "error": + cursor.execute( + "UPDATE queue SET status = ?, error_message = ?, " + "retry_count = retry_count + 1 WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ? WHERE id = ?", + (status, job_id), + ) + conn.commit() + logger.info("Updated job %s status to %s", job_id, status) + + @staticmethod + def get_job_by_id( + job_id: int, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Fetch single job by ID.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def create_episode( # noqa: PLR0913, PLR0917 + title: str, + audio_url: str, + duration: int, + content_length: int, + user_id: int | None = None, + db_path: str = "podcast.db", + ) -> int: + """Insert episode record, return episode ID. + + Raises: + ValueError: If episode ID cannot be retrieved after insert. + """ + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episodes " + "(title, audio_url, duration, content_length, user_id) " + "VALUES (?, ?, ?, ?, ?)", + (title, audio_url, duration, content_length, user_id), + ) + conn.commit() + episode_id = cursor.lastrowid + if episode_id is None: + msg = "Failed to get episode ID after insert" + raise ValueError(msg) + logger.info("Created episode %s: %s", episode_id, title) + return episode_id + + @staticmethod + def get_recent_episodes( + limit: int = 20, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get recent episodes for RSS feed generation.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_queue_status_summary(db_path: str = "podcast.db") -> dict[str, Any]: + """Get queue status summary for web interface.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + + # Count jobs by status + cursor.execute( + "SELECT status, COUNT(*) as count FROM queue GROUP BY status", + ) + rows = cursor.fetchall() + status_counts = {row["status"]: row["count"] for row in rows} + + # Get recent jobs + cursor.execute( + "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10", + ) + rows = cursor.fetchall() + recent_jobs = [dict(row) for row in rows] + + return {"status_counts": status_counts, "recent_jobs": recent_jobs} + + @staticmethod + def get_queue_status(db_path: str = "podcast.db") -> list[dict[str, Any]]: + """Return pending/processing/error items for web interface.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, url, email, status, created_at, error_message + FROM queue + WHERE status IN ('pending', 'processing', 'error') + ORDER BY created_at DESC + LIMIT 20 + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_all_episodes( + db_path: str = "podcast.db", + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all episodes for RSS feed.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length + FROM episodes + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, title, audio_url, duration, created_at, + content_length + FROM episodes + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_retryable_jobs( + max_retries: int = 3, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get failed jobs that can be retried.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'error' " + "AND retry_count < ? ORDER BY created_at ASC", + (max_retries,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def retry_job(job_id: int, db_path: str = "podcast.db") -> None: + """Reset a job to pending status for retry.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE queue SET status = 'pending', " + "error_message = NULL WHERE id = ?", + (job_id,), + ) + conn.commit() + logger.info("Reset job %s to pending for retry", job_id) + + @staticmethod + def delete_job(job_id: int, db_path: str = "podcast.db") -> None: + """Delete a job from the queue.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) + conn.commit() + logger.info("Deleted job %s from queue", job_id) + + @staticmethod + def get_all_queue_items( + db_path: str = "podcast.db", + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all queue items for admin view.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, url, email, status, retry_count, created_at, + error_message + FROM queue + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, url, email, status, retry_count, created_at, + error_message + FROM queue + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_status_counts(db_path: str = "podcast.db") -> dict[str, int]: + """Get count of queue items by status.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT status, COUNT(*) as count + FROM queue + GROUP BY status + """) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def get_user_status_counts( + user_id: int, + db_path: str = "podcast.db", + ) -> dict[str, int]: + """Get count of queue items by status for a specific user.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) as count + FROM queue + WHERE user_id = ? + GROUP BY status + """, + (user_id,), + ) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def migrate_to_multi_user(db_path: str = "podcast.db") -> None: + """Migrate database to support multiple users.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add user_id columns to existing tables + # Check if columns already exist to make migration idempotent + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "user_id" not in queue_columns: + cursor.execute( + "ALTER TABLE queue ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "user_id" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + # Create indexes + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_user_id " + "ON queue(user_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_user_id " + "ON episodes(user_id)", + ) + + conn.commit() + logger.info("Database migrated to support multiple users") + + @staticmethod + def create_user(email: str, db_path: str = "podcast.db") -> tuple[int, str]: + """Create a new user and return (user_id, token). + + Raises: + ValueError: If user ID cannot be retrieved after insert or if user + not found. + """ + # Generate a secure token for RSS feed access + token = secrets.token_urlsafe(32) + + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO users (email, token) VALUES (?, ?)", + (email, token), + ) + conn.commit() + user_id = cursor.lastrowid + if user_id is None: + msg = "Failed to get user ID after insert" + raise ValueError(msg) + logger.info("Created user %s with email %s", user_id, email) + except sqlite3.IntegrityError: + # User already exists + cursor.execute( + "SELECT id, token FROM users WHERE email = ?", + (email,), + ) + row = cursor.fetchone() + if row is None: + msg = f"User with email {email} not found" + raise ValueError(msg) from None + return int(row["id"]), str(row["token"]) + else: + return int(user_id), str(token) + + @staticmethod + def get_user_by_email( + email: str, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Get user by email address.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_token( + token: str, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Get user by RSS token.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_id( + user_id: int, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Get user by ID.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_queue_status( + user_id: int, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Return pending/processing/error items for a specific user.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, url, email, status, created_at, error_message + FROM queue + WHERE user_id = ? AND + status IN ('pending', 'processing', 'error') + ORDER BY created_at DESC + LIMIT 20 + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_recent_episodes( + user_id: int, + limit: int = 20, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get recent episodes for a specific user.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC LIMIT ?", + (user_id, limit), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_all_episodes( + user_id: int, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get all episodes for a specific user for RSS feed.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC", + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + +class TestDatabase(Test.TestCase): + """Test the Database class.""" + + def setUp(self) -> None: + """Set up test database.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_init_db(self) -> None: + """Verify all tables and indexes are created correctly.""" + with Database.get_connection(self.test_db) as conn: + cursor = conn.cursor() + + # Check tables exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + self.assertIn("queue", tables) + self.assertIn("episodes", tables) + self.assertIn("users", tables) + + # Check indexes exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='index'") + indexes = {row[0] for row in cursor.fetchall()} + self.assertIn("idx_queue_status", indexes) + self.assertIn("idx_queue_created", indexes) + self.assertIn("idx_episodes_created", indexes) + self.assertIn("idx_queue_user_id", indexes) + self.assertIn("idx_episodes_user_id", indexes) + + def test_connection_context_manager(self) -> None: + """Ensure connections are properly closed.""" + # Get a connection and verify it works + with Database.get_connection(self.test_db) as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchone() + self.assertEqual(result[0], 1) + + # Connection should be closed after context manager + with pytest.raises(sqlite3.ProgrammingError): + cursor.execute("SELECT 1") + + def test_migration_idempotency(self) -> None: + """Verify migrations can run multiple times safely.""" + # Run migration multiple times + Database.migrate_to_multi_user(self.test_db) + Database.migrate_to_multi_user(self.test_db) + Database.migrate_to_multi_user(self.test_db) + + # Should still work fine + with Database.get_connection(self.test_db) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users") + # Should not raise an error + + +class TestUserManagement(Test.TestCase): + """Test user management functionality.""" + + def setUp(self) -> None: + """Set up test database.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_create_user(self) -> None: + """Create user with unique email and token.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + self.assertIsInstance(user_id, int) + self.assertIsInstance(token, str) + self.assertGreater(len(token), 20) # Should be a secure token + + def test_create_duplicate_user(self) -> None: + """Verify duplicate emails return existing user.""" + # Create first user + user_id1, token1 = Database.create_user( + "test@example.com", + self.test_db, + ) + + # Try to create duplicate + user_id2, token2 = Database.create_user( + "test@example.com", + self.test_db, + ) + + # Should return same user + self.assertIsNotNone(user_id1) + self.assertIsNotNone(user_id2) + self.assertEqual(user_id1, user_id2) + self.assertEqual(token1, token2) + + def test_get_user_by_email(self) -> None: + """Retrieve user by email.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + user = Database.get_user_by_email("test@example.com", self.test_db) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_get_user_by_token(self) -> None: + """Retrieve user by RSS token.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + user = Database.get_user_by_token(token, self.test_db) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + + def test_get_user_by_id(self) -> None: + """Retrieve user by ID.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + user = Database.get_user_by_id(user_id, self.test_db) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_invalid_user_lookups(self) -> None: + """Verify None returned for non-existent users.""" + self.assertIsNone( + Database.get_user_by_email("nobody@example.com", self.test_db), + ) + self.assertIsNone( + Database.get_user_by_token("invalid-token", self.test_db), + ) + self.assertIsNone(Database.get_user_by_id(9999, self.test_db)) + + def test_token_uniqueness(self) -> None: + """Ensure tokens are cryptographically unique.""" + tokens = set() + for i in range(10): + _, token = Database.create_user( + f"user{i}@example.com", + self.test_db, + ) + tokens.add(token) + + # All tokens should be unique + self.assertEqual(len(tokens), 10) + + +class TestQueueOperations(Test.TestCase): + """Test queue operations.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + self.user_id, _ = Database.create_user("test@example.com", self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_add_to_queue(self) -> None: + """Add job with user association.""" + job_id = Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + self.test_db, + ) + + self.assertIsInstance(job_id, int) + self.assertGreater(job_id, 0) + + def test_get_pending_jobs(self) -> None: + """Retrieve jobs in correct order.""" + # Add multiple jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + self.test_db, + ) + time.sleep(0.01) # Ensure different timestamps + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + self.test_db, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Get pending jobs + jobs = Database.get_pending_jobs(limit=10, db_path=self.test_db) + + self.assertEqual(len(jobs), 3) + # Should be in order of creation (oldest first) + self.assertEqual(jobs[0]["id"], job1) + self.assertEqual(jobs[1]["id"], job2) + self.assertEqual(jobs[2]["id"], job3) + + def test_update_job_status(self) -> None: + """Update status and error messages.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Update to processing + Database.update_job_status(job_id, "processing", db_path=self.test_db) + job = Database.get_job_by_id(job_id, self.test_db) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "processing") + + # Update to error with message + Database.update_job_status( + job_id, + "error", + "Network timeout", + self.test_db, + ) + job = Database.get_job_by_id(job_id, self.test_db) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "error") + self.assertEqual(job["error_message"], "Network timeout") + self.assertEqual(job["retry_count"], 1) + + def test_retry_job(self) -> None: + """Reset failed jobs for retry.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Set to error + Database.update_job_status(job_id, "error", "Failed", self.test_db) + + # Retry + Database.retry_job(job_id, self.test_db) + job = Database.get_job_by_id(job_id, self.test_db) + + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertIsNone(job["error_message"]) + + def test_delete_job(self) -> None: + """Remove jobs from queue.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Delete job + Database.delete_job(job_id, self.test_db) + + # Should not exist + job = Database.get_job_by_id(job_id, self.test_db) + self.assertIsNone(job) + + def test_get_retryable_jobs(self) -> None: + """Find jobs eligible for retry.""" + # Add job and mark as error + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + Database.update_job_status(job_id, "error", "Failed", self.test_db) + + # Should be retryable + retryable = Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 1) + self.assertEqual(retryable[0]["id"], job_id) + + # Exceed retry limit + Database.update_job_status( + job_id, + "error", + "Failed again", + self.test_db, + ) + Database.update_job_status( + job_id, + "error", + "Failed yet again", + self.test_db, + ) + + # Should not be retryable anymore + retryable = Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 0) + + def test_user_queue_isolation(self) -> None: + """Ensure users only see their own jobs.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com", self.test_db) + + # Add jobs for both users + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + self.test_db, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "user2@example.com", + user2_id, + self.test_db, + ) + + # Get user-specific queue status + user1_jobs = Database.get_user_queue_status(self.user_id, self.test_db) + user2_jobs = Database.get_user_queue_status(user2_id, self.test_db) + + self.assertEqual(len(user1_jobs), 1) + self.assertEqual(user1_jobs[0]["id"], job1) + + self.assertEqual(len(user2_jobs), 1) + self.assertEqual(user2_jobs[0]["id"], job2) + + def test_status_counts(self) -> None: + """Verify status aggregation queries.""" + # Add jobs with different statuses + Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + self.test_db, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + self.test_db, + ) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + self.test_db, + ) + + Database.update_job_status(job2, "processing", db_path=self.test_db) + Database.update_job_status(job3, "error", "Failed", self.test_db) + + # Get status counts + counts = Database.get_user_status_counts(self.user_id, self.test_db) + + self.assertEqual(counts.get("pending", 0), 1) + self.assertEqual(counts.get("processing", 0), 1) + self.assertEqual(counts.get("error", 0), 1) + + +class TestEpisodeManagement(Test.TestCase): + """Test episode management functionality.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + self.user_id, _ = Database.create_user("test@example.com", self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_create_episode(self) -> None: + """Create episode with user association.""" + episode_id = Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + db_path=self.test_db, + ) + + self.assertIsInstance(episode_id, int) + self.assertGreater(episode_id, 0) + + def test_get_recent_episodes(self) -> None: + """Retrieve episodes in reverse chronological order.""" + # Create multiple episodes + ep1 = Database.create_episode( + "Article 1", + "url1", + 100, + 1000, + self.user_id, + self.test_db, + ) + time.sleep(0.01) + ep2 = Database.create_episode( + "Article 2", + "url2", + 200, + 2000, + self.user_id, + self.test_db, + ) + time.sleep(0.01) + ep3 = Database.create_episode( + "Article 3", + "url3", + 300, + 3000, + self.user_id, + self.test_db, + ) + + # Get recent episodes + episodes = Database.get_recent_episodes(limit=10, db_path=self.test_db) + + self.assertEqual(len(episodes), 3) + # Should be in reverse chronological order + self.assertEqual(episodes[0]["id"], ep3) + self.assertEqual(episodes[1]["id"], ep2) + self.assertEqual(episodes[2]["id"], ep1) + + def test_get_user_episodes(self) -> None: + """Ensure user isolation for episodes.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com", self.test_db) + + # Create episodes for both users + ep1 = Database.create_episode( + "User1 Article", + "url1", + 100, + 1000, + self.user_id, + self.test_db, + ) + ep2 = Database.create_episode( + "User2 Article", + "url2", + 200, + 2000, + user2_id, + self.test_db, + ) + + # Get user-specific episodes + user1_episodes = Database.get_user_all_episodes( + self.user_id, + self.test_db, + ) + user2_episodes = Database.get_user_all_episodes(user2_id, self.test_db) + + self.assertEqual(len(user1_episodes), 1) + self.assertEqual(user1_episodes[0]["id"], ep1) + + self.assertEqual(len(user2_episodes), 1) + self.assertEqual(user2_episodes[0]["id"], ep2) + + def test_episode_metadata(self) -> None: + """Verify duration and content_length storage.""" + Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=12345, + content_length=98765, + user_id=self.user_id, + db_path=self.test_db, + ) + + episodes = Database.get_user_all_episodes(self.user_id, self.test_db) + episode = episodes[0] + + self.assertEqual(episode["duration"], 12345) + self.assertEqual(episode["content_length"], 98765) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestDatabase, + TestUserManagement, + TestQueueOperations, + TestEpisodeManagement, + ], + ) + + +def main() -> None: + """Run all PodcastItLater.Core tests.""" + if "test" in sys.argv: + test() diff --git a/Biz/PodcastItLater/Web.nix b/Biz/PodcastItLater/Web.nix new file mode 100644 index 0000000..692d39e --- /dev/null +++ b/Biz/PodcastItLater/Web.nix @@ -0,0 +1,91 @@ +{ + options, + lib, + config, + ... +}: let + cfg = config.services.podcastitlater-web; + rootDomain = "bensima.com"; + ports = import ../../Omni/Cloud/Ports.nix; +in { + options.services.podcastitlater-web = { + enable = lib.mkEnableOption "Enable the PodcastItLater web service"; + port = lib.mkOption { + type = lib.types.int; + default = 8000; + description = '' + The port on which PodcastItLater web will listen for + incoming HTTP traffic. + ''; + }; + dataDir = lib.mkOption { + type = lib.types.path; + default = "/var/podcastitlater"; + description = "Data directory for PodcastItLater (shared with worker)"; + }; + package = lib.mkOption { + type = lib.types.package; + description = "PodcastItLater web package to use"; + }; + }; + config = lib.mkIf cfg.enable { + systemd.services.podcastitlater-web = { + path = [cfg.package]; + wantedBy = ["multi-user.target"]; + preStart = '' + # Create data directory if it doesn't exist + mkdir -p ${cfg.dataDir} + + # Manual step: create this file with secrets + # MAILGUN_WEBHOOK_KEY=your-mailgun-webhook-key + # SECRET_KEY=your-secret-key-for-sessions + # SESSION_SECRET=your-session-secret + # EMAIL_FROM=noreply@podcastitlater.bensima.com + # SMTP_SERVER=smtp.mailgun.org + # SMTP_PASSWORD=your-smtp-password + test -f /run/podcastitlater/env + ''; + script = '' + ${cfg.package}/bin/podcastitlater-web + ''; + description = '' + PodcastItLater Web Service + ''; + serviceConfig = { + Environment = [ + "PORT=${toString cfg.port}" + "AREA=Live" + "DATABASE_PATH=${cfg.dataDir}/podcast.db" + "BASE_URL=https://podcastitlater.${rootDomain}" + ]; + EnvironmentFile = "/run/podcastitlater/env"; + KillSignal = "INT"; + Type = "simple"; + Restart = "on-abort"; + RestartSec = "1"; + }; + }; + + # Nginx configuration + services.nginx = { + enable = true; + recommendedGzipSettings = true; + recommendedOptimisation = true; + recommendedProxySettings = true; + recommendedTlsSettings = true; + statusPage = true; + + virtualHosts."podcastitlater.${rootDomain}" = { + forceSSL = true; + enableACME = true; + locations."/" = { + proxyPass = "http://localhost:${toString cfg.port}"; + proxyWebsockets = true; + }; + }; + }; + + # Ensure firewall allows web traffic + networking.firewall.allowedTCPPorts = [ports.ssh ports.http ports.https]; + }; +} diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py new file mode 100644 index 0000000..792803c --- /dev/null +++ b/Biz/PodcastItLater/Web.py @@ -0,0 +1,1939 @@ +""" +PodcastItLater Web Service. + +Web frontend for converting articles to podcast episodes via email submission. +Provides ludic + htmx interface, mailgun webhook, and RSS feed generation. +""" + +# : out podcastitlater-web +# : dep ludic +# : dep feedgen +# : dep httpx +# : dep itsdangerous +# : dep uvicorn +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep starlette +import Biz.EmailAgent +import Biz.PodcastItLater.Core as Core +import hashlib +import hmac +import ludic.catalog.layouts as layouts +import ludic.catalog.pages as pages +import ludic.html as html +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import pathlib +import re +import sys +import tempfile +import time +import typing +import urllib.parse +import uvicorn +from datetime import datetime +from datetime import timezone +from feedgen.feed import FeedGenerator # type: ignore[import-untyped] +from itsdangerous import URLSafeTimedSerializer +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from ludic.web import LudicApp +from ludic.web import Request +from ludic.web.datastructures import FormData +from ludic.web.responses import Response +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import RedirectResponse +from starlette.testclient import TestClient +from typing import override + +logger = Log.setup() + +# Configuration +DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db") +MAILGUN_WEBHOOK_KEY = os.getenv("MAILGUN_WEBHOOK_KEY", "") +BASE_URL = os.getenv("BASE_URL", "http://localhost:8000") +PORT = int(os.getenv("PORT", "8000")) + +# Authentication configuration +MAGIC_LINK_MAX_AGE = 3600 # 1 hour +SESSION_MAX_AGE = 30 * 24 * 3600 # 30 days +EMAIL_FROM = os.getenv("EMAIL_FROM", "noreply@podcastitlater.com") +SMTP_SERVER = os.getenv("SMTP_SERVER", "smtp.mailgun.org") +SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "") + +# Initialize serializer for magic links +magic_link_serializer = URLSafeTimedSerializer( + os.getenv("SECRET_KEY", "dev-secret-key"), +) + +# Test database path override for testing +_test_database_path: str | None = None + + +# Constants +URL_TRUNCATE_LENGTH = 80 +TITLE_TRUNCATE_LENGTH = 50 +ERROR_TRUNCATE_LENGTH = 50 + +RSS_CONFIG = { + "title": "Ben's Article Podcast", + "description": "Web articles converted to audio", + "author": "Ben Sima", + "language": "en-US", + "base_url": BASE_URL, +} + + +def send_magic_link(email: str, token: str) -> None: + """Send magic link email to user.""" + subject = "Login to PodcastItLater" + + # Create temporary file for email body + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".txt", + delete=False, + encoding="utf-8", + ) as f: + body_text_path = pathlib.Path(f.name) + + # Create email body + magic_link = f"{BASE_URL}/auth/verify?token={token}" + body_text_path.write_text(f""" +Hello, + +Click this link to login to PodcastItLater: +{magic_link} + +This link will expire in 1 hour. + +If you didn't request this, please ignore this email. + +Best, +PodcastItLater +""") + + try: + Biz.EmailAgent.send_email( + to_addrs=[email], + from_addr=EMAIL_FROM, + smtp_server=SMTP_SERVER, + password=SMTP_PASSWORD, + subject=subject, + body_text=body_text_path, + ) + finally: + # Clean up temporary file + body_text_path.unlink(missing_ok=True) + + +class LoginFormAttrs(Attrs): + """Attributes for LoginForm component.""" + + error: str | None + + +class LoginForm(Component[AnyChildren, LoginFormAttrs]): + """Simple email-based login/registration form.""" + + @override + def render(self) -> html.div: + error = self.attrs.get("error") + return html.div( + html.h2("Login / Register"), + html.form( + html.div( + html.label("Email:", for_="email"), + html.input( + type="email", + id="email", + name="email", + placeholder="your@email.com", + required=True, + style={ + "width": "100%", + "padding": "8px", + "margin": "4px 0", + }, + ), + ), + html.button( + "Continue", + type="submit", + style={ + "padding": "10px 20px", + "background": "#007cba", + "color": "white", + "border": "none", + "cursor": "pointer", + }, + ), + hx_post="/login", + hx_target="#login-result", + hx_swap="innerHTML", + ), + html.div( + error or "", + id="login-result", + style={"margin-top": "10px", "color": "#dc3545"} + if error + else {"margin-top": "10px"}, + ), + ) + + +class SubmitForm(Component[AnyChildren, Attrs]): + """Article submission form with HTMX.""" + + @override + def render(self) -> html.div: + return html.div( + html.h2("Submit Article"), + html.form( + html.div( + html.label("Article URL:", for_="url"), + html.input( + type="url", + id="url", + name="url", + placeholder="https://example.com/article", + required=True, + style={ + "width": "100%", + "padding": "8px", + "margin": "4px 0", + }, + ), + ), + html.button( + "Submit", + type="submit", + style={ + "padding": "10px 20px", + "background": "#007cba", + "color": "white", + "border": "none", + "cursor": "pointer", + }, + ), + hx_post="/submit", + hx_target="#submit-result", + hx_swap="innerHTML", + ), + html.div(id="submit-result", style={"margin-top": "10px"}), + ) + + +class QueueStatusAttrs(Attrs): + """Attributes for QueueStatus component.""" + + items: list[dict[str, typing.Any]] + + +class QueueStatus(Component[AnyChildren, QueueStatusAttrs]): + """Display queue items with auto-refresh.""" + + @override + def render(self) -> html.div: + items = self.attrs["items"] + if not items: + return html.div( + html.h3("Queue Status"), + html.p("No items in queue"), + hx_get="/status", + hx_trigger="every 30s", + hx_swap="outerHTML", + ) + + queue_items = [] + for item in items: + status_color = { + "pending": "#ffa500", + "processing": "#007cba", + "error": "#dc3545", + }.get(item["status"], "#6c757d") + + queue_items.append( + html.div( + html.strong(f"#{item['id']} "), + html.span( + item["status"].upper(), + style={"color": status_color, "font-weight": "bold"}, + ), + html.br(), + html.small( + item["url"][:URL_TRUNCATE_LENGTH] + + ( + "..." + if len(item["url"]) > URL_TRUNCATE_LENGTH + else "" + ), + ), + html.br(), + html.small(f"Created: {item['created_at']}"), + *( + [ + html.br(), + html.small( + f"Error: {item['error_message']}", + style={"color": "#dc3545"}, + ), + ] + if item["error_message"] + else [] + ), + style={ + "border": "1px solid #ddd", + "padding": "10px", + "margin": "5px 0", + "border-radius": "4px", + }, + ), + ) + + return html.div( + html.h3("Queue Status"), + *queue_items, + hx_get="/status", + hx_trigger="every 30s", + hx_swap="outerHTML", + ) + + +class EpisodeListAttrs(Attrs): + """Attributes for EpisodeList component.""" + + episodes: list[dict[str, typing.Any]] + + +class EpisodeList(Component[AnyChildren, EpisodeListAttrs]): + """List recent episodes with audio player.""" + + @override + def render(self) -> html.div: + episodes = self.attrs["episodes"] + if not episodes: + return html.div( + html.h3("Recent Episodes"), + html.p("No episodes yet"), + ) + + episode_items = [] + for episode in episodes: + duration_str = ( + f"{episode['duration']}s" if episode["duration"] else "Unknown" + ) + episode_items.append( + html.div( + html.h4(episode["title"]), + html.audio( + html.source( + src=episode["audio_url"], + type="audio/mpeg", + ), + "Your browser does not support the audio element.", + controls=True, + style={"width": "100%"}, + ), + html.small( + f"Duration: {duration_str} | " + f"Created: {episode['created_at']}", + ), + style={ + "border": "1px solid #ddd", + "padding": "15px", + "margin": "10px 0", + "border-radius": "4px", + }, + ), + ) + + return html.div(html.h3("Recent Episodes"), *episode_items) + + +class AdminViewAttrs(Attrs): + """Attributes for AdminView component.""" + + queue_items: list[dict[str, typing.Any]] + episodes: list[dict[str, typing.Any]] + status_counts: dict[str, int] + + +class AdminView(Component[AnyChildren, AdminViewAttrs]): + """Admin view showing all queue items and episodes in tables.""" + + @override + def render(self) -> pages.HtmlPage: + queue_items = self.attrs["queue_items"] + episodes = self.attrs["episodes"] + status_counts = self.attrs.get("status_counts", {}) + + return pages.HtmlPage( + pages.Head( + title="PodcastItLater - Admin Queue Status", + htmx_version="1.9.10", + load_styles=True, + ), + pages.Body( + layouts.Center( + layouts.Stack( + html.h1("PodcastItLater Admin - Queue Status"), + html.div( + html.a( + "← Back to Home", + href="/", + style={"color": "#007cba"}, + ), + style={"margin-bottom": "20px"}, + ), + # Status Summary + html.div( + html.h2("Status Summary"), + html.div( + *[ + html.span( + f"{status.upper()}: {count}", + style={ + "margin-right": "20px", + "padding": "5px 10px", + "background": ( + AdminView._get_status_color( + status, + ) + ), + "color": "white", + "border-radius": "4px", + }, + ) + for status, count in status_counts.items() + ], + style={"margin-bottom": "20px"}, + ), + ), + # Queue Items Table + html.div( + html.h2("Queue Items"), + html.div( + html.table( + html.thead( + html.tr( + html.th( + "ID", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "URL", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Email", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Status", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Retries", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Created", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Error", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Actions", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + ), + ), + html.tbody( + *[ + html.tr( + html.td( + str(item["id"]), + style={"padding": "10px"}, + ), + html.td( + html.div( + item["url"][ + :TITLE_TRUNCATE_LENGTH + ] + + ( + "..." + if ( + len(item["url"]) + > TITLE_TRUNCATE_LENGTH # noqa: E501 + ) + else "" + ), + title=item["url"], + style={ + "max-width": ( + "300px" + ), + "overflow": ( + "hidden" + ), + "text-overflow": ( + "ellipsis" + ), + }, + ), + style={"padding": "10px"}, + ), + html.td( + item["email"] or "-", + style={"padding": "10px"}, + ), + html.td( + html.span( + item["status"], + style={ + "color": ( + AdminView._get_status_color( + item[ + "status" + ], + ) + ), + }, + ), + style={"padding": "10px"}, + ), + html.td( + str( + item.get( + "retry_count", + 0, + ), + ), + style={"padding": "10px"}, + ), + html.td( + item["created_at"], + style={"padding": "10px"}, + ), + html.td( + html.div( + item["error_message"][ + :ERROR_TRUNCATE_LENGTH + ] + + "..." + if item["error_message"] + and len( + item[ + "error_message" + ], + ) + > ERROR_TRUNCATE_LENGTH + else item[ + "error_message" + ] + or "-", + title=item[ + "error_message" + ] + or "", + style={ + "max-width": ( + "200px" + ), + "overflow": ( + "hidden" + ), + "text-overflow": ( + "ellipsis" + ), + }, + ), + style={"padding": "10px"}, + ), + html.td( + html.div( + html.button( + "Retry", + hx_post=f"/queue/{item['id']}/retry", + hx_target="body", + hx_swap="outerHTML", + style={ + "margin-right": ( # noqa: E501 + "5px" + ), + "padding": ( + "5px 10px" + ), + "background": ( + "#28a745" + ), + "color": ( + "white" + ), + "border": ( + "none" + ), + "cursor": ( + "pointer" + ), + "border-radius": ( # noqa: E501 + "3px" + ), + }, + disabled=item[ + "status" + ] + == "completed", + ) + if item["status"] + != "completed" + else "", + html.button( + "Delete", + hx_delete=f"/queue/{item['id']}", + hx_confirm=( + "Are you sure " + "you want to " + "delete this " + "queue item?" + ), + hx_target="body", + hx_swap="outerHTML", + style={ + "padding": ( + "5px 10px" + ), + "background": ( + "#dc3545" + ), + "color": ( + "white" + ), + "border": ( + "none" + ), + "cursor": ( + "pointer" + ), + "border-radius": ( # noqa: E501 + "3px" + ), + }, + ), + style={ + "display": "flex", + "gap": "5px", + }, + ), + style={"padding": "10px"}, + ), + ) + for item in queue_items + ], + ), + style={ + "width": "100%", + "border-collapse": "collapse", + "border": "1px solid #ddd", + }, + ), + style={ + "overflow-x": "auto", + "margin-bottom": "30px", + }, + ), + ), + # Episodes Table + html.div( + html.h2("Completed Episodes"), + html.div( + html.table( + html.thead( + html.tr( + html.th( + "ID", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Title", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Audio URL", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Duration", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Content Length", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + html.th( + "Created", + style={ + "padding": "10px", + "text-align": "left", + }, + ), + ), + ), + html.tbody( + *[ + html.tr( + html.td( + str(episode["id"]), + style={"padding": "10px"}, + ), + html.td( + episode["title"][ + :TITLE_TRUNCATE_LENGTH + ] + + ( + "..." + if len(episode["title"]) + > TITLE_TRUNCATE_LENGTH + else "" + ), + style={"padding": "10px"}, + ), + html.td( + html.a( + "Listen", + href=episode[ + "audio_url" + ], + target="_blank", + style={ + "color": "#007cba", + }, + ), + style={"padding": "10px"}, + ), + html.td( + f"{episode['duration']}s" + if episode["duration"] + else "-", + style={"padding": "10px"}, + ), + html.td( + ( + f"{episode['content_length']:,} chars" # noqa: E501 + ) + if episode["content_length"] + else "-", + style={"padding": "10px"}, + ), + html.td( + episode["created_at"], + style={"padding": "10px"}, + ), + ) + for episode in episodes + ], + ), + style={ + "width": "100%", + "border-collapse": "collapse", + "border": "1px solid #ddd", + }, + ), + style={"overflow-x": "auto"}, + ), + ), + html.style(""" + body { + font-family: Arial, sans-serif; + max-width: 1200px; + margin: 0 auto; + padding: 20px; + } + h1, h2 { color: #333; } + table { background: white; } + thead { background: #f8f9fa; } + tbody tr:nth-child(even) { background: #f8f9fa; } + tbody tr:hover { background: #e9ecef; } + """), + ), + ), + htmx_version="1.9.10", + hx_get="/queue-status", + hx_trigger="every 10s", + hx_swap="outerHTML", + ), + ) + + @staticmethod + def _get_status_color(status: str) -> str: + """Get color for status display.""" + return { + "pending": "#ffa500", + "processing": "#007cba", + "completed": "#28a745", + "error": "#dc3545", + }.get(status, "#6c757d") + + +class HomePageAttrs(Attrs): + """Attributes for HomePage component.""" + + queue_items: list[dict[str, typing.Any]] + episodes: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + error: str | None + + +class HomePage(Component[AnyChildren, HomePageAttrs]): + """Main page combining all components.""" + + @override + def render(self) -> pages.HtmlPage: + queue_items = self.attrs["queue_items"] + episodes = self.attrs["episodes"] + user = self.attrs.get("user") + + return pages.HtmlPage( + pages.Head( + title="PodcastItLater", + htmx_version="1.9.10", + load_styles=True, + ), + pages.Body( + layouts.Center( + layouts.Stack( + html.h1("PodcastItLater"), + html.p("Convert web articles to podcast episodes"), + html.div( + # Show error if present + html.div( + self.attrs.get("error", "") or "", + style={ + "color": "#dc3545", + "margin-bottom": "10px", + }, + ) + if self.attrs.get("error") + else html.div(), + # Show user info and logout if logged in + html.div( + html.p(f"Logged in as: {user['email']}"), + html.p( + "Your RSS Feed: ", + html.code( + f"{BASE_URL}/feed/{user['token']}.xml", + ), + ), + html.div( + html.a( + "View Queue Status", + href="/queue-status", + style={ + "color": "#007cba", + "margin-right": "15px", + }, + ), + html.a( + "Logout", + href="/logout", + style={"color": "#dc3545"}, + ), + ), + style={ + "background": "#f8f9fa", + "padding": "15px", + "border-radius": "4px", + "margin-bottom": "20px", + }, + ) + if user + else LoginForm(error=self.attrs.get("error")), + # Only show submit form and content if logged in + html.div( + SubmitForm(), + QueueStatus(items=queue_items), + EpisodeList(episodes=episodes), + classes=["container"], + ) + if user + else html.div(), + ), + html.style(""" + body { + font-family: Arial, sans-serif; + max-width: 800px; + margin: 0 auto; + padding: 20px; + } + h1 { color: #333; } + .container { display: grid; gap: 20px; } + """), + ), + ), + htmx_version="1.9.10", + ), + ) + + +def get_database_path() -> str: + """Get the current database path, using test override if set.""" + return ( + _test_database_path + if _test_database_path is not None + else DATABASE_PATH + ) + + +# Initialize database on startup +Core.Database.init_db(get_database_path()) + +# Create ludic app with session support +app = LudicApp() +app.add_middleware( + SessionMiddleware, + secret_key=os.getenv("SESSION_SECRET", "dev-secret-key"), + max_age=SESSION_MAX_AGE, # 30 days + same_site="lax", + https_only=App.from_env() == App.Area.Live, # HTTPS only in production +) + + +def extract_urls_from_text(text: str) -> list[str]: + """Extract HTTP/HTTPS URLs from text.""" + url_pattern = r'https?://[^\s<>"\']+[^\s<>"\'.,;!?]' + return re.findall(url_pattern, text) + + +def verify_mailgun_signature( + token: str, + timestamp: str, + signature: str, +) -> bool: + """Verify Mailgun webhook signature.""" + if not MAILGUN_WEBHOOK_KEY: + return True # Skip verification if no key configured + + value = f"{timestamp}{token}" + expected = hmac.new( + MAILGUN_WEBHOOK_KEY.encode(), + value.encode(), + hashlib.sha256, + ).hexdigest() + return hmac.compare_digest(signature, expected) + + +@app.get("/") +def index(request: Request) -> HomePage: + """Display main page with form and status.""" + user_id = request.session.get("user_id") + user = None + queue_items = [] + episodes = [] + error = request.query_params.get("error") + + # Map error codes to user-friendly messages + error_messages = { + "invalid_link": "Invalid login link", + "expired_link": "Login link has expired. Please request a new one.", + "user_not_found": "User not found. Please try logging in again.", + } + error_message = error_messages.get(error) if error else None + + if user_id: + user = Core.Database.get_user_by_id(user_id, get_database_path()) + if user: + # Get user-specific queue items and episodes + queue_items = Core.Database.get_user_queue_status( + user_id, + get_database_path(), + ) + episodes = Core.Database.get_user_recent_episodes( + user_id, + 10, + get_database_path(), + ) + + return HomePage( + queue_items=queue_items, + episodes=episodes, + user=user, + error=error_message, + ) + + +@app.post("/login") +def login(request: Request, data: FormData) -> Response: + """Handle login/registration.""" + try: + email_raw = data.get("email", "") + email = email_raw.strip().lower() if isinstance(email_raw, str) else "" + + if not email: + return Response( + '<div style="color: #dc3545;">Email is required</div>', + status_code=400, + ) + + area = App.from_env() + + if area == App.Area.Test: + # Development mode: instant login + user = Core.Database.get_user_by_email(email, get_database_path()) + if not user: + user_id, token = Core.Database.create_user( + email, + get_database_path(), + ) + user = {"id": user_id, "email": email, "token": token} + + # Set session with extended lifetime + request.session["user_id"] = user["id"] + request.session["permanent"] = True + + return Response( + '<div style="color: #28a745;">✓ Logged in (dev mode)</div>', + status_code=200, + headers={"HX-Redirect": "/"}, + ) + + # Production mode: send magic link + # Get or create user + user = Core.Database.get_user_by_email(email, get_database_path()) + if not user: + user_id, token = Core.Database.create_user( + email, + get_database_path(), + ) + user = {"id": user_id, "email": email, "token": token} + + # Generate magic link token + magic_token = magic_link_serializer.dumps({ + "user_id": user["id"], + "email": email, + }) + + # Send email + send_magic_link(email, magic_token) + + return Response( + f'<div style="color: #28a745;">✓ Magic link sent to {email}. ' + f"Check your email!</div>", + status_code=200, + ) + + except Exception as e: + logger.exception("Login error") + return Response( + f'<div style="color: #dc3545;">Error: {e!s}</div>', + status_code=500, + ) + + +@app.get("/auth/verify") +def verify_magic_link(request: Request) -> Response: + """Verify magic link and log user in.""" + token = request.query_params.get("token") + + if not token: + return RedirectResponse("/?error=invalid_link") + + try: + # Verify token + data = magic_link_serializer.loads(token, max_age=MAGIC_LINK_MAX_AGE) + user_id = data["user_id"] + + # Verify user still exists + user = Core.Database.get_user_by_id(user_id, get_database_path()) + if not user: + return RedirectResponse("/?error=user_not_found") + + # Set session with extended lifetime + request.session["user_id"] = user_id + request.session["permanent"] = True + + return RedirectResponse("/") + + except Exception: # noqa: BLE001 + return RedirectResponse("/?error=expired_link") + + +@app.get("/logout") +def logout(request: Request) -> Response: + """Handle logout.""" + request.session.clear() + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + +@app.post("/submit") +def submit_article(request: Request, data: FormData) -> html.div: + """Handle manual form submission.""" + try: + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return html.div( + "Error: Please login first", + style={"color": "#dc3545"}, + ) + + user = Core.Database.get_user_by_id(user_id, get_database_path()) + if not user: + return html.div( + "Error: Invalid session", + style={"color": "#dc3545"}, + ) + + url_raw = data.get("url", "") + url = url_raw.strip() if isinstance(url_raw, str) else "" + + if not url: + return html.div( + "Error: URL is required", + style={"color": "#dc3545"}, + ) + + # Basic URL validation + parsed = urllib.parse.urlparse(url) + if not parsed.scheme or not parsed.netloc: + return html.div( + "Error: Invalid URL format", + style={"color": "#dc3545"}, + ) + + job_id = Core.Database.add_to_queue( + url, + user["email"], + user_id, + get_database_path(), + ) + return html.div( + f"✓ Article submitted successfully! Job ID: {job_id}", + style={"color": "#28a745", "font-weight": "bold"}, + ) + + except Exception as e: # noqa: BLE001 + return html.div(f"Error: {e!s}", style={"color": "#dc3545"}) + + +@app.post("/webhook/mailgun") +def mailgun_webhook(request: Request, data: FormData) -> Response: # noqa: ARG001 + """Process email submissions.""" + try: + # Verify signature + token_raw = data.get("token", "") + timestamp_raw = data.get("timestamp", "") + signature_raw = data.get("signature", "") + + token = token_raw if isinstance(token_raw, str) else "" + timestamp = timestamp_raw if isinstance(timestamp_raw, str) else "" + signature = signature_raw if isinstance(signature_raw, str) else "" + + if not verify_mailgun_signature(token, timestamp, signature): + return Response("Unauthorized", status_code=401) + + # Extract email data + sender_raw = data.get("sender", "") + body_plain_raw = data.get("body-plain", "") + + sender = sender_raw if isinstance(sender_raw, str) else "" + body_plain = body_plain_raw if isinstance(body_plain_raw, str) else "" + + # Auto-create user if doesn't exist + user = Core.Database.get_user_by_email(sender, get_database_path()) + if not user: + user_id, token = Core.Database.create_user( + sender, + get_database_path(), + ) + logger.info("Auto-created user %s for email %s", user_id, sender) + else: + user_id = user["id"] + + # Look for URLs in email body + urls = extract_urls_from_text(body_plain) + + if urls: + # Use first URL found + url = urls[0] + Core.Database.add_to_queue( + url, + sender, + user_id, + get_database_path(), + ) + return Response("OK - URL queued") + # No URL found, treat body as content + # For MVP, we'll skip this case + return Response("OK - No URL found") + + except Exception: # noqa: BLE001 + return Response("Error", status_code=500) + + +@app.get("/feed/{token}.xml") +def rss_feed(request: Request, token: str) -> Response: # noqa: ARG001 + """Generate user-specific RSS podcast feed.""" + try: + # Validate token and get user + user = Core.Database.get_user_by_token(token, get_database_path()) + if not user: + return Response("Invalid feed token", status_code=404) + + # Get episodes for this user only + episodes = Core.Database.get_user_all_episodes( + user["id"], + get_database_path(), + ) + + # Extract first name from email for personalization + email_name = user["email"].split("@")[0].split(".")[0].title() + + fg = FeedGenerator() + fg.title(f"{email_name}'s Article Podcast") + fg.description(f"Web articles converted to audio for {user['email']}") + fg.author(name=RSS_CONFIG["author"]) + fg.language(RSS_CONFIG["language"]) + fg.link(href=f"{RSS_CONFIG['base_url']}/feed/{token}.xml") + fg.id(f"{RSS_CONFIG['base_url']}/feed/{token}.xml") + + for episode in episodes: + fe = fg.add_entry() + fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode['id']}") + fe.title(episode["title"]) + fe.description(f"Episode {episode['id']}: {episode['title']}") + fe.enclosure( + episode["audio_url"], + str(episode.get("content_length", 0)), + "audio/mpeg", + ) + # SQLite timestamps don't have timezone info, so add UTC + created_at = datetime.fromisoformat(episode["created_at"]) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + fe.pubDate(created_at) + + rss_str = fg.rss_str(pretty=True) + return Response( + rss_str, + media_type="application/rss+xml; charset=utf-8", + ) + + except Exception as e: # noqa: BLE001 + return Response(f"Error generating feed: {e}", status_code=500) + + +@app.get("/status") +def queue_status(request: Request) -> QueueStatus: # noqa: ARG001 + """Return HTMX endpoint for live queue updates.""" + queue_items = Core.Database.get_queue_status(get_database_path()) + return QueueStatus(items=queue_items) + + +@app.get("/queue-status") +def admin_queue_status(request: Request) -> AdminView | Response: + """Return admin view showing all queue items and episodes.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + # Redirect to login + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id(user_id, get_database_path()) + if not user: + # Invalid session + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + # For now, all logged-in users can see their own data + # Later we can add an admin flag to see all data + all_queue_items = Core.Database.get_all_queue_items( + get_database_path(), + user_id, + ) + all_episodes = Core.Database.get_all_episodes(get_database_path(), user_id) + status_counts = Core.Database.get_user_status_counts( + user_id, + get_database_path(), + ) + + return AdminView( + queue_items=all_queue_items, + episodes=all_episodes, + status_counts=status_counts, + ) + + +@app.post("/queue/{job_id}/retry") +def retry_queue_item(request: Request, job_id: int) -> Response: + """Retry a failed queue item.""" + try: + # Check if user owns this job + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + job = Core.Database.get_job_by_id(job_id, get_database_path()) + if job is None or job.get("user_id") != user_id: + return Response("Forbidden", status_code=403) + + Core.Database.retry_job(job_id, get_database_path()) + # Redirect back to admin view + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/queue-status"}, + ) + except Exception as e: # noqa: BLE001 + return Response( + f"Error retrying job: {e!s}", + status_code=500, + ) + + +@app.delete("/queue/{job_id}") +def delete_queue_item(request: Request, job_id: int) -> Response: + """Delete a queue item.""" + try: + # Check if user owns this job + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + job = Core.Database.get_job_by_id(job_id, get_database_path()) + if job is None or job.get("user_id") != user_id: + return Response("Forbidden", status_code=403) + + Core.Database.delete_job(job_id, get_database_path()) + # Redirect back to admin view + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/queue-status"}, + ) + except Exception as e: # noqa: BLE001 + return Response( + f"Error deleting job: {e!s}", + status_code=500, + ) + + +class BaseWebTest(Test.TestCase): + """Base class for web tests with database setup.""" + + def setUp(self) -> None: + """Set up test database and client.""" + # Create a test database context + self.test_db_path = "test_podcast_web.db" + + # Save original database path + self._original_db_path = globals()["_test_database_path"] + globals()["_test_database_path"] = self.test_db_path + + # Clean up any existing test database + db_file = pathlib.Path(self.test_db_path) + if db_file.exists(): + db_file.unlink() + + # Initialize test database + Core.Database.init_db(self.test_db_path) + + # Create test client + self.client = TestClient(app) + + def tearDown(self) -> None: + """Clean up test database.""" + # Clean up test database file + db_file = pathlib.Path(self.test_db_path) + if db_file.exists(): + db_file.unlink() + + # Restore original database path + globals()["_test_database_path"] = self._original_db_path + + +class TestAuthentication(BaseWebTest): + """Test authentication functionality.""" + + def test_login_new_user(self) -> None: + """Auto-create user on first login.""" + response = self.client.post("/login", data={"email": "new@example.com"}) + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + self.assertEqual(response.headers["HX-Redirect"], "/") + + # Verify user was created + user = Core.Database.get_user_by_email( + "new@example.com", + get_database_path(), + ) + self.assertIsNotNone(user) + + def test_login_existing_user(self) -> None: + """Login with existing email.""" + # Create user first + Core.Database.create_user("existing@example.com", get_database_path()) + + response = self.client.post( + "/login", + data={"email": "existing@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + def test_login_invalid_email(self) -> None: + """Reject malformed emails.""" + response = self.client.post("/login", data={"email": ""}) + + self.assertEqual(response.status_code, 400) + self.assertIn("Email is required", response.text) + + def test_session_persistence(self) -> None: + """Verify session across requests.""" + # Login + self.client.post("/login", data={"email": "test@example.com"}) + + # Access protected page + response = self.client.get("/") + + # Should see logged-in content + self.assertIn("Logged in as: test@example.com", response.text) + + def test_protected_routes(self) -> None: + """Ensure auth required for user actions.""" + # Try to submit without login + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + self.assertIn("Please login first", response.text) + + +class TestArticleSubmission(BaseWebTest): + """Test article submission functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + # Login + self.client.post("/login", data={"email": "test@example.com"}) + + def test_submit_valid_url(self) -> None: + """Accept well-formed URLs.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com/article"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Article submitted successfully", response.text) + self.assertIn("Job ID:", response.text) + + def test_submit_invalid_url(self) -> None: + """Reject malformed URLs.""" + response = self.client.post("/submit", data={"url": "not-a-url"}) + + self.assertIn("Invalid URL format", response.text) + + def test_submit_without_auth(self) -> None: + """Reject unauthenticated submissions.""" + # Clear session + self.client.get("/logout") + + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + self.assertIn("Please login first", response.text) + + def test_submit_creates_job(self) -> None: + """Verify job creation in database.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com/test"}, + ) + + # Extract job ID from response + match = re.search(r"Job ID: (\d+)", response.text) + self.assertIsNotNone(match) + if match is None: + self.fail("Job ID not found in response") + job_id = int(match.group(1)) + + # Verify job in database + job = Core.Database.get_job_by_id(job_id, get_database_path()) + self.assertIsNotNone(job) + if job is None: # Type guard for mypy + self.fail("Job should not be None") + self.assertEqual(job["url"], "https://example.com/test") + self.assertEqual(job["status"], "pending") + + def test_htmx_response(self) -> None: + """Ensure proper HTMX response format.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + # Should return HTML fragment, not full page + self.assertNotIn("<!DOCTYPE", response.text) + self.assertIn("<div", response.text) + + +class TestRSSFeed(BaseWebTest): + """Test RSS feed generation.""" + + def setUp(self) -> None: + """Set up test client and create test data.""" + super().setUp() + + # Create user and episodes + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + get_database_path(), + ) + + # Create test episodes + Core.Database.create_episode( + "Episode 1", + "https://example.com/ep1.mp3", + 300, + 5000, + self.user_id, + get_database_path(), + ) + Core.Database.create_episode( + "Episode 2", + "https://example.com/ep2.mp3", + 600, + 10000, + self.user_id, + get_database_path(), + ) + + def test_feed_generation(self) -> None: + """Generate valid RSS XML.""" + response = self.client.get(f"/feed/{self.token}.xml") + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.headers["content-type"], + "application/rss+xml; charset=utf-8", + ) + + # Verify RSS structure + self.assertIn("<?xml", response.text) + self.assertIn("<rss", response.text) + self.assertIn("<channel>", response.text) + self.assertIn("<item>", response.text) + + def test_feed_user_isolation(self) -> None: + """Only show user's episodes.""" + # Create another user with episodes + user2_id, _ = Core.Database.create_user( + "other@example.com", + get_database_path(), + ) + Core.Database.create_episode( + "Other Episode", + "https://example.com/other.mp3", + 400, + 6000, + user2_id, + get_database_path(), + ) + + # Get first user's feed + response = self.client.get(f"/feed/{self.token}.xml") + + # Should only have user's episodes + self.assertIn("Episode 1", response.text) + self.assertIn("Episode 2", response.text) + self.assertNotIn("Other Episode", response.text) + + def test_feed_invalid_token(self) -> None: + """Return 404 for bad tokens.""" + response = self.client.get("/feed/invalid-token.xml") + + self.assertEqual(response.status_code, 404) + + def test_feed_metadata(self) -> None: + """Verify personalized feed titles.""" + response = self.client.get(f"/feed/{self.token}.xml") + + # Should personalize based on email + self.assertIn("Test's Article Podcast", response.text) + self.assertIn("test@example.com", response.text) + + def test_feed_episode_order(self) -> None: + """Ensure reverse chronological order.""" + response = self.client.get(f"/feed/{self.token}.xml") + + # Episode 2 should appear before Episode 1 + ep2_pos = response.text.find("Episode 2") + ep1_pos = response.text.find("Episode 1") + self.assertLess(ep2_pos, ep1_pos) + + def test_feed_enclosures(self) -> None: + """Verify audio URLs and metadata.""" + response = self.client.get(f"/feed/{self.token}.xml") + + # Check enclosure tags + self.assertIn("<enclosure", response.text) + self.assertIn('type="audio/mpeg"', response.text) + self.assertIn("https://example.com/ep1.mp3", response.text) + self.assertIn("https://example.com/ep2.mp3", response.text) + + +class TestWebhook(BaseWebTest): + """Test Mailgun webhook functionality.""" + + def test_mailgun_signature_valid(self) -> None: + """Accept valid signatures.""" + # Save original key + original_key = globals()["MAILGUN_WEBHOOK_KEY"] + globals()["MAILGUN_WEBHOOK_KEY"] = "test-key" + + try: + # Generate valid signature + timestamp = str(int(time.time())) + token = "test-token" # noqa: S105 + + value = f"{timestamp}{token}" + signature = hmac.new( + b"test-key", + value.encode(), + hashlib.sha256, + ).hexdigest() + + response = self.client.post( + "/webhook/mailgun", + data={ + "token": token, + "timestamp": timestamp, + "signature": signature, + "sender": "test@example.com", + "body-plain": "Check out https://example.com/article", + }, + ) + + self.assertEqual(response.status_code, 200) + finally: + globals()["MAILGUN_WEBHOOK_KEY"] = original_key + + def test_mailgun_signature_invalid(self) -> None: + """Reject invalid signatures.""" + # Save original key + original_key = globals()["MAILGUN_WEBHOOK_KEY"] + globals()["MAILGUN_WEBHOOK_KEY"] = "test-key" + + try: + response = self.client.post( + "/webhook/mailgun", + data={ + "token": "test-token", + "timestamp": "12345", + "signature": "invalid", + "sender": "test@example.com", + "body-plain": "https://example.com", + }, + ) + + self.assertEqual(response.status_code, 401) + finally: + globals()["MAILGUN_WEBHOOK_KEY"] = original_key + + def test_webhook_url_extraction(self) -> None: + """Extract URLs from email body.""" + # Save original key + original_key = globals()["MAILGUN_WEBHOOK_KEY"] + globals()["MAILGUN_WEBHOOK_KEY"] = "" + + try: + response = self.client.post( + "/webhook/mailgun", + data={ + "sender": "test@example.com", + "body-plain": ( + "Hey, check this out: " + "https://example.com/article and also " + "https://example.com/other" + ), + }, + ) + + self.assertEqual(response.status_code, 200) + + # Should queue first URL + jobs = Core.Database.get_pending_jobs(db_path=get_database_path()) + self.assertEqual(len(jobs), 1) + self.assertEqual(jobs[0]["url"], "https://example.com/article") + finally: + globals()["MAILGUN_WEBHOOK_KEY"] = original_key + + def test_webhook_auto_create_user(self) -> None: + """Create user on first email.""" + # Save original key + original_key = globals()["MAILGUN_WEBHOOK_KEY"] + globals()["MAILGUN_WEBHOOK_KEY"] = "" + + try: + response = self.client.post( + "/webhook/mailgun", + data={ + "sender": "newuser@example.com", + "body-plain": "https://example.com/article", + }, + ) + + self.assertEqual(response.status_code, 200) + + # User should be created + user = Core.Database.get_user_by_email( + "newuser@example.com", + get_database_path(), + ) + self.assertIsNotNone(user) + finally: + globals()["MAILGUN_WEBHOOK_KEY"] = original_key + + def test_webhook_multiple_urls(self) -> None: + """Handle emails with multiple URLs.""" + # Save original key + original_key = globals()["MAILGUN_WEBHOOK_KEY"] + globals()["MAILGUN_WEBHOOK_KEY"] = "" + + try: + response = self.client.post( + "/webhook/mailgun", + data={ + "sender": "test@example.com", + "body-plain": ( + "URLs: https://example.com/1 " + "https://example.com/2 https://example.com/3" + ), + }, + ) + + self.assertEqual(response.status_code, 200) + + # Should only queue first URL + jobs = Core.Database.get_pending_jobs(db_path=get_database_path()) + self.assertEqual(len(jobs), 1) + self.assertEqual(jobs[0]["url"], "https://example.com/1") + finally: + globals()["MAILGUN_WEBHOOK_KEY"] = original_key + + def test_webhook_no_urls(self) -> None: + """Handle emails without URLs gracefully.""" + # Save original key + original_key = globals()["MAILGUN_WEBHOOK_KEY"] + globals()["MAILGUN_WEBHOOK_KEY"] = "" + + try: + response = self.client.post( + "/webhook/mailgun", + data={ + "sender": "test@example.com", + "body-plain": "This email has no URLs", + }, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("No URL found", response.text) + finally: + globals()["MAILGUN_WEBHOOK_KEY"] = original_key + + +class TestAdminInterface(BaseWebTest): + """Test admin interface functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + + # Create and login user + self.user_id, _ = Core.Database.create_user( + "test@example.com", + get_database_path(), + ) + self.client.post("/login", data={"email": "test@example.com"}) + + # Create test data + self.job_id = Core.Database.add_to_queue( + "https://example.com/test", + "test@example.com", + self.user_id, + get_database_path(), + ) + + def test_queue_status_view(self) -> None: + """Verify queue display.""" + response = self.client.get("/queue-status") + + self.assertEqual(response.status_code, 200) + self.assertIn("Queue Status", response.text) + self.assertIn("https://example.com/test", response.text) + + def test_retry_action(self) -> None: + """Test retry button functionality.""" + # Set job to error state + Core.Database.update_job_status( + self.job_id, + "error", + "Failed", + get_database_path(), + ) + + # Retry + response = self.client.post(f"/queue/{self.job_id}/retry") + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Job should be pending again + job = Core.Database.get_job_by_id(self.job_id, get_database_path()) + self.assertIsNotNone(job) + if job is not None: + self.assertEqual(job["status"], "pending") + + def test_delete_action(self) -> None: + """Test delete button functionality.""" + response = self.client.delete(f"/queue/{self.job_id}") + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Job should be gone + job = Core.Database.get_job_by_id(self.job_id, get_database_path()) + self.assertIsNone(job) + + def test_user_data_isolation(self) -> None: + """Ensure users only see own data.""" + # Create another user's job + user2_id, _ = Core.Database.create_user( + "other@example.com", + get_database_path(), + ) + Core.Database.add_to_queue( + "https://example.com/other", + "other@example.com", + user2_id, + get_database_path(), + ) + + # View queue status + response = self.client.get("/queue-status") + + # Should only see own job + self.assertIn("https://example.com/test", response.text) + self.assertNotIn("https://example.com/other", response.text) + + def test_status_summary(self) -> None: + """Verify status counts display.""" + # Create jobs with different statuses + Core.Database.update_job_status( + self.job_id, + "error", + "Failed", + get_database_path(), + ) + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + get_database_path(), + ) + Core.Database.update_job_status( + job2, + "processing", + db_path=get_database_path(), + ) + + response = self.client.get("/queue-status") + + # Should show status counts + self.assertIn("ERROR: 1", response.text) + self.assertIn("PROCESSING: 1", response.text) + + +def test() -> None: + """Run all tests for the web module.""" + Test.run( + App.Area.Test, + [ + TestAuthentication, + TestArticleSubmission, + TestRSSFeed, + TestWebhook, + TestAdminInterface, + ], + ) + + +def main() -> None: + """Run the web server.""" + if "test" in sys.argv: + test() + else: + uvicorn.run(app, host="0.0.0.0", port=PORT) # noqa: S104 diff --git a/Biz/PodcastItLater/Worker.nix b/Biz/PodcastItLater/Worker.nix new file mode 100644 index 0000000..14aed9d --- /dev/null +++ b/Biz/PodcastItLater/Worker.nix @@ -0,0 +1,58 @@ +{ + options, + lib, + config, + pkgs, + ... +}: let + cfg = config.services.podcastitlater-worker; +in { + options.services.podcastitlater-worker = { + enable = lib.mkEnableOption "Enable the PodcastItLater worker service"; + dataDir = lib.mkOption { + type = lib.types.path; + default = "/var/podcastitlater"; + description = "Data directory for PodcastItLater (shared with web)"; + }; + package = lib.mkOption { + type = lib.types.package; + description = "PodcastItLater worker package to use"; + }; + }; + config = lib.mkIf cfg.enable { + systemd.services.podcastitlater-worker = { + path = [cfg.package pkgs.ffmpeg]; # ffmpeg needed for pydub + wantedBy = ["multi-user.target"]; + after = ["network.target"]; + preStart = '' + # Create data directory if it doesn't exist + mkdir -p ${cfg.dataDir} + + # Manual step: create this file with secrets + # OPENAI_API_KEY=your-openai-api-key + # S3_ENDPOINT=https://your-s3-endpoint.digitaloceanspaces.com + # S3_BUCKET=your-bucket-name + # S3_ACCESS_KEY=your-s3-access-key + # S3_SECRET_KEY=your-s3-secret-key + test -f /run/podcastitlater/worker-env + ''; + script = '' + ${cfg.package}/bin/podcastitlater-worker + ''; + description = '' + PodcastItLater Worker Service - processes articles to podcasts + ''; + serviceConfig = { + Environment = [ + "AREA=Live" + "DATABASE_PATH=${cfg.dataDir}/podcast.db" + ]; + EnvironmentFile = "/run/podcastitlater/worker-env"; + KillSignal = "INT"; + Type = "simple"; + Restart = "always"; + RestartSec = "10"; + }; + }; + }; +} diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py new file mode 100644 index 0000000..af51260 --- /dev/null +++ b/Biz/PodcastItLater/Worker.py @@ -0,0 +1,1194 @@ +"""Background worker for processing article-to-podcast conversions.""" + +# : dep boto3 +# : dep botocore +# : dep openai +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep trafilatura +# : out podcastitlater-worker +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import boto3 # type: ignore[import-untyped] +import io +import json +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import os +import pathlib +import pytest +import sys +import time +import trafilatura +import typing +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from pydub import AudioSegment # type: ignore[import-untyped] +from typing import Any + +logger = Log.setup() + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +S3_ENDPOINT = os.getenv("S3_ENDPOINT") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") +DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db") + +# Worker configuration +MAX_CONTENT_LENGTH = 5000 # characters for TTS +POLL_INTERVAL = 30 # seconds +MAX_RETRIES = 3 +TTS_MODEL = "tts-1" +TTS_VOICE = "alloy" + + +class ArticleProcessor: + """Handles the complete article-to-podcast conversion pipeline.""" + + def __init__(self) -> None: + """Initialize the processor with required services. + + Raises: + ValueError: If OPENAI_API_KEY environment variable is not set. + """ + if not OPENAI_API_KEY: + msg = "OPENAI_API_KEY environment variable is required" + raise ValueError(msg) + + self.openai_client: openai.OpenAI = openai.OpenAI( + api_key=OPENAI_API_KEY, + ) + + # Initialize S3 client for Digital Ocean Spaces + if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): + self.s3_client: Any = boto3.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + ) + else: + logger.warning("S3 configuration incomplete, uploads will fail") + self.s3_client = None + + @staticmethod + def extract_article_content(url: str) -> tuple[str, str]: + """Extract title and content from article URL using trafilatura. + + Raises: + ValueError: If content cannot be downloaded or extracted. + """ + try: + downloaded = trafilatura.fetch_url(url) + if not downloaded: + msg = f"Failed to download content from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Extract with metadata + result = trafilatura.extract( + downloaded, + include_comments=False, + include_tables=False, + with_metadata=True, + output_format="json", + ) + + if not result: + msg = f"Failed to extract content from {url}" + raise ValueError(msg) # noqa: TRY301 + + data = json.loads(result) + + title = data.get("title", "Untitled Article") + content = data.get("text", "") + + if not content: + msg = f"No content extracted from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Don't truncate - we'll handle length in text_to_speech + logger.info("Extracted article: %s (%d chars)", title, len(content)) + except Exception: + logger.exception("Failed to extract content from %s", url) + raise + else: + return title, content + + def text_to_speech(self, text: str, title: str) -> bytes: + """Convert text to speech using OpenAI TTS API. + + Uses LLM to prepare text, then handles chunking and concatenation. + + Raises: + ValueError: If no chunks are generated from text. + """ + try: + # Use LLM to prepare and chunk the text + chunks = prepare_text_for_tts(text, title) + + if not chunks: + msg = "No chunks generated from text" + raise ValueError(msg) # noqa: TRY301 + + logger.info("Processing %d chunks for TTS", len(chunks)) + + # Generate audio for each chunk + audio_segments = [] + for i, chunk in enumerate(chunks): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Convert bytes to AudioSegment + audio_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + audio_segments.append(audio_segment) + + # Small delay between API calls to be respectful + if i < len(chunks) - 1: + time.sleep(0.5) + + # Concatenate all audio segments + combined_audio = audio_segments[0] + for segment in audio_segments[1:]: + # Add a small silence between chunks for natural pacing + silence = AudioSegment.silent(duration=300) + combined_audio = combined_audio + silence + segment + + # Export combined audio to bytes + output_buffer = io.BytesIO() + combined_audio.export(output_buffer, format="mp3", bitrate="128k") + audio_data = output_buffer.getvalue() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + except Exception: + logger.exception("TTS generation failed") + raise + else: + return audio_data + + def upload_to_s3(self, audio_data: bytes, filename: str) -> str: + """Upload audio file to S3-compatible storage and return public URL. + + Raises: + ValueError: If S3 client is not configured. + ClientError: If S3 upload fails. + """ + if not self.s3_client: + msg = "S3 client not configured" + raise ValueError(msg) + + try: + # Upload file + self.s3_client.put_object( + Bucket=S3_BUCKET, + Key=filename, + Body=audio_data, + ContentType="audio/mpeg", + ACL="public-read", + ) + + # Construct public URL + audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" + logger.info("Uploaded audio to: %s", audio_url) + except ClientError: + logger.exception("S3 upload failed") + raise + else: + return audio_url + + @staticmethod + def estimate_duration(audio_data: bytes) -> int: + """Estimate audio duration in seconds based on file size and bitrate.""" + # Rough estimation: MP3 at 128kbps = ~16KB per second + estimated_seconds = len(audio_data) // 16000 + return max(1, estimated_seconds) # Minimum 1 second + + @staticmethod + def generate_filename(job_id: int, title: str) -> str: + """Generate unique filename for audio file.""" + timestamp = int(datetime.now(tz=timezone.utc).timestamp()) + # Create safe filename from title + safe_title = "".join( + c for c in title if c.isalnum() or c in {" ", "-", "_"} + ).rstrip() + safe_title = safe_title.replace(" ", "_")[:50] # Limit length + return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" + + def process_job(self, job: dict[str, Any]) -> None: + """Process a single job through the complete pipeline.""" + job_id = job["id"] + url = job["url"] + + try: + logger.info("Processing job %d: %s", job_id, url) + + # Update status to processing + Core.Database.update_job_status( + job_id, + "processing", + db_path=DATABASE_PATH, + ) + + # Step 1: Extract article content + title, content = ArticleProcessor.extract_article_content(url) + + # Step 2: Generate audio + audio_data = self.text_to_speech(content, title) + + # Step 3: Upload to S3 + filename = ArticleProcessor.generate_filename(job_id, title) + audio_url = self.upload_to_s3(audio_data, filename) + + # Step 4: Calculate duration + duration = ArticleProcessor.estimate_duration(audio_data) + + # Step 5: Create episode record + episode_id = Core.Database.create_episode( + title=title, + audio_url=audio_url, + duration=duration, + content_length=len(content), + user_id=job.get("user_id"), + db_path=DATABASE_PATH, + ) + + # Step 6: Mark job as complete + Core.Database.update_job_status( + job_id, + "completed", + db_path=DATABASE_PATH, + ) + + logger.info( + "Successfully processed job %d -> episode %d", + job_id, + episode_id, + ) + + except Exception as e: + error_msg = str(e) + logger.exception("Job %d failed: %s", job_id, error_msg) + Core.Database.update_job_status( + job_id, + "error", + error_msg, + DATABASE_PATH, + ) + raise + + +def prepare_text_for_tts(text: str, title: str) -> list[str]: + """Use LLM to prepare text for TTS, returning chunks ready for speech. + + First splits text mechanically, then has LLM edit each chunk. + """ + # First, split the text into manageable chunks + raw_chunks = split_text_into_chunks(text, max_chars=3000) + + logger.info("Split article into %d raw chunks", len(raw_chunks)) + + # Prepare the first chunk with intro + edited_chunks = [] + + for i, chunk in enumerate(raw_chunks): + is_first = i == 0 + is_last = i == len(raw_chunks) - 1 + + try: + edited_chunk = edit_chunk_for_speech( + chunk, + title=title if is_first else None, + is_first=is_first, + is_last=is_last, + ) + edited_chunks.append(edited_chunk) + except Exception: + logger.exception("Failed to edit chunk %d", i + 1) + # Fall back to raw chunk if LLM fails + if is_first: + edited_chunks.append( + f"This is an audio version of {title}. {chunk}", + ) + elif is_last: + edited_chunks.append(f"{chunk} This concludes the article.") + else: + edited_chunks.append(chunk) + + return edited_chunks + + +def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: + """Split text into chunks at sentence boundaries.""" + chunks = [] + current_chunk = "" + + # Split into paragraphs first + paragraphs = text.split("\n\n") + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + # If paragraph itself is too long, split by sentences + if len(para_stripped) > max_chars: + sentences = para_stripped.split(". ") + for sentence in sentences: + if len(current_chunk) + len(sentence) + 2 < max_chars: + current_chunk += sentence + ". " + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + ". " + # If adding this paragraph would exceed limit, start new chunk + elif len(current_chunk) + len(para_stripped) + 2 > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = para_stripped + " " + else: + current_chunk += para_stripped + " " + + # Don't forget the last chunk + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def edit_chunk_for_speech( + chunk: str, + title: str | None = None, + *, + is_first: bool = False, + is_last: bool = False, +) -> str: + """Use LLM to lightly edit a single chunk for speech. + + Raises: + ValueError: If no content is returned from LLM. + """ + system_prompt = ( + "You are a podcast script editor. Your job is to lightly edit text " + "to make it sound natural when spoken aloud.\n\n" + "Guidelines:\n" + ) + system_prompt += """ +- Remove URLs and email addresses, replacing with descriptive phrases +- Convert bullet points and lists into flowing sentences +- Fix any awkward phrasing for speech +- Remove references like "click here" or "see below" +- Keep edits minimal - preserve the original content and style +- Do NOT add commentary or explanations +- Return ONLY the edited text, no JSON or formatting +""" + + user_prompt = chunk + + # Add intro/outro if needed + if is_first and title: + user_prompt = ( + f"Add a brief intro mentioning this is an audio version of " + f"'{title}', then edit this text:\n\n{chunk}" + ) + elif is_last: + user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" + + try: + client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.3, # Lower temperature for more consistent edits + max_tokens=4000, + ) + + content = response.choices[0].message.content + if not content: + msg = "No content returned from LLM" + raise ValueError(msg) # noqa: TRY301 + + # Ensure the chunk isn't too long + max_chunk_length = 4000 + if len(content) > max_chunk_length: + # Truncate at sentence boundary + sentences = content.split(". ") + truncated = "" + for sentence in sentences: + if len(truncated) + len(sentence) + 2 < max_chunk_length: + truncated += sentence + ". " + else: + break + content = truncated.strip() + + except Exception: + logger.exception("LLM chunk editing failed") + raise + else: + return content + + +def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: + """Parse datetime string and ensure it has timezone info.""" + if isinstance(created_at, str): + # Handle timezone-aware datetime strings + if created_at.endswith("Z"): + created_at = created_at[:-1] + "+00:00" + last_attempt = datetime.fromisoformat(created_at) + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + else: + last_attempt = created_at + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + return last_attempt + + +def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: + """Check if a job should be retried based on retry count and backoff time. + + Uses exponential backoff to determine if enough time has passed. + """ + retry_count = job["retry_count"] + if retry_count >= max_retries: + return False + + # Exponential backoff: 30s, 60s, 120s + backoff_time = 30 * (2**retry_count) + last_attempt = parse_datetime_with_timezone(job["created_at"]) + time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt + + return time_since_attempt > timedelta(seconds=backoff_time) + + +def process_pending_jobs(processor: ArticleProcessor) -> None: + """Process all pending jobs.""" + pending_jobs = Core.Database.get_pending_jobs( + limit=5, + db_path=DATABASE_PATH, + ) + + for job in pending_jobs: + current_job = job["id"] + try: + processor.process_job(job) + except Exception as e: + # Ensure job is marked as error even if process_job didn't handle it + logger.exception("Failed to process job: %d", current_job) + # Check if job is still in processing state + current_status = Core.Database.get_job_by_id( + current_job, + DATABASE_PATH, + ) + if current_status and current_status.get("status") == "processing": + Core.Database.update_job_status( + current_job, + "error", + str(e), + DATABASE_PATH, + ) + continue + + +def process_retryable_jobs() -> None: + """Check and retry failed jobs with exponential backoff.""" + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + DATABASE_PATH, + ) + + for job in retryable_jobs: + if should_retry_job(job, MAX_RETRIES): + logger.info( + "Retrying job %d (attempt %d)", + job["id"], + job["retry_count"] + 1, + ) + Core.Database.update_job_status( + job["id"], + "pending", + db_path=DATABASE_PATH, + ) + + +def main_loop() -> None: + """Poll for jobs and process them in a continuous loop.""" + processor = ArticleProcessor() + logger.info("Worker started, polling for jobs...") + + while True: + try: + process_pending_jobs(processor) + process_retryable_jobs() + + # Check if there's any work + pending_jobs = Core.Database.get_pending_jobs( + limit=1, + db_path=DATABASE_PATH, + ) + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + DATABASE_PATH, + ) + + if not pending_jobs and not retryable_jobs: + logger.debug("No jobs to process, sleeping...") + + except Exception: + logger.exception("Error in main loop") + + time.sleep(POLL_INTERVAL) + + +def move() -> None: + """Make the worker move.""" + try: + # Initialize database + Core.Database.init_db(DATABASE_PATH) + + # Start main processing loop + main_loop() + + except KeyboardInterrupt: + logger.info("Worker stopped by user") + except Exception: + logger.exception("Worker crashed") + raise + + +class TestArticleExtraction(Test.TestCase): + """Test article extraction functionality.""" + + def test_extract_valid_article(self) -> None: + """Extract from well-formed HTML.""" + # Mock trafilatura.fetch_url and extract + mock_html = ( + "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" + ) + mock_result = json.dumps({ + "title": "Test Article", + "text": "Content here", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Content here") + + def test_extract_missing_title(self) -> None: + """Handle articles without titles.""" + mock_html = "<html><body><p>Content without title</p></body></html>" + mock_result = json.dumps({"text": "Content without title"}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(title, "Untitled Article") + self.assertEqual(content, "Content without title") + + def test_extract_empty_content(self) -> None: + """Handle empty articles.""" + mock_html = "<html><body></body></html>" + mock_result = json.dumps({"title": "Empty Article", "text": ""}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + pytest.raises(ValueError, match="No content extracted") as cm, + ): + ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertIn("No content extracted", str(cm.value)) + + def test_extract_network_error(self) -> None: + """Handle connection failures.""" + with ( + unittest.mock.patch("trafilatura.fetch_url", return_value=None), + pytest.raises(ValueError, match="Failed to download") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Failed to download", str(cm.value)) + + @staticmethod + def test_extract_timeout() -> None: + """Handle slow responses.""" + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + side_effect=TimeoutError("Timeout"), + ), + pytest.raises(TimeoutError), + ): + ArticleProcessor.extract_article_content("https://example.com") + + def test_content_sanitization(self) -> None: + """Remove unwanted elements.""" + mock_html = """ + <html><body> + <h1>Article</h1> + <p>Good content</p> + <script>alert('bad')</script> + <table><tr><td>data</td></tr></table> + </body></html> + """ + mock_result = json.dumps({ + "title": "Article", + "text": "Good content", # Tables and scripts removed + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + _title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(content, "Good content") + self.assertNotIn("script", content) + self.assertNotIn("table", content) + + +class TestTextToSpeech(Test.TestCase): + """Test text-to-speech functionality.""" + + def setUp(self) -> None: + """Set up mocks.""" + # Mock OpenAI API key + self.env_patcher = unittest.mock.patch.dict( + os.environ, + {"OPENAI_API_KEY": "test-key"}, + ) + self.env_patcher.start() + + # Mock OpenAI response + self.mock_audio_response: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_response.content = b"fake-audio-data" + + # Mock AudioSegment to avoid ffmpeg issues in tests + self.mock_audio_segment: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_segment.export.return_value = None + self.audio_segment_patcher = unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=self.mock_audio_segment, + ) + self.audio_segment_patcher.start() + + # Mock the concatenation operations + self.mock_audio_segment.__add__.return_value = self.mock_audio_segment + + def tearDown(self) -> None: + """Clean up mocks.""" + self.env_patcher.stop() + self.audio_segment_patcher.stop() + + def test_tts_generation(self) -> None: + """Generate audio from text.""" + + # Mock the export to write test audio data + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["Test content"], + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech( + "Test content", + "Test Title", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_chunking(self) -> None: + """Handle long articles with chunking.""" + long_text = "Long content " * 1000 + chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock AudioSegment.silent + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech( + long_text, + "Long Article", + ) + + # Should have called TTS for each chunk + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_empty_text(self) -> None: + """Handle empty input.""" + with unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[], + ): + processor = ArticleProcessor() + with pytest.raises(ValueError, match="No chunks generated") as cm: + processor.text_to_speech("", "Empty") + + self.assertIn("No chunks generated", str(cm.value)) + + def test_tts_special_characters(self) -> None: + """Handle unicode and special chars.""" + special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[special_text], + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech( + special_text, + "Special", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_llm_text_preparation(self) -> None: + """Verify LLM editing.""" + # Test the actual text preparation functions + chunks = split_text_into_chunks("Short text", max_chars=100) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], "Short text") + + # Test long text splitting + long_text = "Sentence one. " * 100 + chunks = split_text_into_chunks(long_text, max_chars=100) + self.assertGreater(len(chunks), 1) + for chunk in chunks: + self.assertLessEqual(len(chunk), 100) + + @staticmethod + def test_llm_failure_fallback() -> None: + """Handle LLM API failures.""" + # Mock LLM failure + with unittest.mock.patch("openai.OpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_client.chat.completions.create.side_effect = Exception( + "API Error", + ) + + # Should fall back to raw text + with pytest.raises(Exception, match="API Error"): + edit_chunk_for_speech("Test chunk", "Title", is_first=True) + + def test_chunk_concatenation(self) -> None: + """Verify audio joining.""" + # Mock multiple audio segments + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech("Test", "Title") + + # Should produce combined audio + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + +class TestJobProcessing(Test.TestCase): + """Test job processing functionality.""" + + def setUp(self) -> None: + """Set up test environment.""" + self.test_db = "test_podcast_worker.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Core.Database.init_db(self.test_db) + + # Create test user and job + self.user_id, _ = Core.Database.create_user( + "test@example.com", + self.test_db, + ) + self.job_id = Core.Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Mock environment + self.env_patcher = unittest.mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test-key", + "S3_ENDPOINT": "https://s3.example.com", + "S3_BUCKET": "test-bucket", + "S3_ACCESS_KEY": "test-access", + "S3_SECRET_KEY": "test-secret", + }, + ) + self.env_patcher.start() + + def tearDown(self) -> None: + """Clean up.""" + self.env_patcher.stop() + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_process_job_success(self) -> None: + """Complete pipeline execution.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock all external calls + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Test Title", "Test content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio-data", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://s3.example.com/audio.mp3", + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.create_episode", + ) as mock_create, + ): + mock_create.return_value = 1 + processor.process_job(job) + + # Verify job was marked complete + mock_update.assert_called_with(self.job_id, "completed") + mock_create.assert_called_once() + + def test_process_job_extraction_failure(self) -> None: + """Handle bad URLs.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + side_effect=ValueError("Bad URL"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(ValueError, match="Bad URL"), + ): + processor.process_job(job) + + # Job should be marked as error + mock_update.assert_called_with(self.job_id, "error", "Bad URL") + + def test_process_job_tts_failure(self) -> None: + """Handle TTS errors.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=Exception("TTS Error"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(Exception, match="TTS Error"), + ): + processor.process_job(job) + + mock_update.assert_called_with(self.job_id, "error", "TTS Error") + + def test_process_job_s3_failure(self) -> None: + """Handle upload errors.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + side_effect=ClientError({}, "PutObject"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ), + pytest.raises(ClientError), + ): + processor.process_job(job) + + def test_job_retry_logic(self) -> None: + """Verify exponential backoff.""" + # Set job to error with retry count + Core.Database.update_job_status( + self.job_id, + "error", + "First failure", + self.test_db, + ) + Core.Database.update_job_status( + self.job_id, + "error", + "Second failure", + self.test_db, + ) + + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + self.assertEqual(job["retry_count"], 2) + + # Should be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 1) + + def test_max_retries(self) -> None: + """Stop after max attempts.""" + # Exceed retry limit + for i in range(4): + Core.Database.update_job_status( + self.job_id, + "error", + f"Failure {i}", + self.test_db, + ) + + # Should not be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 0) + + def test_concurrent_processing(self) -> None: + """Handle multiple jobs.""" + # Create multiple jobs + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + self.test_db, + ) + job3 = Core.Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Get pending jobs + jobs = Core.Database.get_pending_jobs(limit=5, db_path=self.test_db) + + self.assertEqual(len(jobs), 3) + self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestArticleExtraction, + TestTextToSpeech, + TestJobProcessing, + ], + ) + + +def main() -> None: + """Entry point for the worker.""" + if "test" in sys.argv: + test() + else: + move() |
