summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Core.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
-rw-r--r--Biz/PodcastItLater/Core.py2174
1 files changed, 2174 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
new file mode 100644
index 0000000..3a88f22
--- /dev/null
+++ b/Biz/PodcastItLater/Core.py
@@ -0,0 +1,2174 @@
+"""Core, shared logic for PodcastItalater.
+
+Includes:
+- Database models
+- Data access layer
+- Shared types
+"""
+
+# : out podcastitlater-core
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+import hashlib
+import logging
+import Omni.App as App
+import Omni.Test as Test
+import os
+import pathlib
+import pytest
+import secrets
+import sqlite3
+import sys
+import time
+import typing
+import urllib.parse
+from collections.abc import Iterator
+from contextlib import contextmanager
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+
+CODEROOT = pathlib.Path(os.getenv("CODEROOT", "."))
+DATA_DIR = pathlib.Path(
+ os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"),
+)
+
+# Constants for UI display
+URL_TRUNCATE_LENGTH = 80
+TITLE_TRUNCATE_LENGTH = 50
+ERROR_TRUNCATE_LENGTH = 50
+
+# Admin whitelist
+ADMIN_EMAILS = ["ben@bensima.com", "admin@example.com"]
+
+
+def is_admin(user: dict[str, typing.Any] | None) -> bool:
+ """Check if user is an admin based on email whitelist."""
+ if not user:
+ return False
+ return user.get("email", "").lower() in [
+ email.lower() for email in ADMIN_EMAILS
+ ]
+
+
+def normalize_url(url: str) -> str:
+ """Normalize URL for comparison and hashing.
+
+ Normalizes:
+ - Protocol (http/https)
+ - Domain case (lowercase)
+ - www prefix (removed)
+ - Trailing slash (removed)
+ - Preserves query params and fragments as they may be meaningful
+
+ Args:
+ url: URL to normalize
+
+ Returns:
+ Normalized URL string
+ """
+ parsed = urllib.parse.urlparse(url.strip())
+
+ # Normalize domain to lowercase, remove www prefix
+ domain = parsed.netloc.lower()
+ domain = domain.removeprefix("www.")
+
+ # Normalize path - remove trailing slash unless it's the root
+ path = parsed.path.rstrip("/") if parsed.path != "/" else "/"
+
+ # Rebuild URL with normalized components
+ # Use https as the canonical protocol
+ return urllib.parse.urlunparse((
+ "https", # Always use https
+ domain,
+ path,
+ parsed.params,
+ parsed.query,
+ parsed.fragment,
+ ))
+
+
+def hash_url(url: str) -> str:
+ """Generate a hash of a URL for deduplication.
+
+ Args:
+ url: URL to hash
+
+ Returns:
+ SHA256 hash of the normalized URL
+ """
+ normalized = normalize_url(url)
+ return hashlib.sha256(normalized.encode()).hexdigest()
+
+
+class Database: # noqa: PLR0904
+ """Data access layer for PodcastItLater database operations."""
+
+ @staticmethod
+ def teardown() -> None:
+ """Delete the existing database, for cleanup after tests."""
+ db_path = DATA_DIR / "podcast.db"
+ if db_path.exists():
+ db_path.unlink()
+
+ @staticmethod
+ @contextmanager
+ def get_connection() -> Iterator[sqlite3.Connection]:
+ """Context manager for database connections.
+
+ Yields:
+ sqlite3.Connection: Database connection with row factory set.
+ """
+ db_path = DATA_DIR / "podcast.db"
+ conn = sqlite3.connect(db_path)
+ conn.row_factory = sqlite3.Row
+ try:
+ yield conn
+ finally:
+ conn.close()
+
+ @staticmethod
+ def init_db() -> None:
+ """Initialize database with required tables."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Queue table for job processing
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS queue (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ url TEXT,
+ email TEXT,
+ status TEXT DEFAULT 'pending',
+ retry_count INTEGER DEFAULT 0,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ error_message TEXT,
+ title TEXT,
+ author TEXT
+ )
+ """)
+
+ # Episodes table for completed podcasts
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS episodes (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ title TEXT NOT NULL,
+ content_length INTEGER,
+ audio_url TEXT NOT NULL,
+ duration INTEGER,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Create indexes for performance
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_created "
+ "ON queue(created_at)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_created "
+ "ON episodes(created_at)",
+ )
+
+ conn.commit()
+ logger.info("Database initialized successfully")
+
+ # Run migration to add user support
+ Database.migrate_to_multi_user()
+
+ # Run migration to add metadata fields
+ Database.migrate_add_metadata_fields()
+
+ # Run migration to add episode metadata fields
+ Database.migrate_add_episode_metadata()
+
+ # Run migration to add user status field
+ Database.migrate_add_user_status()
+
+ # Run migration to add default titles
+ Database.migrate_add_default_titles()
+
+ # Run migration to add billing fields
+ Database.migrate_add_billing_fields()
+
+ # Run migration to add stripe events table
+ Database.migrate_add_stripe_events_table()
+
+ # Run migration to add public feed features
+ Database.migrate_add_public_feed()
+
+ @staticmethod
+ def add_to_queue(
+ url: str,
+ email: str,
+ user_id: int,
+ title: str | None = None,
+ author: str | None = None,
+ ) -> int:
+ """Insert new job into queue with metadata, return job ID.
+
+ Raises:
+ ValueError: If job ID cannot be retrieved after insert.
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO queue (url, email, user_id, title, author) "
+ "VALUES (?, ?, ?, ?, ?)",
+ (url, email, user_id, title, author),
+ )
+ conn.commit()
+ job_id = cursor.lastrowid
+ if job_id is None:
+ msg = "Failed to get job ID after insert"
+ raise ValueError(msg)
+ logger.info("Added job %s to queue: %s", job_id, url)
+ return job_id
+
+ @staticmethod
+ def get_pending_jobs(
+ limit: int = 10,
+ ) -> list[dict[str, Any]]:
+ """Fetch jobs with status='pending' ordered by creation time."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM queue WHERE status = 'pending' "
+ "ORDER BY created_at ASC LIMIT ?",
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def update_job_status(
+ job_id: int,
+ status: str,
+ error: str | None = None,
+ ) -> None:
+ """Update job status and error message."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ if error is not None:
+ if status == "error":
+ cursor.execute(
+ "UPDATE queue SET status = ?, error_message = ?, "
+ "retry_count = retry_count + 1 WHERE id = ?",
+ (status, error, job_id),
+ )
+ else:
+ cursor.execute(
+ "UPDATE queue SET status = ?, "
+ "error_message = ? WHERE id = ?",
+ (status, error, job_id),
+ )
+ else:
+ cursor.execute(
+ "UPDATE queue SET status = ? WHERE id = ?",
+ (status, job_id),
+ )
+ conn.commit()
+ logger.info("Updated job %s status to %s", job_id, status)
+
+ @staticmethod
+ def get_job_by_id(
+ job_id: int,
+ ) -> dict[str, Any] | None:
+ """Fetch single job by ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def create_episode( # noqa: PLR0913, PLR0917
+ title: str,
+ audio_url: str,
+ duration: int,
+ content_length: int,
+ user_id: int | None = None,
+ author: str | None = None,
+ original_url: str | None = None,
+ original_url_hash: str | None = None,
+ ) -> int:
+ """Insert episode record, return episode ID.
+
+ Raises:
+ ValueError: If episode ID cannot be retrieved after insert.
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episodes "
+ "(title, audio_url, duration, content_length, user_id, "
+ "author, original_url, original_url_hash) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
+ (
+ title,
+ audio_url,
+ duration,
+ content_length,
+ user_id,
+ author,
+ original_url,
+ original_url_hash,
+ ),
+ )
+ conn.commit()
+ episode_id = cursor.lastrowid
+ if episode_id is None:
+ msg = "Failed to get episode ID after insert"
+ raise ValueError(msg)
+ logger.info("Created episode %s: %s", episode_id, title)
+ return episode_id
+
+ @staticmethod
+ def get_recent_episodes(
+ limit: int = 20,
+ ) -> list[dict[str, Any]]:
+ """Get recent episodes for RSS feed generation."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?",
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_queue_status_summary() -> dict[str, Any]:
+ """Get queue status summary for web interface."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Count jobs by status
+ cursor.execute(
+ "SELECT status, COUNT(*) as count FROM queue GROUP BY status",
+ )
+ rows = cursor.fetchall()
+ status_counts = {row["status"]: row["count"] for row in rows}
+
+ # Get recent jobs
+ cursor.execute(
+ "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10",
+ )
+ rows = cursor.fetchall()
+ recent_jobs = [dict(row) for row in rows]
+
+ return {"status_counts": status_counts, "recent_jobs": recent_jobs}
+
+ @staticmethod
+ def get_queue_status() -> list[dict[str, Any]]:
+ """Return pending/processing/error items for web interface."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT id, url, email, status, created_at, error_message,
+ title, author
+ FROM queue
+ WHERE status IN (
+ 'pending', 'processing', 'extracting',
+ 'synthesizing', 'uploading', 'error'
+ )
+ ORDER BY created_at DESC
+ LIMIT 20
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_episode_by_id(episode_id: int) -> dict[str, Any] | None:
+ """Fetch single episode by ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url, user_id, is_public
+ FROM episodes
+ WHERE id = ?
+ """,
+ (episode_id,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_all_episodes(
+ user_id: int | None = None,
+ ) -> list[dict[str, Any]]:
+ """Return all episodes for RSS feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ if user_id:
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url
+ FROM episodes
+ WHERE user_id = ?
+ ORDER BY created_at DESC
+ """,
+ (user_id,),
+ )
+ else:
+ cursor.execute("""
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url
+ FROM episodes
+ ORDER BY created_at DESC
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_retryable_jobs(
+ max_retries: int = 3,
+ ) -> list[dict[str, Any]]:
+ """Get failed jobs that can be retried."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM queue WHERE status = 'error' "
+ "AND retry_count < ? ORDER BY created_at ASC",
+ (max_retries,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def retry_job(job_id: int) -> None:
+ """Reset a job to pending status for retry."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE queue SET status = 'pending', "
+ "error_message = NULL WHERE id = ?",
+ (job_id,),
+ )
+ conn.commit()
+ logger.info("Reset job %s to pending for retry", job_id)
+
+ @staticmethod
+ def delete_job(job_id: int) -> None:
+ """Delete a job from the queue."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,))
+ conn.commit()
+ logger.info("Deleted job %s from queue", job_id)
+
+ @staticmethod
+ def get_all_queue_items(
+ user_id: int | None = None,
+ ) -> list[dict[str, Any]]:
+ """Return all queue items for admin view."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ if user_id:
+ cursor.execute(
+ """
+ SELECT id, url, email, status, retry_count, created_at,
+ error_message, title, author
+ FROM queue
+ WHERE user_id = ?
+ ORDER BY created_at DESC
+ """,
+ (user_id,),
+ )
+ else:
+ cursor.execute("""
+ SELECT id, url, email, status, retry_count, created_at,
+ error_message, title, author
+ FROM queue
+ ORDER BY created_at DESC
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_status_counts() -> dict[str, int]:
+ """Get count of queue items by status."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT status, COUNT(*) as count
+ FROM queue
+ GROUP BY status
+ """)
+ rows = cursor.fetchall()
+ return {row["status"]: row["count"] for row in rows}
+
+ @staticmethod
+ def get_user_status_counts(
+ user_id: int,
+ ) -> dict[str, int]:
+ """Get count of queue items by status for a specific user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT status, COUNT(*) as count
+ FROM queue
+ WHERE user_id = ?
+ GROUP BY status
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return {row["status"]: row["count"] for row in rows}
+
+ @staticmethod
+ def migrate_to_multi_user() -> None:
+ """Migrate database to support multiple users."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Create users table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ email TEXT UNIQUE NOT NULL,
+ token TEXT UNIQUE NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Add user_id columns to existing tables
+ # Check if columns already exist to make migration idempotent
+ cursor.execute("PRAGMA table_info(queue)")
+ queue_info = cursor.fetchall()
+ queue_columns = [col[1] for col in queue_info]
+
+ if "user_id" not in queue_columns:
+ cursor.execute(
+ "ALTER TABLE queue ADD COLUMN user_id INTEGER "
+ "REFERENCES users(id)",
+ )
+
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "user_id" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN user_id INTEGER "
+ "REFERENCES users(id)",
+ )
+
+ # Create indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_user_id "
+ "ON queue(user_id)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_user_id "
+ "ON episodes(user_id)",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support multiple users")
+
+ @staticmethod
+ def migrate_add_metadata_fields() -> None:
+ """Add title and author fields to queue table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check if columns already exist
+ cursor.execute("PRAGMA table_info(queue)")
+ queue_info = cursor.fetchall()
+ queue_columns = [col[1] for col in queue_info]
+
+ if "title" not in queue_columns:
+ cursor.execute("ALTER TABLE queue ADD COLUMN title TEXT")
+
+ if "author" not in queue_columns:
+ cursor.execute("ALTER TABLE queue ADD COLUMN author TEXT")
+
+ conn.commit()
+ logger.info("Database migrated to support metadata fields")
+
+ @staticmethod
+ def migrate_add_episode_metadata() -> None:
+ """Add author and original_url fields to episodes table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check if columns already exist
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "author" not in episodes_columns:
+ cursor.execute("ALTER TABLE episodes ADD COLUMN author TEXT")
+
+ if "original_url" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN original_url TEXT",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support episode metadata fields")
+
+ @staticmethod
+ def migrate_add_user_status() -> None:
+ """Add status field to users table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check if column already exists
+ cursor.execute("PRAGMA table_info(users)")
+ users_info = cursor.fetchall()
+ users_columns = [col[1] for col in users_info]
+
+ if "status" not in users_columns:
+ # Add status column with default 'active'
+ cursor.execute(
+ "ALTER TABLE users ADD COLUMN status TEXT DEFAULT 'active'",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support user status")
+
+ @staticmethod
+ def migrate_add_billing_fields() -> None:
+ """Add billing-related fields to users table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Add columns one by one (SQLite limitation)
+ # Note: SQLite ALTER TABLE doesn't support adding UNIQUE constraints
+ # We add them without UNIQUE and rely on application logic
+ columns_to_add = [
+ ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"),
+ ("stripe_customer_id", "TEXT"),
+ ("stripe_subscription_id", "TEXT"),
+ ("subscription_status", "TEXT"),
+ ("current_period_start", "TIMESTAMP"),
+ ("current_period_end", "TIMESTAMP"),
+ ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"),
+ ]
+
+ for column_name, column_def in columns_to_add:
+ try:
+ query = f"ALTER TABLE users ADD COLUMN {column_name} "
+ cursor.execute(query + column_def)
+ logger.info("Added column users.%s", column_name)
+ except sqlite3.OperationalError as e: # noqa: PERF203
+ # Column already exists, skip
+ logger.debug(
+ "Column users.%s already exists: %s",
+ column_name,
+ e,
+ )
+
+ conn.commit()
+
+ @staticmethod
+ def migrate_add_stripe_events_table() -> None:
+ """Create stripe_events table for webhook idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS stripe_events (
+ id TEXT PRIMARY KEY,
+ type TEXT NOT NULL,
+ payload TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_stripe_events_created "
+ "ON stripe_events(created_at)",
+ )
+ conn.commit()
+ logger.info("Created stripe_events table")
+
+ @staticmethod
+ def migrate_add_public_feed() -> None:
+ """Add is_public column and related tables for public feed feature."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Add is_public column to episodes
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "is_public" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN is_public INTEGER "
+ "NOT NULL DEFAULT 0",
+ )
+ logger.info("Added is_public column to episodes")
+
+ # Create user_episodes junction table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS user_episodes (
+ user_id INTEGER NOT NULL,
+ episode_id INTEGER NOT NULL,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (user_id, episode_id),
+ FOREIGN KEY (user_id) REFERENCES users(id),
+ FOREIGN KEY (episode_id) REFERENCES episodes(id)
+ )
+ """)
+ logger.info("Created user_episodes junction table")
+
+ # Create index on episode_id for reverse lookups
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_user_episodes_episode "
+ "ON user_episodes(episode_id)",
+ )
+
+ # Add original_url_hash column to episodes
+ if "original_url_hash" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN original_url_hash TEXT",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_url_hash "
+ "ON episodes(original_url_hash)",
+ )
+ logger.info("Added original_url_hash column to episodes")
+
+ # Create episode_metrics table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS episode_metrics (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ episode_id INTEGER NOT NULL,
+ user_id INTEGER,
+ event_type TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (episode_id) REFERENCES episodes(id),
+ FOREIGN KEY (user_id) REFERENCES users(id)
+ )
+ """)
+ logger.info("Created episode_metrics table")
+
+ # Create indexes for metrics queries
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episode_metrics_episode "
+ "ON episode_metrics(episode_id)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episode_metrics_type "
+ "ON episode_metrics(event_type)",
+ )
+
+ # Create index on is_public for efficient public feed queries
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_public "
+ "ON episodes(is_public)",
+ )
+
+ conn.commit()
+ logger.info("Database migrated for public feed feature")
+
+ @staticmethod
+ def migrate_add_default_titles() -> None:
+ """Add default titles to queue items that have None titles."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Update queue items with NULL titles to have a default
+ cursor.execute("""
+ UPDATE queue
+ SET title = 'Untitled Article'
+ WHERE title IS NULL
+ """)
+
+ # Get count of updated rows
+ updated_count = cursor.rowcount
+
+ conn.commit()
+ logger.info(
+ "Updated %s queue items with default titles",
+ updated_count,
+ )
+
+ @staticmethod
+ def create_user(email: str, status: str = "active") -> tuple[int, str]:
+ """Create a new user and return (user_id, token).
+
+ Args:
+ email: User email address
+ status: Initial status (active or disabled)
+
+ Raises:
+ ValueError: If user ID cannot be retrieved after insert or if user
+ not found, or if status is invalid.
+ """
+ if status not in {"pending", "active", "disabled"}:
+ msg = f"Invalid status: {status}"
+ raise ValueError(msg)
+
+ # Generate a secure token for RSS feed access
+ token = secrets.token_urlsafe(32)
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "INSERT INTO users (email, token, status) VALUES (?, ?, ?)",
+ (email, token, status),
+ )
+ conn.commit()
+ user_id = cursor.lastrowid
+ if user_id is None:
+ msg = "Failed to get user ID after insert"
+ raise ValueError(msg)
+ logger.info(
+ "Created user %s with email %s (status: %s)",
+ user_id,
+ email,
+ status,
+ )
+ except sqlite3.IntegrityError:
+ # User already exists
+ cursor.execute(
+ "SELECT id, token FROM users WHERE email = ?",
+ (email,),
+ )
+ row = cursor.fetchone()
+ if row is None:
+ msg = f"User with email {email} not found"
+ raise ValueError(msg) from None
+ return int(row["id"]), str(row["token"])
+ else:
+ return int(user_id), str(token)
+
+ @staticmethod
+ def get_user_by_email(
+ email: str,
+ ) -> dict[str, Any] | None:
+ """Get user by email address."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_by_token(
+ token: str,
+ ) -> dict[str, Any] | None:
+ """Get user by RSS token."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE token = ?", (token,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_by_id(
+ user_id: int,
+ ) -> dict[str, Any] | None:
+ """Get user by ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_queue_position(job_id: int) -> int | None:
+ """Get position of job in pending queue."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ # Get created_at of this job
+ cursor.execute(
+ "SELECT created_at FROM queue WHERE id = ?",
+ (job_id,),
+ )
+ row = cursor.fetchone()
+ if not row:
+ return None
+ created_at = row[0]
+
+ # Count pending items created before or at same time
+ cursor.execute(
+ """
+ SELECT COUNT(*) FROM queue
+ WHERE status = 'pending' AND created_at <= ?
+ """,
+ (created_at,),
+ )
+ return int(cursor.fetchone()[0])
+
+ @staticmethod
+ def get_user_queue_status(
+ user_id: int,
+ ) -> list[dict[str, Any]]:
+ """Return pending/processing/error items for a specific user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, url, email, status, created_at, error_message,
+ title, author
+ FROM queue
+ WHERE user_id = ? AND
+ status IN (
+ 'pending', 'processing', 'extracting',
+ 'synthesizing', 'uploading', 'error'
+ )
+ ORDER BY created_at DESC
+ LIMIT 20
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_user_recent_episodes(
+ user_id: int,
+ limit: int = 20,
+ ) -> list[dict[str, Any]]:
+ """Get recent episodes for a specific user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE user_id = ? "
+ "ORDER BY created_at DESC LIMIT ?",
+ (user_id, limit),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_user_all_episodes(
+ user_id: int,
+ ) -> list[dict[str, Any]]:
+ """Get all episodes for a specific user for RSS feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE user_id = ? "
+ "ORDER BY created_at DESC",
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def update_user_status(
+ user_id: int,
+ status: str,
+ ) -> None:
+ """Update user account status."""
+ if status not in {"pending", "active", "disabled"}:
+ msg = f"Invalid status: {status}"
+ raise ValueError(msg)
+
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET status = ? WHERE id = ?",
+ (status, user_id),
+ )
+ conn.commit()
+ logger.info("Updated user %s status to %s", user_id, status)
+
+ @staticmethod
+ def delete_user(user_id: int) -> None:
+ """Delete user and all associated data."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # 1. Get owned episode IDs
+ cursor.execute(
+ "SELECT id FROM episodes WHERE user_id = ?",
+ (user_id,),
+ )
+ owned_episode_ids = [row[0] for row in cursor.fetchall()]
+
+ # 2. Delete references to owned episodes
+ if owned_episode_ids:
+ # Construct placeholders for IN clause
+ placeholders = ",".join("?" * len(owned_episode_ids))
+
+ # Delete from user_episodes where these episodes are referenced
+ query = f"DELETE FROM user_episodes WHERE episode_id IN ({placeholders})" # noqa: S608, E501
+ cursor.execute(query, tuple(owned_episode_ids))
+
+ # Delete metrics for these episodes
+ query = f"DELETE FROM episode_metrics WHERE episode_id IN ({placeholders})" # noqa: S608, E501
+ cursor.execute(query, tuple(owned_episode_ids))
+
+ # 3. Delete owned episodes
+ cursor.execute("DELETE FROM episodes WHERE user_id = ?", (user_id,))
+
+ # 4. Delete user's data referencing others or themselves
+ cursor.execute(
+ "DELETE FROM user_episodes WHERE user_id = ?",
+ (user_id,),
+ )
+ cursor.execute(
+ "DELETE FROM episode_metrics WHERE user_id = ?",
+ (user_id,),
+ )
+ cursor.execute("DELETE FROM queue WHERE user_id = ?", (user_id,))
+
+ # 5. Delete user
+ cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
+
+ conn.commit()
+ logger.info("Deleted user %s and all associated data", user_id)
+
+ @staticmethod
+ def update_user_email(user_id: int, new_email: str) -> None:
+ """Update user's email address.
+
+ Args:
+ user_id: ID of the user to update
+ new_email: New email address
+
+ Raises:
+ ValueError: If email is already taken by another user
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "UPDATE users SET email = ? WHERE id = ?",
+ (new_email, user_id),
+ )
+ conn.commit()
+ logger.info("Updated user %s email to %s", user_id, new_email)
+ except sqlite3.IntegrityError:
+ msg = f"Email {new_email} is already taken"
+ raise ValueError(msg) from None
+
+ @staticmethod
+ def mark_episode_public(episode_id: int) -> None:
+ """Mark an episode as public."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE episodes SET is_public = 1 WHERE id = ?",
+ (episode_id,),
+ )
+ conn.commit()
+ logger.info("Marked episode %s as public", episode_id)
+
+ @staticmethod
+ def unmark_episode_public(episode_id: int) -> None:
+ """Mark an episode as private (not public)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE episodes SET is_public = 0 WHERE id = ?",
+ (episode_id,),
+ )
+ conn.commit()
+ logger.info("Unmarked episode %s as public", episode_id)
+
+ @staticmethod
+ def get_public_episodes(limit: int = 50) -> list[dict[str, Any]]:
+ """Get public episodes for public feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url
+ FROM episodes
+ WHERE is_public = 1
+ ORDER BY created_at DESC
+ LIMIT ?
+ """,
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def add_episode_to_user(user_id: int, episode_id: int) -> None:
+ """Add an episode to a user's feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "INSERT INTO user_episodes (user_id, episode_id) "
+ "VALUES (?, ?)",
+ (user_id, episode_id),
+ )
+ conn.commit()
+ logger.info(
+ "Added episode %s to user %s feed",
+ episode_id,
+ user_id,
+ )
+ except sqlite3.IntegrityError:
+ # Episode already in user's feed
+ logger.info(
+ "Episode %s already in user %s feed",
+ episode_id,
+ user_id,
+ )
+
+ @staticmethod
+ def user_has_episode(user_id: int, episode_id: int) -> bool:
+ """Check if a user has an episode in their feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT 1 FROM user_episodes "
+ "WHERE user_id = ? AND episode_id = ?",
+ (user_id, episode_id),
+ )
+ return cursor.fetchone() is not None
+
+ @staticmethod
+ def track_episode_metric(
+ episode_id: int,
+ event_type: str,
+ user_id: int | None = None,
+ ) -> None:
+ """Track an episode metric event.
+
+ Args:
+ episode_id: ID of the episode
+ event_type: Type of event ('added', 'played', 'downloaded')
+ user_id: Optional user ID (None for anonymous events)
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episode_metrics (episode_id, user_id, event_type) "
+ "VALUES (?, ?, ?)",
+ (episode_id, user_id, event_type),
+ )
+ conn.commit()
+ logger.info(
+ "Tracked %s event for episode %s (user: %s)",
+ event_type,
+ episode_id,
+ user_id or "anonymous",
+ )
+
+ @staticmethod
+ def get_user_episodes(user_id: int) -> list[dict[str, Any]]:
+ """Get all episodes in a user's feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.audio_url, e.duration, e.created_at,
+ e.content_length, e.author, e.original_url, e.is_public,
+ ue.added_at
+ FROM episodes e
+ JOIN user_episodes ue ON e.id = ue.episode_id
+ WHERE ue.user_id = ?
+ ORDER BY ue.added_at DESC
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_episode_by_url_hash(url_hash: str) -> dict[str, Any] | None:
+ """Get episode by original URL hash."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE original_url_hash = ?",
+ (url_hash,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_metrics_summary() -> dict[str, Any]:
+ """Get aggregate metrics summary for admin dashboard.
+
+ Returns:
+ dict with keys:
+ - total_episodes: Total number of episodes
+ - total_plays: Total play events
+ - total_downloads: Total download events
+ - total_adds: Total add events
+ - most_played: List of top 10 most played episodes
+ - most_downloaded: List of top 10 most downloaded episodes
+ - most_added: List of top 10 most added episodes
+ - total_users: Total number of users
+ - active_subscriptions: Number of active subscriptions
+ - submissions_24h: Submissions in last 24 hours
+ - submissions_7d: Submissions in last 7 days
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Get total episodes
+ cursor.execute("SELECT COUNT(*) as count FROM episodes")
+ total_episodes = cursor.fetchone()["count"]
+
+ # Get event counts
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM episode_metrics "
+ "WHERE event_type = 'played'",
+ )
+ total_plays = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM episode_metrics "
+ "WHERE event_type = 'downloaded'",
+ )
+ total_downloads = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM episode_metrics "
+ "WHERE event_type = 'added'",
+ )
+ total_adds = cursor.fetchone()["count"]
+
+ # Get most played episodes
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.author, COUNT(*) as play_count
+ FROM episode_metrics em
+ JOIN episodes e ON em.episode_id = e.id
+ WHERE em.event_type = 'played'
+ GROUP BY em.episode_id
+ ORDER BY play_count DESC
+ LIMIT 10
+ """,
+ )
+ most_played = [dict(row) for row in cursor.fetchall()]
+
+ # Get most downloaded episodes
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.author, COUNT(*) as download_count
+ FROM episode_metrics em
+ JOIN episodes e ON em.episode_id = e.id
+ WHERE em.event_type = 'downloaded'
+ GROUP BY em.episode_id
+ ORDER BY download_count DESC
+ LIMIT 10
+ """,
+ )
+ most_downloaded = [dict(row) for row in cursor.fetchall()]
+
+ # Get most added episodes
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.author, COUNT(*) as add_count
+ FROM episode_metrics em
+ JOIN episodes e ON em.episode_id = e.id
+ WHERE em.event_type = 'added'
+ GROUP BY em.episode_id
+ ORDER BY add_count DESC
+ LIMIT 10
+ """,
+ )
+ most_added = [dict(row) for row in cursor.fetchall()]
+
+ # Get user metrics
+ cursor.execute("SELECT COUNT(*) as count FROM users")
+ total_users = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM users "
+ "WHERE subscription_status = 'active'",
+ )
+ active_subscriptions = cursor.fetchone()["count"]
+
+ # Get recent submission metrics
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM queue "
+ "WHERE created_at >= datetime('now', '-1 day')",
+ )
+ submissions_24h = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM queue "
+ "WHERE created_at >= datetime('now', '-7 days')",
+ )
+ submissions_7d = cursor.fetchone()["count"]
+
+ return {
+ "total_episodes": total_episodes,
+ "total_plays": total_plays,
+ "total_downloads": total_downloads,
+ "total_adds": total_adds,
+ "most_played": most_played,
+ "most_downloaded": most_downloaded,
+ "most_added": most_added,
+ "total_users": total_users,
+ "active_subscriptions": active_subscriptions,
+ "submissions_24h": submissions_24h,
+ "submissions_7d": submissions_7d,
+ }
+
+ @staticmethod
+ def track_episode_event(
+ episode_id: int,
+ event_type: str,
+ user_id: int | None = None,
+ ) -> None:
+ """Track an episode event (added, played, downloaded)."""
+ if event_type not in {"added", "played", "downloaded"}:
+ msg = f"Invalid event type: {event_type}"
+ raise ValueError(msg)
+
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episode_metrics "
+ "(episode_id, user_id, event_type) VALUES (?, ?, ?)",
+ (episode_id, user_id, event_type),
+ )
+ conn.commit()
+ logger.info(
+ "Tracked %s event for episode %s",
+ event_type,
+ episode_id,
+ )
+
+ @staticmethod
+ def get_episode_metrics(episode_id: int) -> dict[str, int]:
+ """Get aggregated metrics for an episode."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT event_type, COUNT(*) as count
+ FROM episode_metrics
+ WHERE episode_id = ?
+ GROUP BY event_type
+ """,
+ (episode_id,),
+ )
+ rows = cursor.fetchall()
+ return {row["event_type"]: row["count"] for row in rows}
+
+ @staticmethod
+ def get_episode_metric_events(episode_id: int) -> list[dict[str, Any]]:
+ """Get raw metric events for an episode (for testing)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, episode_id, user_id, event_type, created_at
+ FROM episode_metrics
+ WHERE episode_id = ?
+ ORDER BY created_at DESC
+ """,
+ (episode_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def set_user_stripe_customer(user_id: int, customer_id: str) -> None:
+ """Link Stripe customer ID to user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET stripe_customer_id = ? WHERE id = ?",
+ (customer_id, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Linked user %s to Stripe customer %s",
+ user_id,
+ customer_id,
+ )
+
+ @staticmethod
+ def update_user_subscription( # noqa: PLR0913, PLR0917
+ user_id: int,
+ subscription_id: str,
+ status: str,
+ period_start: Any,
+ period_end: Any,
+ tier: str,
+ cancel_at_period_end: bool, # noqa: FBT001
+ ) -> None:
+ """Update user subscription details."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ stripe_subscription_id = ?,
+ subscription_status = ?,
+ current_period_start = ?,
+ current_period_end = ?,
+ plan_tier = ?,
+ cancel_at_period_end = ?
+ WHERE id = ?
+ """,
+ (
+ subscription_id,
+ status,
+ period_start.isoformat(),
+ period_end.isoformat(),
+ tier,
+ 1 if cancel_at_period_end else 0,
+ user_id,
+ ),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user_id,
+ tier,
+ status,
+ )
+
+ @staticmethod
+ def update_subscription_status(user_id: int, status: str) -> None:
+ """Update only the subscription status (e.g., past_due)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET subscription_status = ? WHERE id = ?",
+ (status, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription status to %s",
+ user_id,
+ status,
+ )
+
+ @staticmethod
+ def downgrade_to_free(user_id: int) -> None:
+ """Downgrade user to free tier and clear subscription data."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ plan_tier = 'free',
+ subscription_status = 'canceled',
+ stripe_subscription_id = NULL,
+ current_period_start = NULL,
+ current_period_end = NULL,
+ cancel_at_period_end = 0
+ WHERE id = ?
+ """,
+ (user_id,),
+ )
+ conn.commit()
+ logger.info("Downgraded user %s to free tier", user_id)
+
+ @staticmethod
+ def get_user_by_stripe_customer_id(
+ customer_id: str,
+ ) -> dict[str, Any] | None:
+ """Get user by Stripe customer ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM users WHERE stripe_customer_id = ?",
+ (customer_id,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def has_processed_stripe_event(event_id: str) -> bool:
+ """Check if Stripe event has already been processed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT id FROM stripe_events WHERE id = ?",
+ (event_id,),
+ )
+ return cursor.fetchone() is not None
+
+ @staticmethod
+ def mark_stripe_event_processed(
+ event_id: str,
+ event_type: str,
+ payload: bytes,
+ ) -> None:
+ """Mark Stripe event as processed for idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT OR IGNORE INTO stripe_events (id, type, payload) "
+ "VALUES (?, ?, ?)",
+ (event_id, event_type, payload.decode("utf-8")),
+ )
+ conn.commit()
+
+ @staticmethod
+ def get_usage(
+ user_id: int,
+ period_start: Any,
+ period_end: Any,
+ ) -> dict[str, int]:
+ """Get usage stats for user in period.
+
+ Counts episodes added to user's feed (via user_episodes table)
+ during the billing period, regardless of who created them.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Count articles added to user's feed in period
+ # Uses user_episodes junction table to track when episodes
+ # were added, which correctly handles shared/existing episodes
+ cursor.execute(
+ """
+ SELECT COUNT(*) as count, SUM(e.duration) as total_seconds
+ FROM user_episodes ue
+ JOIN episodes e ON e.id = ue.episode_id
+ WHERE ue.user_id = ? AND ue.added_at >= ? AND ue.added_at < ?
+ """,
+ (user_id, period_start.isoformat(), period_end.isoformat()),
+ )
+ row = cursor.fetchone()
+
+ articles = row["count"] if row else 0
+ total_seconds = (
+ row["total_seconds"] if row and row["total_seconds"] else 0
+ )
+ minutes = total_seconds // 60
+
+ return {"articles": articles, "minutes": minutes}
+
+
+class TestDatabase(Test.TestCase):
+ """Test the Database class."""
+
+ @staticmethod
+ def setUp() -> None:
+ """Set up test database."""
+ Database.init_db()
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ Database.teardown()
+ # Clear user ID
+ self.user_id = None
+
+ def test_init_db(self) -> None:
+ """Verify all tables and indexes are created correctly."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check tables exist
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+ tables = {row[0] for row in cursor.fetchall()}
+ self.assertIn("queue", tables)
+ self.assertIn("episodes", tables)
+ self.assertIn("users", tables)
+
+ # Check indexes exist
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='index'")
+ indexes = {row[0] for row in cursor.fetchall()}
+ self.assertIn("idx_queue_status", indexes)
+ self.assertIn("idx_queue_created", indexes)
+ self.assertIn("idx_episodes_created", indexes)
+ self.assertIn("idx_queue_user_id", indexes)
+ self.assertIn("idx_episodes_user_id", indexes)
+
+ def test_connection_context_manager(self) -> None:
+ """Ensure connections are properly closed."""
+ # Get a connection and verify it works
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT 1")
+ result = cursor.fetchone()
+ self.assertEqual(result[0], 1)
+
+ # Connection should be closed after context manager
+ with pytest.raises(sqlite3.ProgrammingError):
+ cursor.execute("SELECT 1")
+
+ def test_migration_idempotency(self) -> None:
+ """Verify migrations can run multiple times safely."""
+ # Run migration multiple times
+ Database.migrate_to_multi_user()
+ Database.migrate_to_multi_user()
+ Database.migrate_to_multi_user()
+
+ # Should still work fine
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users")
+ # Should not raise an error
+ # Test completed successfully - migration worked
+ self.assertIsNotNone(conn)
+
+ def test_get_metrics_summary_extended(self) -> None:
+ """Verify extended metrics summary."""
+ # Create some data
+ user_id, _ = Database.create_user("test@example.com")
+ Database.create_episode(
+ "Test Article",
+ "url",
+ 100,
+ 1000,
+ user_id,
+ )
+
+ # Create a queue item
+ Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ user_id,
+ )
+
+ metrics = Database.get_metrics_summary()
+
+ self.assertIn("total_users", metrics)
+ self.assertIn("active_subscriptions", metrics)
+ self.assertIn("submissions_24h", metrics)
+ self.assertIn("submissions_7d", metrics)
+
+ self.assertEqual(metrics["total_users"], 1)
+ self.assertEqual(metrics["submissions_24h"], 1)
+ self.assertEqual(metrics["submissions_7d"], 1)
+
+
+class TestUserManagement(Test.TestCase):
+ """Test user management functionality."""
+
+ @staticmethod
+ def setUp() -> None:
+ """Set up test database."""
+ Database.init_db()
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Database.teardown()
+
+ def test_create_user(self) -> None:
+ """Create user with unique email and token."""
+ user_id, token = Database.create_user("test@example.com")
+
+ self.assertIsInstance(user_id, int)
+ self.assertIsInstance(token, str)
+ self.assertGreater(len(token), 20) # Should be a secure token
+
+ def test_create_duplicate_user(self) -> None:
+ """Verify duplicate emails return existing user."""
+ # Create first user
+ user_id1, token1 = Database.create_user(
+ "test@example.com",
+ )
+
+ # Try to create duplicate
+ user_id2, token2 = Database.create_user(
+ "test@example.com",
+ )
+
+ # Should return same user
+ self.assertIsNotNone(user_id1)
+ self.assertIsNotNone(user_id2)
+ self.assertEqual(user_id1, user_id2)
+ self.assertEqual(token1, token2)
+
+ def test_get_user_by_email(self) -> None:
+ """Retrieve user by email."""
+ user_id, token = Database.create_user("test@example.com")
+
+ user = Database.get_user_by_email("test@example.com")
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["id"], user_id)
+ self.assertEqual(user["email"], "test@example.com")
+ self.assertEqual(user["token"], token)
+
+ def test_get_user_by_token(self) -> None:
+ """Retrieve user by RSS token."""
+ user_id, token = Database.create_user("test@example.com")
+
+ user = Database.get_user_by_token(token)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["id"], user_id)
+ self.assertEqual(user["email"], "test@example.com")
+
+ def test_get_user_by_id(self) -> None:
+ """Retrieve user by ID."""
+ user_id, token = Database.create_user("test@example.com")
+
+ user = Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["email"], "test@example.com")
+ self.assertEqual(user["token"], token)
+
+ def test_invalid_user_lookups(self) -> None:
+ """Verify None returned for non-existent users."""
+ self.assertIsNone(
+ Database.get_user_by_email("nobody@example.com"),
+ )
+ self.assertIsNone(
+ Database.get_user_by_token("invalid-token"),
+ )
+ self.assertIsNone(Database.get_user_by_id(9999))
+
+ def test_token_uniqueness(self) -> None:
+ """Ensure tokens are cryptographically unique."""
+ tokens = set()
+ for i in range(10):
+ _, token = Database.create_user(
+ f"user{i}@example.com",
+ )
+ tokens.add(token)
+
+ # All tokens should be unique
+ self.assertEqual(len(tokens), 10)
+
+ def test_delete_user(self) -> None:
+ """Test user deletion and cleanup."""
+ # Create user
+ user_id, _ = Database.create_user("delete_me@example.com")
+
+ # Create some data for the user
+ Database.add_to_queue(
+ "https://example.com/article",
+ "delete_me@example.com",
+ user_id,
+ )
+
+ ep_id = Database.create_episode(
+ title="Test Episode",
+ audio_url="url",
+ duration=100,
+ content_length=1000,
+ user_id=user_id,
+ )
+ Database.add_episode_to_user(user_id, ep_id)
+ Database.track_episode_metric(ep_id, "played", user_id)
+
+ # Delete user
+ Database.delete_user(user_id)
+
+ # Verify user is gone
+ self.assertIsNone(Database.get_user_by_id(user_id))
+
+ # Verify queue items are gone
+ queue = Database.get_user_queue_status(user_id)
+ self.assertEqual(len(queue), 0)
+
+ # Verify episodes are gone (direct lookup)
+ self.assertIsNone(Database.get_episode_by_id(ep_id))
+
+ def test_update_user_email(self) -> None:
+ """Update user email address."""
+ user_id, _ = Database.create_user("old@example.com")
+
+ # Update email
+ Database.update_user_email(user_id, "new@example.com")
+
+ # Verify update
+ user = Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ if user:
+ self.assertEqual(user["email"], "new@example.com")
+
+ # Old email should not exist
+ self.assertIsNone(Database.get_user_by_email("old@example.com"))
+
+ @staticmethod
+ def test_update_user_email_duplicate() -> None:
+ """Cannot update to an existing email."""
+ user_id1, _ = Database.create_user("user1@example.com")
+ Database.create_user("user2@example.com")
+
+ # Try to update user1 to user2's email
+ with pytest.raises(ValueError, match="already taken"):
+ Database.update_user_email(user_id1, "user2@example.com")
+
+
+class TestQueueOperations(Test.TestCase):
+ """Test queue operations."""
+
+ def setUp(self) -> None:
+ """Set up test database with a user."""
+ Database.init_db()
+ self.user_id, _ = Database.create_user("test@example.com")
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Database.teardown()
+
+ def test_add_to_queue(self) -> None:
+ """Add job with user association."""
+ job_id = Database.add_to_queue(
+ "https://example.com/article",
+ "test@example.com",
+ self.user_id,
+ )
+
+ self.assertIsInstance(job_id, int)
+ self.assertGreater(job_id, 0)
+
+ def test_get_pending_jobs(self) -> None:
+ """Retrieve jobs in correct order."""
+ # Add multiple jobs
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01) # Ensure different timestamps
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Get pending jobs
+ jobs = Database.get_pending_jobs(limit=10)
+
+ self.assertEqual(len(jobs), 3)
+ # Should be in order of creation (oldest first)
+ self.assertEqual(jobs[0]["id"], job1)
+ self.assertEqual(jobs[1]["id"], job2)
+ self.assertEqual(jobs[2]["id"], job3)
+
+ def test_update_job_status(self) -> None:
+ """Update status and error messages."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Update to processing
+ Database.update_job_status(job_id, "processing")
+ job = Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "processing")
+
+ # Update to error with message
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Network timeout",
+ )
+ job = Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "error")
+ self.assertEqual(job["error_message"], "Network timeout")
+ self.assertEqual(job["retry_count"], 1)
+
+ def test_retry_job(self) -> None:
+ """Reset failed jobs for retry."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Set to error
+ Database.update_job_status(job_id, "error", "Failed")
+
+ # Retry
+ Database.retry_job(job_id)
+ job = Database.get_job_by_id(job_id)
+
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "pending")
+ self.assertIsNone(job["error_message"])
+
+ def test_delete_job(self) -> None:
+ """Remove jobs from queue."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Delete job
+ Database.delete_job(job_id)
+
+ # Should not exist
+ job = Database.get_job_by_id(job_id)
+ self.assertIsNone(job)
+
+ def test_get_retryable_jobs(self) -> None:
+ """Find jobs eligible for retry."""
+ # Add job and mark as error
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+ Database.update_job_status(job_id, "error", "Failed")
+
+ # Should be retryable
+ retryable = Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 1)
+ self.assertEqual(retryable[0]["id"], job_id)
+
+ # Exceed retry limit
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Failed again",
+ )
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Failed yet again",
+ )
+
+ # Should not be retryable anymore
+ retryable = Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 0)
+
+ def test_user_queue_isolation(self) -> None:
+ """Ensure users only see their own jobs."""
+ # Create second user
+ user2_id, _ = Database.create_user("user2@example.com")
+
+ # Add jobs for both users
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "user2@example.com",
+ user2_id,
+ )
+
+ # Get user-specific queue status
+ user1_jobs = Database.get_user_queue_status(self.user_id)
+ user2_jobs = Database.get_user_queue_status(user2_id)
+
+ self.assertEqual(len(user1_jobs), 1)
+ self.assertEqual(user1_jobs[0]["id"], job1)
+
+ self.assertEqual(len(user2_jobs), 1)
+ self.assertEqual(user2_jobs[0]["id"], job2)
+
+ def test_status_counts(self) -> None:
+ """Verify status aggregation queries."""
+ # Add jobs with different statuses
+ Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ Database.update_job_status(job2, "processing")
+ Database.update_job_status(job3, "error", "Failed")
+
+ # Get status counts
+ counts = Database.get_user_status_counts(self.user_id)
+
+ self.assertEqual(counts.get("pending", 0), 1)
+ self.assertEqual(counts.get("processing", 0), 1)
+ self.assertEqual(counts.get("error", 0), 1)
+
+ def test_queue_position(self) -> None:
+ """Verify queue position calculation."""
+ # Add multiple pending jobs
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Check positions
+ self.assertEqual(Database.get_queue_position(job1), 1)
+ self.assertEqual(Database.get_queue_position(job2), 2)
+ self.assertEqual(Database.get_queue_position(job3), 3)
+
+ # Move job 2 to processing
+ Database.update_job_status(job2, "processing")
+
+ # Check positions (job 3 should now be 2nd pending job)
+ self.assertEqual(Database.get_queue_position(job1), 1)
+ self.assertIsNone(Database.get_queue_position(job2))
+ self.assertEqual(Database.get_queue_position(job3), 2)
+
+
+class TestEpisodeManagement(Test.TestCase):
+ """Test episode management functionality."""
+
+ def setUp(self) -> None:
+ """Set up test database with a user."""
+ Database.init_db()
+ self.user_id, _ = Database.create_user("test@example.com")
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Database.teardown()
+
+ def test_create_episode(self) -> None:
+ """Create episode with user association."""
+ episode_id = Database.create_episode(
+ title="Test Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=5000,
+ user_id=self.user_id,
+ )
+
+ self.assertIsInstance(episode_id, int)
+ self.assertGreater(episode_id, 0)
+
+ def test_get_recent_episodes(self) -> None:
+ """Retrieve episodes in reverse chronological order."""
+ # Create multiple episodes
+ ep1 = Database.create_episode(
+ "Article 1",
+ "url1",
+ 100,
+ 1000,
+ self.user_id,
+ )
+ time.sleep(0.01)
+ ep2 = Database.create_episode(
+ "Article 2",
+ "url2",
+ 200,
+ 2000,
+ self.user_id,
+ )
+ time.sleep(0.01)
+ ep3 = Database.create_episode(
+ "Article 3",
+ "url3",
+ 300,
+ 3000,
+ self.user_id,
+ )
+
+ # Get recent episodes
+ episodes = Database.get_recent_episodes(limit=10)
+
+ self.assertEqual(len(episodes), 3)
+ # Should be in reverse chronological order
+ self.assertEqual(episodes[0]["id"], ep3)
+ self.assertEqual(episodes[1]["id"], ep2)
+ self.assertEqual(episodes[2]["id"], ep1)
+
+ def test_get_user_episodes(self) -> None:
+ """Ensure user isolation for episodes."""
+ # Create second user
+ user2_id, _ = Database.create_user("user2@example.com")
+
+ # Create episodes for both users
+ ep1 = Database.create_episode(
+ "User1 Article",
+ "url1",
+ 100,
+ 1000,
+ self.user_id,
+ )
+ ep2 = Database.create_episode(
+ "User2 Article",
+ "url2",
+ 200,
+ 2000,
+ user2_id,
+ )
+
+ # Get user-specific episodes
+ user1_episodes = Database.get_user_all_episodes(
+ self.user_id,
+ )
+ user2_episodes = Database.get_user_all_episodes(user2_id)
+
+ self.assertEqual(len(user1_episodes), 1)
+ self.assertEqual(user1_episodes[0]["id"], ep1)
+
+ self.assertEqual(len(user2_episodes), 1)
+ self.assertEqual(user2_episodes[0]["id"], ep2)
+
+ def test_episode_metadata(self) -> None:
+ """Verify duration and content_length storage."""
+ Database.create_episode(
+ title="Test Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=12345,
+ content_length=98765,
+ user_id=self.user_id,
+ )
+
+ episodes = Database.get_user_all_episodes(self.user_id)
+ episode = episodes[0]
+
+ self.assertEqual(episode["duration"], 12345)
+ self.assertEqual(episode["content_length"], 98765)
+
+
+def test() -> None:
+ """Run the tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestDatabase,
+ TestUserManagement,
+ TestQueueOperations,
+ TestEpisodeManagement,
+ ],
+ )
+
+
+def main() -> None:
+ """Run all PodcastItLater.Core tests."""
+ if "test" in sys.argv:
+ test()