diff options
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 2174 |
1 files changed, 2174 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py new file mode 100644 index 0000000..3a88f22 --- /dev/null +++ b/Biz/PodcastItLater/Core.py @@ -0,0 +1,2174 @@ +"""Core, shared logic for PodcastItalater. + +Includes: +- Database models +- Data access layer +- Shared types +""" + +# : out podcastitlater-core +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import hashlib +import logging +import Omni.App as App +import Omni.Test as Test +import os +import pathlib +import pytest +import secrets +import sqlite3 +import sys +import time +import typing +import urllib.parse +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = logging.getLogger(__name__) + + +CODEROOT = pathlib.Path(os.getenv("CODEROOT", ".")) +DATA_DIR = pathlib.Path( + os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"), +) + +# Constants for UI display +URL_TRUNCATE_LENGTH = 80 +TITLE_TRUNCATE_LENGTH = 50 +ERROR_TRUNCATE_LENGTH = 50 + +# Admin whitelist +ADMIN_EMAILS = ["ben@bensima.com", "admin@example.com"] + + +def is_admin(user: dict[str, typing.Any] | None) -> bool: + """Check if user is an admin based on email whitelist.""" + if not user: + return False + return user.get("email", "").lower() in [ + email.lower() for email in ADMIN_EMAILS + ] + + +def normalize_url(url: str) -> str: + """Normalize URL for comparison and hashing. + + Normalizes: + - Protocol (http/https) + - Domain case (lowercase) + - www prefix (removed) + - Trailing slash (removed) + - Preserves query params and fragments as they may be meaningful + + Args: + url: URL to normalize + + Returns: + Normalized URL string + """ + parsed = urllib.parse.urlparse(url.strip()) + + # Normalize domain to lowercase, remove www prefix + domain = parsed.netloc.lower() + domain = domain.removeprefix("www.") + + # Normalize path - remove trailing slash unless it's the root + path = parsed.path.rstrip("/") if parsed.path != "/" else "/" + + # Rebuild URL with normalized components + # Use https as the canonical protocol + return urllib.parse.urlunparse(( + "https", # Always use https + domain, + path, + parsed.params, + parsed.query, + parsed.fragment, + )) + + +def hash_url(url: str) -> str: + """Generate a hash of a URL for deduplication. + + Args: + url: URL to hash + + Returns: + SHA256 hash of the normalized URL + """ + normalized = normalize_url(url) + return hashlib.sha256(normalized.encode()).hexdigest() + + +class Database: # noqa: PLR0904 + """Data access layer for PodcastItLater database operations.""" + + @staticmethod + def teardown() -> None: + """Delete the existing database, for cleanup after tests.""" + db_path = DATA_DIR / "podcast.db" + if db_path.exists(): + db_path.unlink() + + @staticmethod + @contextmanager + def get_connection() -> Iterator[sqlite3.Connection]: + """Context manager for database connections. + + Yields: + sqlite3.Connection: Database connection with row factory set. + """ + db_path = DATA_DIR / "podcast.db" + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + @staticmethod + def init_db() -> None: + """Initialize database with required tables.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Queue table for job processing + cursor.execute(""" + CREATE TABLE IF NOT EXISTS queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + email TEXT, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + error_message TEXT, + title TEXT, + author TEXT + ) + """) + + # Episodes table for completed podcasts + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + content_length INTEGER, + audio_url TEXT NOT NULL, + duration INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes for performance + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_created " + "ON queue(created_at)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_created " + "ON episodes(created_at)", + ) + + conn.commit() + logger.info("Database initialized successfully") + + # Run migration to add user support + Database.migrate_to_multi_user() + + # Run migration to add metadata fields + Database.migrate_add_metadata_fields() + + # Run migration to add episode metadata fields + Database.migrate_add_episode_metadata() + + # Run migration to add user status field + Database.migrate_add_user_status() + + # Run migration to add default titles + Database.migrate_add_default_titles() + + # Run migration to add billing fields + Database.migrate_add_billing_fields() + + # Run migration to add stripe events table + Database.migrate_add_stripe_events_table() + + # Run migration to add public feed features + Database.migrate_add_public_feed() + + @staticmethod + def add_to_queue( + url: str, + email: str, + user_id: int, + title: str | None = None, + author: str | None = None, + ) -> int: + """Insert new job into queue with metadata, return job ID. + + Raises: + ValueError: If job ID cannot be retrieved after insert. + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO queue (url, email, user_id, title, author) " + "VALUES (?, ?, ?, ?, ?)", + (url, email, user_id, title, author), + ) + conn.commit() + job_id = cursor.lastrowid + if job_id is None: + msg = "Failed to get job ID after insert" + raise ValueError(msg) + logger.info("Added job %s to queue: %s", job_id, url) + return job_id + + @staticmethod + def get_pending_jobs( + limit: int = 10, + ) -> list[dict[str, Any]]: + """Fetch jobs with status='pending' ordered by creation time.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'pending' " + "ORDER BY created_at ASC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_job_status( + job_id: int, + status: str, + error: str | None = None, + ) -> None: + """Update job status and error message.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if error is not None: + if status == "error": + cursor.execute( + "UPDATE queue SET status = ?, error_message = ?, " + "retry_count = retry_count + 1 WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ?, " + "error_message = ? WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ? WHERE id = ?", + (status, job_id), + ) + conn.commit() + logger.info("Updated job %s status to %s", job_id, status) + + @staticmethod + def get_job_by_id( + job_id: int, + ) -> dict[str, Any] | None: + """Fetch single job by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def create_episode( # noqa: PLR0913, PLR0917 + title: str, + audio_url: str, + duration: int, + content_length: int, + user_id: int | None = None, + author: str | None = None, + original_url: str | None = None, + original_url_hash: str | None = None, + ) -> int: + """Insert episode record, return episode ID. + + Raises: + ValueError: If episode ID cannot be retrieved after insert. + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episodes " + "(title, audio_url, duration, content_length, user_id, " + "author, original_url, original_url_hash) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + title, + audio_url, + duration, + content_length, + user_id, + author, + original_url, + original_url_hash, + ), + ) + conn.commit() + episode_id = cursor.lastrowid + if episode_id is None: + msg = "Failed to get episode ID after insert" + raise ValueError(msg) + logger.info("Created episode %s: %s", episode_id, title) + return episode_id + + @staticmethod + def get_recent_episodes( + limit: int = 20, + ) -> list[dict[str, Any]]: + """Get recent episodes for RSS feed generation.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_queue_status_summary() -> dict[str, Any]: + """Get queue status summary for web interface.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count jobs by status + cursor.execute( + "SELECT status, COUNT(*) as count FROM queue GROUP BY status", + ) + rows = cursor.fetchall() + status_counts = {row["status"]: row["count"] for row in rows} + + # Get recent jobs + cursor.execute( + "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10", + ) + rows = cursor.fetchall() + recent_jobs = [dict(row) for row in rows] + + return {"status_counts": status_counts, "recent_jobs": recent_jobs} + + @staticmethod + def get_queue_status() -> list[dict[str, Any]]: + """Return pending/processing/error items for web interface.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, url, email, status, created_at, error_message, + title, author + FROM queue + WHERE status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) + ORDER BY created_at DESC + LIMIT 20 + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_episode_by_id(episode_id: int) -> dict[str, Any] | None: + """Fetch single episode by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url, user_id, is_public + FROM episodes + WHERE id = ? + """, + (episode_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_all_episodes( + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all episodes for RSS feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_retryable_jobs( + max_retries: int = 3, + ) -> list[dict[str, Any]]: + """Get failed jobs that can be retried.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'error' " + "AND retry_count < ? ORDER BY created_at ASC", + (max_retries,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def retry_job(job_id: int) -> None: + """Reset a job to pending status for retry.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE queue SET status = 'pending', " + "error_message = NULL WHERE id = ?", + (job_id,), + ) + conn.commit() + logger.info("Reset job %s to pending for retry", job_id) + + @staticmethod + def delete_job(job_id: int) -> None: + """Delete a job from the queue.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) + conn.commit() + logger.info("Deleted job %s from queue", job_id) + + @staticmethod + def get_all_queue_items( + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all queue items for admin view.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, url, email, status, retry_count, created_at, + error_message, title, author + FROM queue + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, url, email, status, retry_count, created_at, + error_message, title, author + FROM queue + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_status_counts() -> dict[str, int]: + """Get count of queue items by status.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT status, COUNT(*) as count + FROM queue + GROUP BY status + """) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def get_user_status_counts( + user_id: int, + ) -> dict[str, int]: + """Get count of queue items by status for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) as count + FROM queue + WHERE user_id = ? + GROUP BY status + """, + (user_id,), + ) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def migrate_to_multi_user() -> None: + """Migrate database to support multiple users.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add user_id columns to existing tables + # Check if columns already exist to make migration idempotent + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "user_id" not in queue_columns: + cursor.execute( + "ALTER TABLE queue ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "user_id" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + # Create indexes + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_user_id " + "ON queue(user_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_user_id " + "ON episodes(user_id)", + ) + + conn.commit() + logger.info("Database migrated to support multiple users") + + @staticmethod + def migrate_add_metadata_fields() -> None: + """Add title and author fields to queue table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if columns already exist + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "title" not in queue_columns: + cursor.execute("ALTER TABLE queue ADD COLUMN title TEXT") + + if "author" not in queue_columns: + cursor.execute("ALTER TABLE queue ADD COLUMN author TEXT") + + conn.commit() + logger.info("Database migrated to support metadata fields") + + @staticmethod + def migrate_add_episode_metadata() -> None: + """Add author and original_url fields to episodes table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if columns already exist + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "author" not in episodes_columns: + cursor.execute("ALTER TABLE episodes ADD COLUMN author TEXT") + + if "original_url" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN original_url TEXT", + ) + + conn.commit() + logger.info("Database migrated to support episode metadata fields") + + @staticmethod + def migrate_add_user_status() -> None: + """Add status field to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if column already exists + cursor.execute("PRAGMA table_info(users)") + users_info = cursor.fetchall() + users_columns = [col[1] for col in users_info] + + if "status" not in users_columns: + # Add status column with default 'active' + cursor.execute( + "ALTER TABLE users ADD COLUMN status TEXT DEFAULT 'active'", + ) + + conn.commit() + logger.info("Database migrated to support user status") + + @staticmethod + def migrate_add_billing_fields() -> None: + """Add billing-related fields to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add columns one by one (SQLite limitation) + # Note: SQLite ALTER TABLE doesn't support adding UNIQUE constraints + # We add them without UNIQUE and rely on application logic + columns_to_add = [ + ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"), + ("stripe_customer_id", "TEXT"), + ("stripe_subscription_id", "TEXT"), + ("subscription_status", "TEXT"), + ("current_period_start", "TIMESTAMP"), + ("current_period_end", "TIMESTAMP"), + ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"), + ] + + for column_name, column_def in columns_to_add: + try: + query = f"ALTER TABLE users ADD COLUMN {column_name} " + cursor.execute(query + column_def) + logger.info("Added column users.%s", column_name) + except sqlite3.OperationalError as e: # noqa: PERF203 + # Column already exists, skip + logger.debug( + "Column users.%s already exists: %s", + column_name, + e, + ) + + conn.commit() + + @staticmethod + def migrate_add_stripe_events_table() -> None: + """Create stripe_events table for webhook idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS stripe_events ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_stripe_events_created " + "ON stripe_events(created_at)", + ) + conn.commit() + logger.info("Created stripe_events table") + + @staticmethod + def migrate_add_public_feed() -> None: + """Add is_public column and related tables for public feed feature.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add is_public column to episodes + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "is_public" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN is_public INTEGER " + "NOT NULL DEFAULT 0", + ) + logger.info("Added is_public column to episodes") + + # Create user_episodes junction table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS user_episodes ( + user_id INTEGER NOT NULL, + episode_id INTEGER NOT NULL, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (user_id, episode_id), + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (episode_id) REFERENCES episodes(id) + ) + """) + logger.info("Created user_episodes junction table") + + # Create index on episode_id for reverse lookups + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_user_episodes_episode " + "ON user_episodes(episode_id)", + ) + + # Add original_url_hash column to episodes + if "original_url_hash" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN original_url_hash TEXT", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_url_hash " + "ON episodes(original_url_hash)", + ) + logger.info("Added original_url_hash column to episodes") + + # Create episode_metrics table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + episode_id INTEGER NOT NULL, + user_id INTEGER, + event_type TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (episode_id) REFERENCES episodes(id), + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + logger.info("Created episode_metrics table") + + # Create indexes for metrics queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episode_metrics_episode " + "ON episode_metrics(episode_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episode_metrics_type " + "ON episode_metrics(event_type)", + ) + + # Create index on is_public for efficient public feed queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_public " + "ON episodes(is_public)", + ) + + conn.commit() + logger.info("Database migrated for public feed feature") + + @staticmethod + def migrate_add_default_titles() -> None: + """Add default titles to queue items that have None titles.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Update queue items with NULL titles to have a default + cursor.execute(""" + UPDATE queue + SET title = 'Untitled Article' + WHERE title IS NULL + """) + + # Get count of updated rows + updated_count = cursor.rowcount + + conn.commit() + logger.info( + "Updated %s queue items with default titles", + updated_count, + ) + + @staticmethod + def create_user(email: str, status: str = "active") -> tuple[int, str]: + """Create a new user and return (user_id, token). + + Args: + email: User email address + status: Initial status (active or disabled) + + Raises: + ValueError: If user ID cannot be retrieved after insert or if user + not found, or if status is invalid. + """ + if status not in {"pending", "active", "disabled"}: + msg = f"Invalid status: {status}" + raise ValueError(msg) + + # Generate a secure token for RSS feed access + token = secrets.token_urlsafe(32) + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO users (email, token, status) VALUES (?, ?, ?)", + (email, token, status), + ) + conn.commit() + user_id = cursor.lastrowid + if user_id is None: + msg = "Failed to get user ID after insert" + raise ValueError(msg) + logger.info( + "Created user %s with email %s (status: %s)", + user_id, + email, + status, + ) + except sqlite3.IntegrityError: + # User already exists + cursor.execute( + "SELECT id, token FROM users WHERE email = ?", + (email,), + ) + row = cursor.fetchone() + if row is None: + msg = f"User with email {email} not found" + raise ValueError(msg) from None + return int(row["id"]), str(row["token"]) + else: + return int(user_id), str(token) + + @staticmethod + def get_user_by_email( + email: str, + ) -> dict[str, Any] | None: + """Get user by email address.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_token( + token: str, + ) -> dict[str, Any] | None: + """Get user by RSS token.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_id( + user_id: int, + ) -> dict[str, Any] | None: + """Get user by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_queue_position(job_id: int) -> int | None: + """Get position of job in pending queue.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + # Get created_at of this job + cursor.execute( + "SELECT created_at FROM queue WHERE id = ?", + (job_id,), + ) + row = cursor.fetchone() + if not row: + return None + created_at = row[0] + + # Count pending items created before or at same time + cursor.execute( + """ + SELECT COUNT(*) FROM queue + WHERE status = 'pending' AND created_at <= ? + """, + (created_at,), + ) + return int(cursor.fetchone()[0]) + + @staticmethod + def get_user_queue_status( + user_id: int, + ) -> list[dict[str, Any]]: + """Return pending/processing/error items for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, url, email, status, created_at, error_message, + title, author + FROM queue + WHERE user_id = ? AND + status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) + ORDER BY created_at DESC + LIMIT 20 + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_recent_episodes( + user_id: int, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Get recent episodes for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC LIMIT ?", + (user_id, limit), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_all_episodes( + user_id: int, + ) -> list[dict[str, Any]]: + """Get all episodes for a specific user for RSS feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC", + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_user_status( + user_id: int, + status: str, + ) -> None: + """Update user account status.""" + if status not in {"pending", "active", "disabled"}: + msg = f"Invalid status: {status}" + raise ValueError(msg) + + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info("Updated user %s status to %s", user_id, status) + + @staticmethod + def delete_user(user_id: int) -> None: + """Delete user and all associated data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # 1. Get owned episode IDs + cursor.execute( + "SELECT id FROM episodes WHERE user_id = ?", + (user_id,), + ) + owned_episode_ids = [row[0] for row in cursor.fetchall()] + + # 2. Delete references to owned episodes + if owned_episode_ids: + # Construct placeholders for IN clause + placeholders = ",".join("?" * len(owned_episode_ids)) + + # Delete from user_episodes where these episodes are referenced + query = f"DELETE FROM user_episodes WHERE episode_id IN ({placeholders})" # noqa: S608, E501 + cursor.execute(query, tuple(owned_episode_ids)) + + # Delete metrics for these episodes + query = f"DELETE FROM episode_metrics WHERE episode_id IN ({placeholders})" # noqa: S608, E501 + cursor.execute(query, tuple(owned_episode_ids)) + + # 3. Delete owned episodes + cursor.execute("DELETE FROM episodes WHERE user_id = ?", (user_id,)) + + # 4. Delete user's data referencing others or themselves + cursor.execute( + "DELETE FROM user_episodes WHERE user_id = ?", + (user_id,), + ) + cursor.execute( + "DELETE FROM episode_metrics WHERE user_id = ?", + (user_id,), + ) + cursor.execute("DELETE FROM queue WHERE user_id = ?", (user_id,)) + + # 5. Delete user + cursor.execute("DELETE FROM users WHERE id = ?", (user_id,)) + + conn.commit() + logger.info("Deleted user %s and all associated data", user_id) + + @staticmethod + def update_user_email(user_id: int, new_email: str) -> None: + """Update user's email address. + + Args: + user_id: ID of the user to update + new_email: New email address + + Raises: + ValueError: If email is already taken by another user + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "UPDATE users SET email = ? WHERE id = ?", + (new_email, user_id), + ) + conn.commit() + logger.info("Updated user %s email to %s", user_id, new_email) + except sqlite3.IntegrityError: + msg = f"Email {new_email} is already taken" + raise ValueError(msg) from None + + @staticmethod + def mark_episode_public(episode_id: int) -> None: + """Mark an episode as public.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE episodes SET is_public = 1 WHERE id = ?", + (episode_id,), + ) + conn.commit() + logger.info("Marked episode %s as public", episode_id) + + @staticmethod + def unmark_episode_public(episode_id: int) -> None: + """Mark an episode as private (not public).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE episodes SET is_public = 0 WHERE id = ?", + (episode_id,), + ) + conn.commit() + logger.info("Unmarked episode %s as public", episode_id) + + @staticmethod + def get_public_episodes(limit: int = 50) -> list[dict[str, Any]]: + """Get public episodes for public feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + WHERE is_public = 1 + ORDER BY created_at DESC + LIMIT ? + """, + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def add_episode_to_user(user_id: int, episode_id: int) -> None: + """Add an episode to a user's feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO user_episodes (user_id, episode_id) " + "VALUES (?, ?)", + (user_id, episode_id), + ) + conn.commit() + logger.info( + "Added episode %s to user %s feed", + episode_id, + user_id, + ) + except sqlite3.IntegrityError: + # Episode already in user's feed + logger.info( + "Episode %s already in user %s feed", + episode_id, + user_id, + ) + + @staticmethod + def user_has_episode(user_id: int, episode_id: int) -> bool: + """Check if a user has an episode in their feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM user_episodes " + "WHERE user_id = ? AND episode_id = ?", + (user_id, episode_id), + ) + return cursor.fetchone() is not None + + @staticmethod + def track_episode_metric( + episode_id: int, + event_type: str, + user_id: int | None = None, + ) -> None: + """Track an episode metric event. + + Args: + episode_id: ID of the episode + event_type: Type of event ('added', 'played', 'downloaded') + user_id: Optional user ID (None for anonymous events) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episode_metrics (episode_id, user_id, event_type) " + "VALUES (?, ?, ?)", + (episode_id, user_id, event_type), + ) + conn.commit() + logger.info( + "Tracked %s event for episode %s (user: %s)", + event_type, + episode_id, + user_id or "anonymous", + ) + + @staticmethod + def get_user_episodes(user_id: int) -> list[dict[str, Any]]: + """Get all episodes in a user's feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT e.id, e.title, e.audio_url, e.duration, e.created_at, + e.content_length, e.author, e.original_url, e.is_public, + ue.added_at + FROM episodes e + JOIN user_episodes ue ON e.id = ue.episode_id + WHERE ue.user_id = ? + ORDER BY ue.added_at DESC + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_episode_by_url_hash(url_hash: str) -> dict[str, Any] | None: + """Get episode by original URL hash.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE original_url_hash = ?", + (url_hash,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_metrics_summary() -> dict[str, Any]: + """Get aggregate metrics summary for admin dashboard. + + Returns: + dict with keys: + - total_episodes: Total number of episodes + - total_plays: Total play events + - total_downloads: Total download events + - total_adds: Total add events + - most_played: List of top 10 most played episodes + - most_downloaded: List of top 10 most downloaded episodes + - most_added: List of top 10 most added episodes + - total_users: Total number of users + - active_subscriptions: Number of active subscriptions + - submissions_24h: Submissions in last 24 hours + - submissions_7d: Submissions in last 7 days + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Get total episodes + cursor.execute("SELECT COUNT(*) as count FROM episodes") + total_episodes = cursor.fetchone()["count"] + + # Get event counts + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'played'", + ) + total_plays = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'downloaded'", + ) + total_downloads = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'added'", + ) + total_adds = cursor.fetchone()["count"] + + # Get most played episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as play_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'played' + GROUP BY em.episode_id + ORDER BY play_count DESC + LIMIT 10 + """, + ) + most_played = [dict(row) for row in cursor.fetchall()] + + # Get most downloaded episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as download_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'downloaded' + GROUP BY em.episode_id + ORDER BY download_count DESC + LIMIT 10 + """, + ) + most_downloaded = [dict(row) for row in cursor.fetchall()] + + # Get most added episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as add_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'added' + GROUP BY em.episode_id + ORDER BY add_count DESC + LIMIT 10 + """, + ) + most_added = [dict(row) for row in cursor.fetchall()] + + # Get user metrics + cursor.execute("SELECT COUNT(*) as count FROM users") + total_users = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM users " + "WHERE subscription_status = 'active'", + ) + active_subscriptions = cursor.fetchone()["count"] + + # Get recent submission metrics + cursor.execute( + "SELECT COUNT(*) as count FROM queue " + "WHERE created_at >= datetime('now', '-1 day')", + ) + submissions_24h = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM queue " + "WHERE created_at >= datetime('now', '-7 days')", + ) + submissions_7d = cursor.fetchone()["count"] + + return { + "total_episodes": total_episodes, + "total_plays": total_plays, + "total_downloads": total_downloads, + "total_adds": total_adds, + "most_played": most_played, + "most_downloaded": most_downloaded, + "most_added": most_added, + "total_users": total_users, + "active_subscriptions": active_subscriptions, + "submissions_24h": submissions_24h, + "submissions_7d": submissions_7d, + } + + @staticmethod + def track_episode_event( + episode_id: int, + event_type: str, + user_id: int | None = None, + ) -> None: + """Track an episode event (added, played, downloaded).""" + if event_type not in {"added", "played", "downloaded"}: + msg = f"Invalid event type: {event_type}" + raise ValueError(msg) + + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episode_metrics " + "(episode_id, user_id, event_type) VALUES (?, ?, ?)", + (episode_id, user_id, event_type), + ) + conn.commit() + logger.info( + "Tracked %s event for episode %s", + event_type, + episode_id, + ) + + @staticmethod + def get_episode_metrics(episode_id: int) -> dict[str, int]: + """Get aggregated metrics for an episode.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT event_type, COUNT(*) as count + FROM episode_metrics + WHERE episode_id = ? + GROUP BY event_type + """, + (episode_id,), + ) + rows = cursor.fetchall() + return {row["event_type"]: row["count"] for row in rows} + + @staticmethod + def get_episode_metric_events(episode_id: int) -> list[dict[str, Any]]: + """Get raw metric events for an episode (for testing).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, episode_id, user_id, event_type, created_at + FROM episode_metrics + WHERE episode_id = ? + ORDER BY created_at DESC + """, + (episode_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def set_user_stripe_customer(user_id: int, customer_id: str) -> None: + """Link Stripe customer ID to user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET stripe_customer_id = ? WHERE id = ?", + (customer_id, user_id), + ) + conn.commit() + logger.info( + "Linked user %s to Stripe customer %s", + user_id, + customer_id, + ) + + @staticmethod + def update_user_subscription( # noqa: PLR0913, PLR0917 + user_id: int, + subscription_id: str, + status: str, + period_start: Any, + period_end: Any, + tier: str, + cancel_at_period_end: bool, # noqa: FBT001 + ) -> None: + """Update user subscription details.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + stripe_subscription_id = ?, + subscription_status = ?, + current_period_start = ?, + current_period_end = ?, + plan_tier = ?, + cancel_at_period_end = ? + WHERE id = ? + """, + ( + subscription_id, + status, + period_start.isoformat(), + period_end.isoformat(), + tier, + 1 if cancel_at_period_end else 0, + user_id, + ), + ) + conn.commit() + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user_id, + tier, + status, + ) + + @staticmethod + def update_subscription_status(user_id: int, status: str) -> None: + """Update only the subscription status (e.g., past_due).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET subscription_status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info( + "Updated user %s subscription status to %s", + user_id, + status, + ) + + @staticmethod + def downgrade_to_free(user_id: int) -> None: + """Downgrade user to free tier and clear subscription data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + plan_tier = 'free', + subscription_status = 'canceled', + stripe_subscription_id = NULL, + current_period_start = NULL, + current_period_end = NULL, + cancel_at_period_end = 0 + WHERE id = ? + """, + (user_id,), + ) + conn.commit() + logger.info("Downgraded user %s to free tier", user_id) + + @staticmethod + def get_user_by_stripe_customer_id( + customer_id: str, + ) -> dict[str, Any] | None: + """Get user by Stripe customer ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM users WHERE stripe_customer_id = ?", + (customer_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def has_processed_stripe_event(event_id: str) -> bool: + """Check if Stripe event has already been processed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id FROM stripe_events WHERE id = ?", + (event_id,), + ) + return cursor.fetchone() is not None + + @staticmethod + def mark_stripe_event_processed( + event_id: str, + event_type: str, + payload: bytes, + ) -> None: + """Mark Stripe event as processed for idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO stripe_events (id, type, payload) " + "VALUES (?, ?, ?)", + (event_id, event_type, payload.decode("utf-8")), + ) + conn.commit() + + @staticmethod + def get_usage( + user_id: int, + period_start: Any, + period_end: Any, + ) -> dict[str, int]: + """Get usage stats for user in period. + + Counts episodes added to user's feed (via user_episodes table) + during the billing period, regardless of who created them. + + Returns: + dict with keys: articles (int), minutes (int) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count articles added to user's feed in period + # Uses user_episodes junction table to track when episodes + # were added, which correctly handles shared/existing episodes + cursor.execute( + """ + SELECT COUNT(*) as count, SUM(e.duration) as total_seconds + FROM user_episodes ue + JOIN episodes e ON e.id = ue.episode_id + WHERE ue.user_id = ? AND ue.added_at >= ? AND ue.added_at < ? + """, + (user_id, period_start.isoformat(), period_end.isoformat()), + ) + row = cursor.fetchone() + + articles = row["count"] if row else 0 + total_seconds = ( + row["total_seconds"] if row and row["total_seconds"] else 0 + ) + minutes = total_seconds // 60 + + return {"articles": articles, "minutes": minutes} + + +class TestDatabase(Test.TestCase): + """Test the Database class.""" + + @staticmethod + def setUp() -> None: + """Set up test database.""" + Database.init_db() + + def tearDown(self) -> None: + """Clean up test database.""" + Database.teardown() + # Clear user ID + self.user_id = None + + def test_init_db(self) -> None: + """Verify all tables and indexes are created correctly.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check tables exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + self.assertIn("queue", tables) + self.assertIn("episodes", tables) + self.assertIn("users", tables) + + # Check indexes exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='index'") + indexes = {row[0] for row in cursor.fetchall()} + self.assertIn("idx_queue_status", indexes) + self.assertIn("idx_queue_created", indexes) + self.assertIn("idx_episodes_created", indexes) + self.assertIn("idx_queue_user_id", indexes) + self.assertIn("idx_episodes_user_id", indexes) + + def test_connection_context_manager(self) -> None: + """Ensure connections are properly closed.""" + # Get a connection and verify it works + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchone() + self.assertEqual(result[0], 1) + + # Connection should be closed after context manager + with pytest.raises(sqlite3.ProgrammingError): + cursor.execute("SELECT 1") + + def test_migration_idempotency(self) -> None: + """Verify migrations can run multiple times safely.""" + # Run migration multiple times + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + + # Should still work fine + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users") + # Should not raise an error + # Test completed successfully - migration worked + self.assertIsNotNone(conn) + + def test_get_metrics_summary_extended(self) -> None: + """Verify extended metrics summary.""" + # Create some data + user_id, _ = Database.create_user("test@example.com") + Database.create_episode( + "Test Article", + "url", + 100, + 1000, + user_id, + ) + + # Create a queue item + Database.add_to_queue( + "https://example.com", + "test@example.com", + user_id, + ) + + metrics = Database.get_metrics_summary() + + self.assertIn("total_users", metrics) + self.assertIn("active_subscriptions", metrics) + self.assertIn("submissions_24h", metrics) + self.assertIn("submissions_7d", metrics) + + self.assertEqual(metrics["total_users"], 1) + self.assertEqual(metrics["submissions_24h"], 1) + self.assertEqual(metrics["submissions_7d"], 1) + + +class TestUserManagement(Test.TestCase): + """Test user management functionality.""" + + @staticmethod + def setUp() -> None: + """Set up test database.""" + Database.init_db() + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_create_user(self) -> None: + """Create user with unique email and token.""" + user_id, token = Database.create_user("test@example.com") + + self.assertIsInstance(user_id, int) + self.assertIsInstance(token, str) + self.assertGreater(len(token), 20) # Should be a secure token + + def test_create_duplicate_user(self) -> None: + """Verify duplicate emails return existing user.""" + # Create first user + user_id1, token1 = Database.create_user( + "test@example.com", + ) + + # Try to create duplicate + user_id2, token2 = Database.create_user( + "test@example.com", + ) + + # Should return same user + self.assertIsNotNone(user_id1) + self.assertIsNotNone(user_id2) + self.assertEqual(user_id1, user_id2) + self.assertEqual(token1, token2) + + def test_get_user_by_email(self) -> None: + """Retrieve user by email.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_email("test@example.com") + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_get_user_by_token(self) -> None: + """Retrieve user by RSS token.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_token(token) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + + def test_get_user_by_id(self) -> None: + """Retrieve user by ID.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_invalid_user_lookups(self) -> None: + """Verify None returned for non-existent users.""" + self.assertIsNone( + Database.get_user_by_email("nobody@example.com"), + ) + self.assertIsNone( + Database.get_user_by_token("invalid-token"), + ) + self.assertIsNone(Database.get_user_by_id(9999)) + + def test_token_uniqueness(self) -> None: + """Ensure tokens are cryptographically unique.""" + tokens = set() + for i in range(10): + _, token = Database.create_user( + f"user{i}@example.com", + ) + tokens.add(token) + + # All tokens should be unique + self.assertEqual(len(tokens), 10) + + def test_delete_user(self) -> None: + """Test user deletion and cleanup.""" + # Create user + user_id, _ = Database.create_user("delete_me@example.com") + + # Create some data for the user + Database.add_to_queue( + "https://example.com/article", + "delete_me@example.com", + user_id, + ) + + ep_id = Database.create_episode( + title="Test Episode", + audio_url="url", + duration=100, + content_length=1000, + user_id=user_id, + ) + Database.add_episode_to_user(user_id, ep_id) + Database.track_episode_metric(ep_id, "played", user_id) + + # Delete user + Database.delete_user(user_id) + + # Verify user is gone + self.assertIsNone(Database.get_user_by_id(user_id)) + + # Verify queue items are gone + queue = Database.get_user_queue_status(user_id) + self.assertEqual(len(queue), 0) + + # Verify episodes are gone (direct lookup) + self.assertIsNone(Database.get_episode_by_id(ep_id)) + + def test_update_user_email(self) -> None: + """Update user email address.""" + user_id, _ = Database.create_user("old@example.com") + + # Update email + Database.update_user_email(user_id, "new@example.com") + + # Verify update + user = Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + if user: + self.assertEqual(user["email"], "new@example.com") + + # Old email should not exist + self.assertIsNone(Database.get_user_by_email("old@example.com")) + + @staticmethod + def test_update_user_email_duplicate() -> None: + """Cannot update to an existing email.""" + user_id1, _ = Database.create_user("user1@example.com") + Database.create_user("user2@example.com") + + # Try to update user1 to user2's email + with pytest.raises(ValueError, match="already taken"): + Database.update_user_email(user_id1, "user2@example.com") + + +class TestQueueOperations(Test.TestCase): + """Test queue operations.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_add_to_queue(self) -> None: + """Add job with user association.""" + job_id = Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + self.assertIsInstance(job_id, int) + self.assertGreater(job_id, 0) + + def test_get_pending_jobs(self) -> None: + """Retrieve jobs in correct order.""" + # Add multiple jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) # Ensure different timestamps + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Database.get_pending_jobs(limit=10) + + self.assertEqual(len(jobs), 3) + # Should be in order of creation (oldest first) + self.assertEqual(jobs[0]["id"], job1) + self.assertEqual(jobs[1]["id"], job2) + self.assertEqual(jobs[2]["id"], job3) + + def test_update_job_status(self) -> None: + """Update status and error messages.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Update to processing + Database.update_job_status(job_id, "processing") + job = Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "processing") + + # Update to error with message + Database.update_job_status( + job_id, + "error", + "Network timeout", + ) + job = Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "error") + self.assertEqual(job["error_message"], "Network timeout") + self.assertEqual(job["retry_count"], 1) + + def test_retry_job(self) -> None: + """Reset failed jobs for retry.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Set to error + Database.update_job_status(job_id, "error", "Failed") + + # Retry + Database.retry_job(job_id) + job = Database.get_job_by_id(job_id) + + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertIsNone(job["error_message"]) + + def test_delete_job(self) -> None: + """Remove jobs from queue.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Delete job + Database.delete_job(job_id) + + # Should not exist + job = Database.get_job_by_id(job_id) + self.assertIsNone(job) + + def test_get_retryable_jobs(self) -> None: + """Find jobs eligible for retry.""" + # Add job and mark as error + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + Database.update_job_status(job_id, "error", "Failed") + + # Should be retryable + retryable = Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + self.assertEqual(retryable[0]["id"], job_id) + + # Exceed retry limit + Database.update_job_status( + job_id, + "error", + "Failed again", + ) + Database.update_job_status( + job_id, + "error", + "Failed yet again", + ) + + # Should not be retryable anymore + retryable = Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_user_queue_isolation(self) -> None: + """Ensure users only see their own jobs.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com") + + # Add jobs for both users + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "user2@example.com", + user2_id, + ) + + # Get user-specific queue status + user1_jobs = Database.get_user_queue_status(self.user_id) + user2_jobs = Database.get_user_queue_status(user2_id) + + self.assertEqual(len(user1_jobs), 1) + self.assertEqual(user1_jobs[0]["id"], job1) + + self.assertEqual(len(user2_jobs), 1) + self.assertEqual(user2_jobs[0]["id"], job2) + + def test_status_counts(self) -> None: + """Verify status aggregation queries.""" + # Add jobs with different statuses + Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + Database.update_job_status(job2, "processing") + Database.update_job_status(job3, "error", "Failed") + + # Get status counts + counts = Database.get_user_status_counts(self.user_id) + + self.assertEqual(counts.get("pending", 0), 1) + self.assertEqual(counts.get("processing", 0), 1) + self.assertEqual(counts.get("error", 0), 1) + + def test_queue_position(self) -> None: + """Verify queue position calculation.""" + # Add multiple pending jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Check positions + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertEqual(Database.get_queue_position(job2), 2) + self.assertEqual(Database.get_queue_position(job3), 3) + + # Move job 2 to processing + Database.update_job_status(job2, "processing") + + # Check positions (job 3 should now be 2nd pending job) + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertIsNone(Database.get_queue_position(job2)) + self.assertEqual(Database.get_queue_position(job3), 2) + + +class TestEpisodeManagement(Test.TestCase): + """Test episode management functionality.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_create_episode(self) -> None: + """Create episode with user association.""" + episode_id = Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + ) + + self.assertIsInstance(episode_id, int) + self.assertGreater(episode_id, 0) + + def test_get_recent_episodes(self) -> None: + """Retrieve episodes in reverse chronological order.""" + # Create multiple episodes + ep1 = Database.create_episode( + "Article 1", + "url1", + 100, + 1000, + self.user_id, + ) + time.sleep(0.01) + ep2 = Database.create_episode( + "Article 2", + "url2", + 200, + 2000, + self.user_id, + ) + time.sleep(0.01) + ep3 = Database.create_episode( + "Article 3", + "url3", + 300, + 3000, + self.user_id, + ) + + # Get recent episodes + episodes = Database.get_recent_episodes(limit=10) + + self.assertEqual(len(episodes), 3) + # Should be in reverse chronological order + self.assertEqual(episodes[0]["id"], ep3) + self.assertEqual(episodes[1]["id"], ep2) + self.assertEqual(episodes[2]["id"], ep1) + + def test_get_user_episodes(self) -> None: + """Ensure user isolation for episodes.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com") + + # Create episodes for both users + ep1 = Database.create_episode( + "User1 Article", + "url1", + 100, + 1000, + self.user_id, + ) + ep2 = Database.create_episode( + "User2 Article", + "url2", + 200, + 2000, + user2_id, + ) + + # Get user-specific episodes + user1_episodes = Database.get_user_all_episodes( + self.user_id, + ) + user2_episodes = Database.get_user_all_episodes(user2_id) + + self.assertEqual(len(user1_episodes), 1) + self.assertEqual(user1_episodes[0]["id"], ep1) + + self.assertEqual(len(user2_episodes), 1) + self.assertEqual(user2_episodes[0]["id"], ep2) + + def test_episode_metadata(self) -> None: + """Verify duration and content_length storage.""" + Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=12345, + content_length=98765, + user_id=self.user_id, + ) + + episodes = Database.get_user_all_episodes(self.user_id) + episode = episodes[0] + + self.assertEqual(episode["duration"], 12345) + self.assertEqual(episode["content_length"], 98765) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestDatabase, + TestUserManagement, + TestQueueOperations, + TestEpisodeManagement, + ], + ) + + +def main() -> None: + """Run all PodcastItLater.Core tests.""" + if "test" in sys.argv: + test() |
