diff options
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 1117 |
1 files changed, 1117 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py new file mode 100644 index 0000000..c0d0acf --- /dev/null +++ b/Biz/PodcastItLater/Core.py @@ -0,0 +1,1117 @@ +"""Core, shared logic for PodcastItalater. + +Includes: +- Database models +- Data access layer +- Shared types +""" + +# : out podcastitlater-core +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import pathlib +import pytest +import secrets +import sqlite3 +import sys +import time +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = Log.setup() + + +class Database: # noqa: PLR0904 + """Data access layer for PodcastItLater database operations.""" + + @staticmethod + @contextmanager + def get_connection( + db_path: str = "podcast.db", + ) -> Iterator[sqlite3.Connection]: + """Context manager for database connections. + + Yields: + sqlite3.Connection: Database connection with row factory set. + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + @staticmethod + def init_db(db_path: str = "podcast.db") -> None: + """Initialize database with required tables.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + + # Queue table for job processing + cursor.execute(""" + CREATE TABLE IF NOT EXISTS queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + email TEXT, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + error_message TEXT + ) + """) + + # Episodes table for completed podcasts + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + content_length INTEGER, + audio_url TEXT NOT NULL, + duration INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes for performance + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_created " + "ON queue(created_at)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_created " + "ON episodes(created_at)", + ) + + conn.commit() + logger.info("Database initialized successfully") + + # Run migration to add user support + Database.migrate_to_multi_user(db_path) + + @staticmethod + def add_to_queue( + url: str, + email: str, + user_id: int, + db_path: str = "podcast.db", + ) -> int: + """Insert new job into queue, return job ID. + + Raises: + ValueError: If job ID cannot be retrieved after insert. + """ + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO queue (url, email, user_id) VALUES (?, ?, ?)", + (url, email, user_id), + ) + conn.commit() + job_id = cursor.lastrowid + if job_id is None: + msg = "Failed to get job ID after insert" + raise ValueError(msg) + logger.info("Added job %s to queue: %s", job_id, url) + return job_id + + @staticmethod + def get_pending_jobs( + limit: int = 10, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Fetch jobs with status='pending' ordered by creation time.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'pending' " + "ORDER BY created_at ASC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_job_status( + job_id: int, + status: str, + error: str | None = None, + db_path: str = "podcast.db", + ) -> None: + """Update job status and error message.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + if status == "error": + cursor.execute( + "UPDATE queue SET status = ?, error_message = ?, " + "retry_count = retry_count + 1 WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ? WHERE id = ?", + (status, job_id), + ) + conn.commit() + logger.info("Updated job %s status to %s", job_id, status) + + @staticmethod + def get_job_by_id( + job_id: int, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Fetch single job by ID.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def create_episode( # noqa: PLR0913, PLR0917 + title: str, + audio_url: str, + duration: int, + content_length: int, + user_id: int | None = None, + db_path: str = "podcast.db", + ) -> int: + """Insert episode record, return episode ID. + + Raises: + ValueError: If episode ID cannot be retrieved after insert. + """ + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episodes " + "(title, audio_url, duration, content_length, user_id) " + "VALUES (?, ?, ?, ?, ?)", + (title, audio_url, duration, content_length, user_id), + ) + conn.commit() + episode_id = cursor.lastrowid + if episode_id is None: + msg = "Failed to get episode ID after insert" + raise ValueError(msg) + logger.info("Created episode %s: %s", episode_id, title) + return episode_id + + @staticmethod + def get_recent_episodes( + limit: int = 20, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get recent episodes for RSS feed generation.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_queue_status_summary(db_path: str = "podcast.db") -> dict[str, Any]: + """Get queue status summary for web interface.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + + # Count jobs by status + cursor.execute( + "SELECT status, COUNT(*) as count FROM queue GROUP BY status", + ) + rows = cursor.fetchall() + status_counts = {row["status"]: row["count"] for row in rows} + + # Get recent jobs + cursor.execute( + "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10", + ) + rows = cursor.fetchall() + recent_jobs = [dict(row) for row in rows] + + return {"status_counts": status_counts, "recent_jobs": recent_jobs} + + @staticmethod + def get_queue_status(db_path: str = "podcast.db") -> list[dict[str, Any]]: + """Return pending/processing/error items for web interface.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, url, email, status, created_at, error_message + FROM queue + WHERE status IN ('pending', 'processing', 'error') + ORDER BY created_at DESC + LIMIT 20 + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_all_episodes( + db_path: str = "podcast.db", + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all episodes for RSS feed.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length + FROM episodes + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, title, audio_url, duration, created_at, + content_length + FROM episodes + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_retryable_jobs( + max_retries: int = 3, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get failed jobs that can be retried.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'error' " + "AND retry_count < ? ORDER BY created_at ASC", + (max_retries,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def retry_job(job_id: int, db_path: str = "podcast.db") -> None: + """Reset a job to pending status for retry.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE queue SET status = 'pending', " + "error_message = NULL WHERE id = ?", + (job_id,), + ) + conn.commit() + logger.info("Reset job %s to pending for retry", job_id) + + @staticmethod + def delete_job(job_id: int, db_path: str = "podcast.db") -> None: + """Delete a job from the queue.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) + conn.commit() + logger.info("Deleted job %s from queue", job_id) + + @staticmethod + def get_all_queue_items( + db_path: str = "podcast.db", + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all queue items for admin view.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, url, email, status, retry_count, created_at, + error_message + FROM queue + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, url, email, status, retry_count, created_at, + error_message + FROM queue + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_status_counts(db_path: str = "podcast.db") -> dict[str, int]: + """Get count of queue items by status.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT status, COUNT(*) as count + FROM queue + GROUP BY status + """) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def get_user_status_counts( + user_id: int, + db_path: str = "podcast.db", + ) -> dict[str, int]: + """Get count of queue items by status for a specific user.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) as count + FROM queue + WHERE user_id = ? + GROUP BY status + """, + (user_id,), + ) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def migrate_to_multi_user(db_path: str = "podcast.db") -> None: + """Migrate database to support multiple users.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add user_id columns to existing tables + # Check if columns already exist to make migration idempotent + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "user_id" not in queue_columns: + cursor.execute( + "ALTER TABLE queue ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "user_id" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + # Create indexes + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_user_id " + "ON queue(user_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_user_id " + "ON episodes(user_id)", + ) + + conn.commit() + logger.info("Database migrated to support multiple users") + + @staticmethod + def create_user(email: str, db_path: str = "podcast.db") -> tuple[int, str]: + """Create a new user and return (user_id, token). + + Raises: + ValueError: If user ID cannot be retrieved after insert or if user + not found. + """ + # Generate a secure token for RSS feed access + token = secrets.token_urlsafe(32) + + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO users (email, token) VALUES (?, ?)", + (email, token), + ) + conn.commit() + user_id = cursor.lastrowid + if user_id is None: + msg = "Failed to get user ID after insert" + raise ValueError(msg) + logger.info("Created user %s with email %s", user_id, email) + except sqlite3.IntegrityError: + # User already exists + cursor.execute( + "SELECT id, token FROM users WHERE email = ?", + (email,), + ) + row = cursor.fetchone() + if row is None: + msg = f"User with email {email} not found" + raise ValueError(msg) from None + return int(row["id"]), str(row["token"]) + else: + return int(user_id), str(token) + + @staticmethod + def get_user_by_email( + email: str, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Get user by email address.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_token( + token: str, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Get user by RSS token.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_id( + user_id: int, + db_path: str = "podcast.db", + ) -> dict[str, Any] | None: + """Get user by ID.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_queue_status( + user_id: int, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Return pending/processing/error items for a specific user.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, url, email, status, created_at, error_message + FROM queue + WHERE user_id = ? AND + status IN ('pending', 'processing', 'error') + ORDER BY created_at DESC + LIMIT 20 + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_recent_episodes( + user_id: int, + limit: int = 20, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get recent episodes for a specific user.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC LIMIT ?", + (user_id, limit), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_all_episodes( + user_id: int, + db_path: str = "podcast.db", + ) -> list[dict[str, Any]]: + """Get all episodes for a specific user for RSS feed.""" + with Database.get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC", + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + +class TestDatabase(Test.TestCase): + """Test the Database class.""" + + def setUp(self) -> None: + """Set up test database.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_init_db(self) -> None: + """Verify all tables and indexes are created correctly.""" + with Database.get_connection(self.test_db) as conn: + cursor = conn.cursor() + + # Check tables exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + self.assertIn("queue", tables) + self.assertIn("episodes", tables) + self.assertIn("users", tables) + + # Check indexes exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='index'") + indexes = {row[0] for row in cursor.fetchall()} + self.assertIn("idx_queue_status", indexes) + self.assertIn("idx_queue_created", indexes) + self.assertIn("idx_episodes_created", indexes) + self.assertIn("idx_queue_user_id", indexes) + self.assertIn("idx_episodes_user_id", indexes) + + def test_connection_context_manager(self) -> None: + """Ensure connections are properly closed.""" + # Get a connection and verify it works + with Database.get_connection(self.test_db) as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchone() + self.assertEqual(result[0], 1) + + # Connection should be closed after context manager + with pytest.raises(sqlite3.ProgrammingError): + cursor.execute("SELECT 1") + + def test_migration_idempotency(self) -> None: + """Verify migrations can run multiple times safely.""" + # Run migration multiple times + Database.migrate_to_multi_user(self.test_db) + Database.migrate_to_multi_user(self.test_db) + Database.migrate_to_multi_user(self.test_db) + + # Should still work fine + with Database.get_connection(self.test_db) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users") + # Should not raise an error + + +class TestUserManagement(Test.TestCase): + """Test user management functionality.""" + + def setUp(self) -> None: + """Set up test database.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_create_user(self) -> None: + """Create user with unique email and token.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + self.assertIsInstance(user_id, int) + self.assertIsInstance(token, str) + self.assertGreater(len(token), 20) # Should be a secure token + + def test_create_duplicate_user(self) -> None: + """Verify duplicate emails return existing user.""" + # Create first user + user_id1, token1 = Database.create_user( + "test@example.com", + self.test_db, + ) + + # Try to create duplicate + user_id2, token2 = Database.create_user( + "test@example.com", + self.test_db, + ) + + # Should return same user + self.assertIsNotNone(user_id1) + self.assertIsNotNone(user_id2) + self.assertEqual(user_id1, user_id2) + self.assertEqual(token1, token2) + + def test_get_user_by_email(self) -> None: + """Retrieve user by email.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + user = Database.get_user_by_email("test@example.com", self.test_db) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_get_user_by_token(self) -> None: + """Retrieve user by RSS token.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + user = Database.get_user_by_token(token, self.test_db) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + + def test_get_user_by_id(self) -> None: + """Retrieve user by ID.""" + user_id, token = Database.create_user("test@example.com", self.test_db) + + user = Database.get_user_by_id(user_id, self.test_db) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_invalid_user_lookups(self) -> None: + """Verify None returned for non-existent users.""" + self.assertIsNone( + Database.get_user_by_email("nobody@example.com", self.test_db), + ) + self.assertIsNone( + Database.get_user_by_token("invalid-token", self.test_db), + ) + self.assertIsNone(Database.get_user_by_id(9999, self.test_db)) + + def test_token_uniqueness(self) -> None: + """Ensure tokens are cryptographically unique.""" + tokens = set() + for i in range(10): + _, token = Database.create_user( + f"user{i}@example.com", + self.test_db, + ) + tokens.add(token) + + # All tokens should be unique + self.assertEqual(len(tokens), 10) + + +class TestQueueOperations(Test.TestCase): + """Test queue operations.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + self.user_id, _ = Database.create_user("test@example.com", self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_add_to_queue(self) -> None: + """Add job with user association.""" + job_id = Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + self.test_db, + ) + + self.assertIsInstance(job_id, int) + self.assertGreater(job_id, 0) + + def test_get_pending_jobs(self) -> None: + """Retrieve jobs in correct order.""" + # Add multiple jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + self.test_db, + ) + time.sleep(0.01) # Ensure different timestamps + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + self.test_db, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Get pending jobs + jobs = Database.get_pending_jobs(limit=10, db_path=self.test_db) + + self.assertEqual(len(jobs), 3) + # Should be in order of creation (oldest first) + self.assertEqual(jobs[0]["id"], job1) + self.assertEqual(jobs[1]["id"], job2) + self.assertEqual(jobs[2]["id"], job3) + + def test_update_job_status(self) -> None: + """Update status and error messages.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Update to processing + Database.update_job_status(job_id, "processing", db_path=self.test_db) + job = Database.get_job_by_id(job_id, self.test_db) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "processing") + + # Update to error with message + Database.update_job_status( + job_id, + "error", + "Network timeout", + self.test_db, + ) + job = Database.get_job_by_id(job_id, self.test_db) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "error") + self.assertEqual(job["error_message"], "Network timeout") + self.assertEqual(job["retry_count"], 1) + + def test_retry_job(self) -> None: + """Reset failed jobs for retry.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Set to error + Database.update_job_status(job_id, "error", "Failed", self.test_db) + + # Retry + Database.retry_job(job_id, self.test_db) + job = Database.get_job_by_id(job_id, self.test_db) + + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertIsNone(job["error_message"]) + + def test_delete_job(self) -> None: + """Remove jobs from queue.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Delete job + Database.delete_job(job_id, self.test_db) + + # Should not exist + job = Database.get_job_by_id(job_id, self.test_db) + self.assertIsNone(job) + + def test_get_retryable_jobs(self) -> None: + """Find jobs eligible for retry.""" + # Add job and mark as error + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + self.test_db, + ) + Database.update_job_status(job_id, "error", "Failed", self.test_db) + + # Should be retryable + retryable = Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 1) + self.assertEqual(retryable[0]["id"], job_id) + + # Exceed retry limit + Database.update_job_status( + job_id, + "error", + "Failed again", + self.test_db, + ) + Database.update_job_status( + job_id, + "error", + "Failed yet again", + self.test_db, + ) + + # Should not be retryable anymore + retryable = Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 0) + + def test_user_queue_isolation(self) -> None: + """Ensure users only see their own jobs.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com", self.test_db) + + # Add jobs for both users + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + self.test_db, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "user2@example.com", + user2_id, + self.test_db, + ) + + # Get user-specific queue status + user1_jobs = Database.get_user_queue_status(self.user_id, self.test_db) + user2_jobs = Database.get_user_queue_status(user2_id, self.test_db) + + self.assertEqual(len(user1_jobs), 1) + self.assertEqual(user1_jobs[0]["id"], job1) + + self.assertEqual(len(user2_jobs), 1) + self.assertEqual(user2_jobs[0]["id"], job2) + + def test_status_counts(self) -> None: + """Verify status aggregation queries.""" + # Add jobs with different statuses + Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + self.test_db, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + self.test_db, + ) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + self.test_db, + ) + + Database.update_job_status(job2, "processing", db_path=self.test_db) + Database.update_job_status(job3, "error", "Failed", self.test_db) + + # Get status counts + counts = Database.get_user_status_counts(self.user_id, self.test_db) + + self.assertEqual(counts.get("pending", 0), 1) + self.assertEqual(counts.get("processing", 0), 1) + self.assertEqual(counts.get("error", 0), 1) + + +class TestEpisodeManagement(Test.TestCase): + """Test episode management functionality.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + self.test_db = "test_podcast.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Database.init_db(self.test_db) + self.user_id, _ = Database.create_user("test@example.com", self.test_db) + + def tearDown(self) -> None: + """Clean up test database.""" + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_create_episode(self) -> None: + """Create episode with user association.""" + episode_id = Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + db_path=self.test_db, + ) + + self.assertIsInstance(episode_id, int) + self.assertGreater(episode_id, 0) + + def test_get_recent_episodes(self) -> None: + """Retrieve episodes in reverse chronological order.""" + # Create multiple episodes + ep1 = Database.create_episode( + "Article 1", + "url1", + 100, + 1000, + self.user_id, + self.test_db, + ) + time.sleep(0.01) + ep2 = Database.create_episode( + "Article 2", + "url2", + 200, + 2000, + self.user_id, + self.test_db, + ) + time.sleep(0.01) + ep3 = Database.create_episode( + "Article 3", + "url3", + 300, + 3000, + self.user_id, + self.test_db, + ) + + # Get recent episodes + episodes = Database.get_recent_episodes(limit=10, db_path=self.test_db) + + self.assertEqual(len(episodes), 3) + # Should be in reverse chronological order + self.assertEqual(episodes[0]["id"], ep3) + self.assertEqual(episodes[1]["id"], ep2) + self.assertEqual(episodes[2]["id"], ep1) + + def test_get_user_episodes(self) -> None: + """Ensure user isolation for episodes.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com", self.test_db) + + # Create episodes for both users + ep1 = Database.create_episode( + "User1 Article", + "url1", + 100, + 1000, + self.user_id, + self.test_db, + ) + ep2 = Database.create_episode( + "User2 Article", + "url2", + 200, + 2000, + user2_id, + self.test_db, + ) + + # Get user-specific episodes + user1_episodes = Database.get_user_all_episodes( + self.user_id, + self.test_db, + ) + user2_episodes = Database.get_user_all_episodes(user2_id, self.test_db) + + self.assertEqual(len(user1_episodes), 1) + self.assertEqual(user1_episodes[0]["id"], ep1) + + self.assertEqual(len(user2_episodes), 1) + self.assertEqual(user2_episodes[0]["id"], ep2) + + def test_episode_metadata(self) -> None: + """Verify duration and content_length storage.""" + Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=12345, + content_length=98765, + user_id=self.user_id, + db_path=self.test_db, + ) + + episodes = Database.get_user_all_episodes(self.user_id, self.test_db) + episode = episodes[0] + + self.assertEqual(episode["duration"], 12345) + self.assertEqual(episode["content_length"], 98765) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestDatabase, + TestUserManagement, + TestQueueOperations, + TestEpisodeManagement, + ], + ) + + +def main() -> None: + """Run all PodcastItLater.Core tests.""" + if "test" in sys.argv: + test() |
