"""Core, shared logic for PodcastItalater. Includes: - Database models - Data access layer - Shared types """ # : out podcastitlater-core # : dep pytest # : dep pytest-asyncio # : dep pytest-mock import logging import Omni.App as App import Omni.Test as Test import os import pathlib import pytest import secrets import sqlite3 import sys import time import typing from collections.abc import Iterator from contextlib import contextmanager from typing import Any logger = logging.getLogger(__name__) CODEROOT = pathlib.Path(os.getenv("CODEROOT", ".")) DATA_DIR = pathlib.Path( os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"), ) # Constants for UI display URL_TRUNCATE_LENGTH = 80 TITLE_TRUNCATE_LENGTH = 50 ERROR_TRUNCATE_LENGTH = 50 # Admin whitelist ADMIN_EMAILS = ["ben@bensima.com"] def is_admin(user: dict[str, typing.Any] | None) -> bool: """Check if user is an admin based on email whitelist.""" if not user: return False return user.get("email", "").lower() in [ email.lower() for email in ADMIN_EMAILS ] class Database: # noqa: PLR0904 """Data access layer for PodcastItLater database operations.""" @staticmethod def teardown() -> None: """Delete the existing database, for cleanup after tests.""" db_path = DATA_DIR / "podcast.db" if db_path.exists(): db_path.unlink() @staticmethod @contextmanager def get_connection() -> Iterator[sqlite3.Connection]: """Context manager for database connections. Yields: sqlite3.Connection: Database connection with row factory set. """ db_path = DATA_DIR / "podcast.db" conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row try: yield conn finally: conn.close() @staticmethod def init_db() -> None: """Initialize database with required tables.""" with Database.get_connection() as conn: cursor = conn.cursor() # Queue table for job processing cursor.execute(""" CREATE TABLE IF NOT EXISTS queue ( id INTEGER PRIMARY KEY AUTOINCREMENT, url TEXT, email TEXT, status TEXT DEFAULT 'pending', retry_count INTEGER DEFAULT 0, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, error_message TEXT, title TEXT, author TEXT ) """) # Episodes table for completed podcasts cursor.execute(""" CREATE TABLE IF NOT EXISTS episodes ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, content_length INTEGER, audio_url TEXT NOT NULL, duration INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Create indexes for performance cursor.execute( "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)", ) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_queue_created " "ON queue(created_at)", ) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_episodes_created " "ON episodes(created_at)", ) conn.commit() logger.info("Database initialized successfully") # Run migration to add user support Database.migrate_to_multi_user() # Run migration to add metadata fields Database.migrate_add_metadata_fields() # Run migration to add episode metadata fields Database.migrate_add_episode_metadata() # Run migration to add user status field Database.migrate_add_user_status() # Run migration to add default titles Database.migrate_add_default_titles() # Run migration to add billing fields Database.migrate_add_billing_fields() # Run migration to add stripe events table Database.migrate_add_stripe_events_table() @staticmethod def add_to_queue( url: str, email: str, user_id: int, title: str | None = None, author: str | None = None, ) -> int: """Insert new job into queue with metadata, return job ID. Raises: ValueError: If job ID cannot be retrieved after insert. """ with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "INSERT INTO queue (url, email, user_id, title, author) " "VALUES (?, ?, ?, ?, ?)", (url, email, user_id, title, author), ) conn.commit() job_id = cursor.lastrowid if job_id is None: msg = "Failed to get job ID after insert" raise ValueError(msg) logger.info("Added job %s to queue: %s", job_id, url) return job_id @staticmethod def get_pending_jobs( limit: int = 10, ) -> list[dict[str, Any]]: """Fetch jobs with status='pending' ordered by creation time.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM queue WHERE status = 'pending' " "ORDER BY created_at ASC LIMIT ?", (limit,), ) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def update_job_status( job_id: int, status: str, error: str | None = None, ) -> None: """Update job status and error message.""" with Database.get_connection() as conn: cursor = conn.cursor() if status == "error": cursor.execute( "UPDATE queue SET status = ?, error_message = ?, " "retry_count = retry_count + 1 WHERE id = ?", (status, error, job_id), ) else: cursor.execute( "UPDATE queue SET status = ? WHERE id = ?", (status, job_id), ) conn.commit() logger.info("Updated job %s status to %s", job_id, status) @staticmethod def get_job_by_id( job_id: int, ) -> dict[str, Any] | None: """Fetch single job by ID.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) row = cursor.fetchone() return dict(row) if row is not None else None @staticmethod def create_episode( # noqa: PLR0913, PLR0917 title: str, audio_url: str, duration: int, content_length: int, user_id: int | None = None, author: str | None = None, original_url: str | None = None, ) -> int: """Insert episode record, return episode ID. Raises: ValueError: If episode ID cannot be retrieved after insert. """ with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "INSERT INTO episodes " "(title, audio_url, duration, content_length, user_id, author, " "original_url) VALUES (?, ?, ?, ?, ?, ?, ?)", ( title, audio_url, duration, content_length, user_id, author, original_url, ), ) conn.commit() episode_id = cursor.lastrowid if episode_id is None: msg = "Failed to get episode ID after insert" raise ValueError(msg) logger.info("Created episode %s: %s", episode_id, title) return episode_id @staticmethod def get_recent_episodes( limit: int = 20, ) -> list[dict[str, Any]]: """Get recent episodes for RSS feed generation.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", (limit,), ) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def get_queue_status_summary() -> dict[str, Any]: """Get queue status summary for web interface.""" with Database.get_connection() as conn: cursor = conn.cursor() # Count jobs by status cursor.execute( "SELECT status, COUNT(*) as count FROM queue GROUP BY status", ) rows = cursor.fetchall() status_counts = {row["status"]: row["count"] for row in rows} # Get recent jobs cursor.execute( "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10", ) rows = cursor.fetchall() recent_jobs = [dict(row) for row in rows] return {"status_counts": status_counts, "recent_jobs": recent_jobs} @staticmethod def get_queue_status() -> list[dict[str, Any]]: """Return pending/processing/error items for web interface.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT id, url, email, status, created_at, error_message, title, author FROM queue WHERE status IN ('pending', 'processing', 'error') ORDER BY created_at DESC LIMIT 20 """) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def get_all_episodes( user_id: int | None = None, ) -> list[dict[str, Any]]: """Return all episodes for RSS feed.""" with Database.get_connection() as conn: cursor = conn.cursor() if user_id: cursor.execute( """ SELECT id, title, audio_url, duration, created_at, content_length, author, original_url FROM episodes WHERE user_id = ? ORDER BY created_at DESC """, (user_id,), ) else: cursor.execute(""" SELECT id, title, audio_url, duration, created_at, content_length, author, original_url FROM episodes ORDER BY created_at DESC """) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def get_retryable_jobs( max_retries: int = 3, ) -> list[dict[str, Any]]: """Get failed jobs that can be retried.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM queue WHERE status = 'error' " "AND retry_count < ? ORDER BY created_at ASC", (max_retries,), ) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def retry_job(job_id: int) -> None: """Reset a job to pending status for retry.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "UPDATE queue SET status = 'pending', " "error_message = NULL WHERE id = ?", (job_id,), ) conn.commit() logger.info("Reset job %s to pending for retry", job_id) @staticmethod def delete_job(job_id: int) -> None: """Delete a job from the queue.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) conn.commit() logger.info("Deleted job %s from queue", job_id) @staticmethod def get_all_queue_items( user_id: int | None = None, ) -> list[dict[str, Any]]: """Return all queue items for admin view.""" with Database.get_connection() as conn: cursor = conn.cursor() if user_id: cursor.execute( """ SELECT id, url, email, status, retry_count, created_at, error_message, title, author FROM queue WHERE user_id = ? ORDER BY created_at DESC """, (user_id,), ) else: cursor.execute(""" SELECT id, url, email, status, retry_count, created_at, error_message, title, author FROM queue ORDER BY created_at DESC """) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def get_status_counts() -> dict[str, int]: """Get count of queue items by status.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT status, COUNT(*) as count FROM queue GROUP BY status """) rows = cursor.fetchall() return {row["status"]: row["count"] for row in rows} @staticmethod def get_user_status_counts( user_id: int, ) -> dict[str, int]: """Get count of queue items by status for a specific user.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ SELECT status, COUNT(*) as count FROM queue WHERE user_id = ? GROUP BY status """, (user_id,), ) rows = cursor.fetchall() return {row["status"]: row["count"] for row in rows} @staticmethod def migrate_to_multi_user() -> None: """Migrate database to support multiple users.""" with Database.get_connection() as conn: cursor = conn.cursor() # Create users table cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT UNIQUE NOT NULL, token TEXT UNIQUE NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Add user_id columns to existing tables # Check if columns already exist to make migration idempotent cursor.execute("PRAGMA table_info(queue)") queue_info = cursor.fetchall() queue_columns = [col[1] for col in queue_info] if "user_id" not in queue_columns: cursor.execute( "ALTER TABLE queue ADD COLUMN user_id INTEGER " "REFERENCES users(id)", ) cursor.execute("PRAGMA table_info(episodes)") episodes_info = cursor.fetchall() episodes_columns = [col[1] for col in episodes_info] if "user_id" not in episodes_columns: cursor.execute( "ALTER TABLE episodes ADD COLUMN user_id INTEGER " "REFERENCES users(id)", ) # Create indexes cursor.execute( "CREATE INDEX IF NOT EXISTS idx_queue_user_id " "ON queue(user_id)", ) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_episodes_user_id " "ON episodes(user_id)", ) conn.commit() logger.info("Database migrated to support multiple users") @staticmethod def migrate_add_metadata_fields() -> None: """Add title and author fields to queue table.""" with Database.get_connection() as conn: cursor = conn.cursor() # Check if columns already exist cursor.execute("PRAGMA table_info(queue)") queue_info = cursor.fetchall() queue_columns = [col[1] for col in queue_info] if "title" not in queue_columns: cursor.execute("ALTER TABLE queue ADD COLUMN title TEXT") if "author" not in queue_columns: cursor.execute("ALTER TABLE queue ADD COLUMN author TEXT") conn.commit() logger.info("Database migrated to support metadata fields") @staticmethod def migrate_add_episode_metadata() -> None: """Add author and original_url fields to episodes table.""" with Database.get_connection() as conn: cursor = conn.cursor() # Check if columns already exist cursor.execute("PRAGMA table_info(episodes)") episodes_info = cursor.fetchall() episodes_columns = [col[1] for col in episodes_info] if "author" not in episodes_columns: cursor.execute("ALTER TABLE episodes ADD COLUMN author TEXT") if "original_url" not in episodes_columns: cursor.execute( "ALTER TABLE episodes ADD COLUMN original_url TEXT", ) conn.commit() logger.info("Database migrated to support episode metadata fields") @staticmethod def migrate_add_user_status() -> None: """Add status field to users table.""" with Database.get_connection() as conn: cursor = conn.cursor() # Check if column already exists cursor.execute("PRAGMA table_info(users)") users_info = cursor.fetchall() users_columns = [col[1] for col in users_info] if "status" not in users_columns: # Add status column with default 'active' cursor.execute( "ALTER TABLE users ADD COLUMN status TEXT DEFAULT 'active'", ) conn.commit() logger.info("Database migrated to support user status") @staticmethod def migrate_add_billing_fields() -> None: """Add billing-related fields to users table.""" with Database.get_connection() as conn: cursor = conn.cursor() # Add columns one by one (SQLite limitation) # Note: SQLite ALTER TABLE doesn't support adding UNIQUE constraints # We add them without UNIQUE and rely on application logic columns_to_add = [ ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"), ("stripe_customer_id", "TEXT"), ("stripe_subscription_id", "TEXT"), ("subscription_status", "TEXT"), ("current_period_start", "TIMESTAMP"), ("current_period_end", "TIMESTAMP"), ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"), ] for column_name, column_def in columns_to_add: try: query = f"ALTER TABLE users ADD COLUMN {column_name} " cursor.execute(query + column_def) logger.info("Added column users.%s", column_name) except sqlite3.OperationalError as e: # noqa: PERF203 # Column already exists, skip logger.debug( "Column users.%s already exists: %s", column_name, e, ) conn.commit() @staticmethod def migrate_add_stripe_events_table() -> None: """Create stripe_events table for webhook idempotency.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS stripe_events ( id TEXT PRIMARY KEY, type TEXT NOT NULL, payload TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_stripe_events_created " "ON stripe_events(created_at)", ) conn.commit() logger.info("Created stripe_events table") @staticmethod def migrate_add_default_titles() -> None: """Add default titles to queue items that have None titles.""" with Database.get_connection() as conn: cursor = conn.cursor() # Update queue items with NULL titles to have a default cursor.execute(""" UPDATE queue SET title = 'Untitled Article' WHERE title IS NULL """) # Get count of updated rows updated_count = cursor.rowcount conn.commit() logger.info( "Updated %s queue items with default titles", updated_count, ) @staticmethod def create_user(email: str, status: str = "active") -> tuple[int, str]: """Create a new user and return (user_id, token). Args: email: User email address status: Initial status (active or disabled) Raises: ValueError: If user ID cannot be retrieved after insert or if user not found, or if status is invalid. """ if status not in {"pending", "active", "disabled"}: msg = f"Invalid status: {status}" raise ValueError(msg) # Generate a secure token for RSS feed access token = secrets.token_urlsafe(32) with Database.get_connection() as conn: cursor = conn.cursor() try: cursor.execute( "INSERT INTO users (email, token, status) VALUES (?, ?, ?)", (email, token, status), ) conn.commit() user_id = cursor.lastrowid if user_id is None: msg = "Failed to get user ID after insert" raise ValueError(msg) logger.info( "Created user %s with email %s (status: %s)", user_id, email, status, ) except sqlite3.IntegrityError: # User already exists cursor.execute( "SELECT id, token FROM users WHERE email = ?", (email,), ) row = cursor.fetchone() if row is None: msg = f"User with email {email} not found" raise ValueError(msg) from None return int(row["id"]), str(row["token"]) else: return int(user_id), str(token) @staticmethod def get_user_by_email( email: str, ) -> dict[str, Any] | None: """Get user by email address.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) row = cursor.fetchone() return dict(row) if row is not None else None @staticmethod def get_user_by_token( token: str, ) -> dict[str, Any] | None: """Get user by RSS token.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) row = cursor.fetchone() return dict(row) if row is not None else None @staticmethod def get_user_by_id( user_id: int, ) -> dict[str, Any] | None: """Get user by ID.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) row = cursor.fetchone() return dict(row) if row is not None else None @staticmethod def get_user_queue_status( user_id: int, ) -> list[dict[str, Any]]: """Return pending/processing/error items for a specific user.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ SELECT id, url, email, status, created_at, error_message, title, author FROM queue WHERE user_id = ? AND status IN ('pending', 'processing', 'error') ORDER BY created_at DESC LIMIT 20 """, (user_id,), ) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def get_user_recent_episodes( user_id: int, limit: int = 20, ) -> list[dict[str, Any]]: """Get recent episodes for a specific user.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM episodes WHERE user_id = ? " "ORDER BY created_at DESC LIMIT ?", (user_id, limit), ) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def get_user_all_episodes( user_id: int, ) -> list[dict[str, Any]]: """Get all episodes for a specific user for RSS feed.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM episodes WHERE user_id = ? " "ORDER BY created_at DESC", (user_id,), ) rows = cursor.fetchall() return [dict(row) for row in rows] @staticmethod def update_user_status( user_id: int, status: str, ) -> None: """Update user account status.""" if status not in {"pending", "active", "disabled"}: msg = f"Invalid status: {status}" raise ValueError(msg) with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "UPDATE users SET status = ? WHERE id = ?", (status, user_id), ) conn.commit() logger.info("Updated user %s status to %s", user_id, status) @staticmethod def set_user_stripe_customer(user_id: int, customer_id: str) -> None: """Link Stripe customer ID to user.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "UPDATE users SET stripe_customer_id = ? WHERE id = ?", (customer_id, user_id), ) conn.commit() logger.info( "Linked user %s to Stripe customer %s", user_id, customer_id, ) @staticmethod def update_user_subscription( # noqa: PLR0913, PLR0917 user_id: int, subscription_id: str, status: str, period_start: Any, period_end: Any, tier: str, cancel_at_period_end: bool, # noqa: FBT001 ) -> None: """Update user subscription details.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ UPDATE users SET stripe_subscription_id = ?, subscription_status = ?, current_period_start = ?, current_period_end = ?, plan_tier = ?, cancel_at_period_end = ? WHERE id = ? """, ( subscription_id, status, period_start.isoformat(), period_end.isoformat(), tier, 1 if cancel_at_period_end else 0, user_id, ), ) conn.commit() logger.info( "Updated user %s subscription: tier=%s, status=%s", user_id, tier, status, ) @staticmethod def update_subscription_status(user_id: int, status: str) -> None: """Update only the subscription status (e.g., past_due).""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "UPDATE users SET subscription_status = ? WHERE id = ?", (status, user_id), ) conn.commit() logger.info( "Updated user %s subscription status to %s", user_id, status, ) @staticmethod def downgrade_to_free(user_id: int) -> None: """Downgrade user to free tier and clear subscription data.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ UPDATE users SET plan_tier = 'free', subscription_status = 'canceled', stripe_subscription_id = NULL, current_period_start = NULL, current_period_end = NULL, cancel_at_period_end = 0 WHERE id = ? """, (user_id,), ) conn.commit() logger.info("Downgraded user %s to free tier", user_id) @staticmethod def get_user_by_stripe_customer_id( customer_id: str, ) -> dict[str, Any] | None: """Get user by Stripe customer ID.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM users WHERE stripe_customer_id = ?", (customer_id,), ) row = cursor.fetchone() return dict(row) if row is not None else None @staticmethod def has_processed_stripe_event(event_id: str) -> bool: """Check if Stripe event has already been processed.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT id FROM stripe_events WHERE id = ?", (event_id,), ) return cursor.fetchone() is not None @staticmethod def mark_stripe_event_processed( event_id: str, event_type: str, payload: bytes, ) -> None: """Mark Stripe event as processed for idempotency.""" with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "INSERT OR IGNORE INTO stripe_events (id, type, payload) " "VALUES (?, ?, ?)", (event_id, event_type, payload.decode("utf-8")), ) conn.commit() @staticmethod def get_usage( user_id: int, period_start: Any, period_end: Any, ) -> dict[str, int]: """Get usage stats for user in period. Returns: dict with keys: articles (int), minutes (int) """ with Database.get_connection() as conn: cursor = conn.cursor() # Count articles created in period cursor.execute( """ SELECT COUNT(*) as count, SUM(duration) as total_seconds FROM episodes WHERE user_id = ? AND created_at >= ? AND created_at < ? """, (user_id, period_start.isoformat(), period_end.isoformat()), ) row = cursor.fetchone() articles = row["count"] if row else 0 total_seconds = ( row["total_seconds"] if row and row["total_seconds"] else 0 ) minutes = total_seconds // 60 return {"articles": articles, "minutes": minutes} class TestDatabase(Test.TestCase): """Test the Database class.""" @staticmethod def setUp() -> None: """Set up test database.""" Database.init_db() def tearDown(self) -> None: """Clean up test database.""" Database.teardown() # Clear user ID self.user_id = None def test_init_db(self) -> None: """Verify all tables and indexes are created correctly.""" with Database.get_connection() as conn: cursor = conn.cursor() # Check tables exist cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = {row[0] for row in cursor.fetchall()} self.assertIn("queue", tables) self.assertIn("episodes", tables) self.assertIn("users", tables) # Check indexes exist cursor.execute("SELECT name FROM sqlite_master WHERE type='index'") indexes = {row[0] for row in cursor.fetchall()} self.assertIn("idx_queue_status", indexes) self.assertIn("idx_queue_created", indexes) self.assertIn("idx_episodes_created", indexes) self.assertIn("idx_queue_user_id", indexes) self.assertIn("idx_episodes_user_id", indexes) def test_connection_context_manager(self) -> None: """Ensure connections are properly closed.""" # Get a connection and verify it works with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT 1") result = cursor.fetchone() self.assertEqual(result[0], 1) # Connection should be closed after context manager with pytest.raises(sqlite3.ProgrammingError): cursor.execute("SELECT 1") def test_migration_idempotency(self) -> None: """Verify migrations can run multiple times safely.""" # Run migration multiple times Database.migrate_to_multi_user() Database.migrate_to_multi_user() Database.migrate_to_multi_user() # Should still work fine with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users") # Should not raise an error # Test completed successfully - migration worked self.assertIsNotNone(conn) class TestUserManagement(Test.TestCase): """Test user management functionality.""" @staticmethod def setUp() -> None: """Set up test database.""" Database.init_db() @staticmethod def tearDown() -> None: """Clean up test database.""" Database.teardown() def test_create_user(self) -> None: """Create user with unique email and token.""" user_id, token = Database.create_user("test@example.com") self.assertIsInstance(user_id, int) self.assertIsInstance(token, str) self.assertGreater(len(token), 20) # Should be a secure token def test_create_duplicate_user(self) -> None: """Verify duplicate emails return existing user.""" # Create first user user_id1, token1 = Database.create_user( "test@example.com", ) # Try to create duplicate user_id2, token2 = Database.create_user( "test@example.com", ) # Should return same user self.assertIsNotNone(user_id1) self.assertIsNotNone(user_id2) self.assertEqual(user_id1, user_id2) self.assertEqual(token1, token2) def test_get_user_by_email(self) -> None: """Retrieve user by email.""" user_id, token = Database.create_user("test@example.com") user = Database.get_user_by_email("test@example.com") self.assertIsNotNone(user) if user is None: self.fail("User should not be None") self.assertEqual(user["id"], user_id) self.assertEqual(user["email"], "test@example.com") self.assertEqual(user["token"], token) def test_get_user_by_token(self) -> None: """Retrieve user by RSS token.""" user_id, token = Database.create_user("test@example.com") user = Database.get_user_by_token(token) self.assertIsNotNone(user) if user is None: self.fail("User should not be None") self.assertEqual(user["id"], user_id) self.assertEqual(user["email"], "test@example.com") def test_get_user_by_id(self) -> None: """Retrieve user by ID.""" user_id, token = Database.create_user("test@example.com") user = Database.get_user_by_id(user_id) self.assertIsNotNone(user) if user is None: self.fail("User should not be None") self.assertEqual(user["email"], "test@example.com") self.assertEqual(user["token"], token) def test_invalid_user_lookups(self) -> None: """Verify None returned for non-existent users.""" self.assertIsNone( Database.get_user_by_email("nobody@example.com"), ) self.assertIsNone( Database.get_user_by_token("invalid-token"), ) self.assertIsNone(Database.get_user_by_id(9999)) def test_token_uniqueness(self) -> None: """Ensure tokens are cryptographically unique.""" tokens = set() for i in range(10): _, token = Database.create_user( f"user{i}@example.com", ) tokens.add(token) # All tokens should be unique self.assertEqual(len(tokens), 10) class TestQueueOperations(Test.TestCase): """Test queue operations.""" def setUp(self) -> None: """Set up test database with a user.""" Database.init_db() self.user_id, _ = Database.create_user("test@example.com") @staticmethod def tearDown() -> None: """Clean up test database.""" Database.teardown() def test_add_to_queue(self) -> None: """Add job with user association.""" job_id = Database.add_to_queue( "https://example.com/article", "test@example.com", self.user_id, ) self.assertIsInstance(job_id, int) self.assertGreater(job_id, 0) def test_get_pending_jobs(self) -> None: """Retrieve jobs in correct order.""" # Add multiple jobs job1 = Database.add_to_queue( "https://example.com/1", "test@example.com", self.user_id, ) time.sleep(0.01) # Ensure different timestamps job2 = Database.add_to_queue( "https://example.com/2", "test@example.com", self.user_id, ) time.sleep(0.01) job3 = Database.add_to_queue( "https://example.com/3", "test@example.com", self.user_id, ) # Get pending jobs jobs = Database.get_pending_jobs(limit=10) self.assertEqual(len(jobs), 3) # Should be in order of creation (oldest first) self.assertEqual(jobs[0]["id"], job1) self.assertEqual(jobs[1]["id"], job2) self.assertEqual(jobs[2]["id"], job3) def test_update_job_status(self) -> None: """Update status and error messages.""" job_id = Database.add_to_queue( "https://example.com", "test@example.com", self.user_id, ) # Update to processing Database.update_job_status(job_id, "processing") job = Database.get_job_by_id(job_id) self.assertIsNotNone(job) if job is None: self.fail("Job should not be None") self.assertEqual(job["status"], "processing") # Update to error with message Database.update_job_status( job_id, "error", "Network timeout", ) job = Database.get_job_by_id(job_id) self.assertIsNotNone(job) if job is None: self.fail("Job should not be None") self.assertEqual(job["status"], "error") self.assertEqual(job["error_message"], "Network timeout") self.assertEqual(job["retry_count"], 1) def test_retry_job(self) -> None: """Reset failed jobs for retry.""" job_id = Database.add_to_queue( "https://example.com", "test@example.com", self.user_id, ) # Set to error Database.update_job_status(job_id, "error", "Failed") # Retry Database.retry_job(job_id) job = Database.get_job_by_id(job_id) self.assertIsNotNone(job) if job is None: self.fail("Job should not be None") self.assertEqual(job["status"], "pending") self.assertIsNone(job["error_message"]) def test_delete_job(self) -> None: """Remove jobs from queue.""" job_id = Database.add_to_queue( "https://example.com", "test@example.com", self.user_id, ) # Delete job Database.delete_job(job_id) # Should not exist job = Database.get_job_by_id(job_id) self.assertIsNone(job) def test_get_retryable_jobs(self) -> None: """Find jobs eligible for retry.""" # Add job and mark as error job_id = Database.add_to_queue( "https://example.com", "test@example.com", self.user_id, ) Database.update_job_status(job_id, "error", "Failed") # Should be retryable retryable = Database.get_retryable_jobs( max_retries=3, ) self.assertEqual(len(retryable), 1) self.assertEqual(retryable[0]["id"], job_id) # Exceed retry limit Database.update_job_status( job_id, "error", "Failed again", ) Database.update_job_status( job_id, "error", "Failed yet again", ) # Should not be retryable anymore retryable = Database.get_retryable_jobs( max_retries=3, ) self.assertEqual(len(retryable), 0) def test_user_queue_isolation(self) -> None: """Ensure users only see their own jobs.""" # Create second user user2_id, _ = Database.create_user("user2@example.com") # Add jobs for both users job1 = Database.add_to_queue( "https://example.com/1", "test@example.com", self.user_id, ) job2 = Database.add_to_queue( "https://example.com/2", "user2@example.com", user2_id, ) # Get user-specific queue status user1_jobs = Database.get_user_queue_status(self.user_id) user2_jobs = Database.get_user_queue_status(user2_id) self.assertEqual(len(user1_jobs), 1) self.assertEqual(user1_jobs[0]["id"], job1) self.assertEqual(len(user2_jobs), 1) self.assertEqual(user2_jobs[0]["id"], job2) def test_status_counts(self) -> None: """Verify status aggregation queries.""" # Add jobs with different statuses Database.add_to_queue( "https://example.com/1", "test@example.com", self.user_id, ) job2 = Database.add_to_queue( "https://example.com/2", "test@example.com", self.user_id, ) job3 = Database.add_to_queue( "https://example.com/3", "test@example.com", self.user_id, ) Database.update_job_status(job2, "processing") Database.update_job_status(job3, "error", "Failed") # Get status counts counts = Database.get_user_status_counts(self.user_id) self.assertEqual(counts.get("pending", 0), 1) self.assertEqual(counts.get("processing", 0), 1) self.assertEqual(counts.get("error", 0), 1) class TestEpisodeManagement(Test.TestCase): """Test episode management functionality.""" def setUp(self) -> None: """Set up test database with a user.""" Database.init_db() self.user_id, _ = Database.create_user("test@example.com") @staticmethod def tearDown() -> None: """Clean up test database.""" Database.teardown() def test_create_episode(self) -> None: """Create episode with user association.""" episode_id = Database.create_episode( title="Test Article", audio_url="https://example.com/audio.mp3", duration=300, content_length=5000, user_id=self.user_id, ) self.assertIsInstance(episode_id, int) self.assertGreater(episode_id, 0) def test_get_recent_episodes(self) -> None: """Retrieve episodes in reverse chronological order.""" # Create multiple episodes ep1 = Database.create_episode( "Article 1", "url1", 100, 1000, self.user_id, ) time.sleep(0.01) ep2 = Database.create_episode( "Article 2", "url2", 200, 2000, self.user_id, ) time.sleep(0.01) ep3 = Database.create_episode( "Article 3", "url3", 300, 3000, self.user_id, ) # Get recent episodes episodes = Database.get_recent_episodes(limit=10) self.assertEqual(len(episodes), 3) # Should be in reverse chronological order self.assertEqual(episodes[0]["id"], ep3) self.assertEqual(episodes[1]["id"], ep2) self.assertEqual(episodes[2]["id"], ep1) def test_get_user_episodes(self) -> None: """Ensure user isolation for episodes.""" # Create second user user2_id, _ = Database.create_user("user2@example.com") # Create episodes for both users ep1 = Database.create_episode( "User1 Article", "url1", 100, 1000, self.user_id, ) ep2 = Database.create_episode( "User2 Article", "url2", 200, 2000, user2_id, ) # Get user-specific episodes user1_episodes = Database.get_user_all_episodes( self.user_id, ) user2_episodes = Database.get_user_all_episodes(user2_id) self.assertEqual(len(user1_episodes), 1) self.assertEqual(user1_episodes[0]["id"], ep1) self.assertEqual(len(user2_episodes), 1) self.assertEqual(user2_episodes[0]["id"], ep2) def test_episode_metadata(self) -> None: """Verify duration and content_length storage.""" Database.create_episode( title="Test Article", audio_url="https://example.com/audio.mp3", duration=12345, content_length=98765, user_id=self.user_id, ) episodes = Database.get_user_all_episodes(self.user_id) episode = episodes[0] self.assertEqual(episode["duration"], 12345) self.assertEqual(episode["content_length"], 98765) def test() -> None: """Run the tests.""" Test.run( App.Area.Test, [ TestDatabase, TestUserManagement, TestQueueOperations, TestEpisodeManagement, ], ) def main() -> None: """Run all PodcastItLater.Core tests.""" if "test" in sys.argv: test()