diff options
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 410 |
1 files changed, 139 insertions, 271 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py index 278b97e..0ca945c 100644 --- a/Biz/PodcastItLater/Core.py +++ b/Biz/PodcastItLater/Core.py @@ -13,12 +13,14 @@ Includes: import Omni.App as App import Omni.Log as Log import Omni.Test as Test +import os import pathlib import pytest import secrets import sqlite3 import sys import time +import typing from collections.abc import Iterator from contextlib import contextmanager from typing import Any @@ -26,33 +28,48 @@ from typing import Any logger = Log.setup() +CODEROOT = pathlib.Path(os.getenv("CODEROOT", ".")) +DATA_DIR = pathlib.Path( + os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"), +) + +# Constants for UI display +URL_TRUNCATE_LENGTH = 80 +TITLE_TRUNCATE_LENGTH = 50 +ERROR_TRUNCATE_LENGTH = 50 + +# Admin whitelist +ADMIN_EMAILS = ["ben@bensima.com"] + + +def is_admin(user: dict[str, typing.Any] | None) -> bool: + """Check if user is an admin based on email whitelist.""" + if not user: + return False + return user.get("email", "").lower() in [ + email.lower() for email in ADMIN_EMAILS + ] + + class Database: # noqa: PLR0904 """Data access layer for PodcastItLater database operations.""" @staticmethod - def get_default_db_path() -> str: - """Get the default database path based on environment.""" - area = App.from_env() - if area == App.Area.Test: - return "_/var/podcastitlater/podcast.db" - return "/var/podcastitlater/podcast.db" + def teardown() -> None: + """Delete the existing database, for cleanup after tests.""" + db_path = DATA_DIR / "podcast.db" + if db_path.exists(): + db_path.unlink() @staticmethod @contextmanager - def get_connection( - db_path: str | None = None, - ) -> Iterator[sqlite3.Connection]: + def get_connection() -> Iterator[sqlite3.Connection]: """Context manager for database connections. - Args: - db_path: Database file path. If None, uses environment-appropriate - default. - Yields: sqlite3.Connection: Database connection with row factory set. """ - if db_path is None: - db_path = Database.get_default_db_path() + db_path = DATA_DIR / "podcast.db" conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row try: @@ -61,16 +78,9 @@ class Database: # noqa: PLR0904 conn.close() @staticmethod - def init_db(db_path: str | None = None) -> None: + def init_db() -> None: """Initialize database with required tables.""" - if db_path is None: - db_path = Database.get_default_db_path() - - # Ensure directory exists - db_dir = pathlib.Path(db_path).parent - db_dir.mkdir(parents=True, exist_ok=True) - - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Queue table for job processing @@ -117,26 +127,25 @@ class Database: # noqa: PLR0904 logger.info("Database initialized successfully") # Run migration to add user support - Database.migrate_to_multi_user(db_path) + Database.migrate_to_multi_user() # Run migration to add metadata fields - Database.migrate_add_metadata_fields(db_path) + Database.migrate_add_metadata_fields() # Run migration to add episode metadata fields - Database.migrate_add_episode_metadata(db_path) + Database.migrate_add_episode_metadata() # Run migration to add user status field - Database.migrate_add_user_status(db_path) + Database.migrate_add_user_status() # Run migration to add default titles - Database.migrate_add_default_titles(db_path) + Database.migrate_add_default_titles() @staticmethod - def add_to_queue( # noqa: PLR0913, PLR0917 + def add_to_queue( url: str, email: str, user_id: int, - db_path: str | None = None, title: str | None = None, author: str | None = None, ) -> int: @@ -145,9 +154,7 @@ class Database: # noqa: PLR0904 Raises: ValueError: If job ID cannot be retrieved after insert. """ - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "INSERT INTO queue (url, email, user_id, title, author) " @@ -165,12 +172,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_pending_jobs( limit: int = 10, - db_path: str | None = None, ) -> list[dict[str, Any]]: """Fetch jobs with status='pending' ordered by creation time.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM queue WHERE status = 'pending' " @@ -185,12 +189,9 @@ class Database: # noqa: PLR0904 job_id: int, status: str, error: str | None = None, - db_path: str | None = None, ) -> None: """Update job status and error message.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() if status == "error": cursor.execute( @@ -209,12 +210,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_job_by_id( job_id: int, - db_path: str | None = None, ) -> dict[str, Any] | None: """Fetch single job by ID.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) row = cursor.fetchone() @@ -229,16 +227,13 @@ class Database: # noqa: PLR0904 user_id: int | None = None, author: str | None = None, original_url: str | None = None, - db_path: str | None = None, ) -> int: """Insert episode record, return episode ID. Raises: ValueError: If episode ID cannot be retrieved after insert. """ - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "INSERT INTO episodes " @@ -265,12 +260,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_recent_episodes( limit: int = 20, - db_path: str | None = None, ) -> list[dict[str, Any]]: """Get recent episodes for RSS feed generation.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", @@ -280,11 +272,9 @@ class Database: # noqa: PLR0904 return [dict(row) for row in rows] @staticmethod - def get_queue_status_summary(db_path: str | None = None) -> dict[str, Any]: + def get_queue_status_summary() -> dict[str, Any]: """Get queue status summary for web interface.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Count jobs by status @@ -304,11 +294,9 @@ class Database: # noqa: PLR0904 return {"status_counts": status_counts, "recent_jobs": recent_jobs} @staticmethod - def get_queue_status(db_path: str | None = None) -> list[dict[str, Any]]: + def get_queue_status() -> list[dict[str, Any]]: """Return pending/processing/error items for web interface.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT id, url, email, status, created_at, error_message, @@ -323,13 +311,10 @@ class Database: # noqa: PLR0904 @staticmethod def get_all_episodes( - db_path: str | None = None, user_id: int | None = None, ) -> list[dict[str, Any]]: """Return all episodes for RSS feed.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() if user_id: cursor.execute( @@ -355,12 +340,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_retryable_jobs( max_retries: int = 3, - db_path: str | None = None, ) -> list[dict[str, Any]]: """Get failed jobs that can be retried.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM queue WHERE status = 'error' " @@ -371,11 +353,9 @@ class Database: # noqa: PLR0904 return [dict(row) for row in rows] @staticmethod - def retry_job(job_id: int, db_path: str | None = None) -> None: + def retry_job(job_id: int) -> None: """Reset a job to pending status for retry.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "UPDATE queue SET status = 'pending', " @@ -386,11 +366,9 @@ class Database: # noqa: PLR0904 logger.info("Reset job %s to pending for retry", job_id) @staticmethod - def delete_job(job_id: int, db_path: str | None = None) -> None: + def delete_job(job_id: int) -> None: """Delete a job from the queue.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) conn.commit() @@ -398,13 +376,10 @@ class Database: # noqa: PLR0904 @staticmethod def get_all_queue_items( - db_path: str | None = None, user_id: int | None = None, ) -> list[dict[str, Any]]: """Return all queue items for admin view.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() if user_id: cursor.execute( @@ -428,11 +403,9 @@ class Database: # noqa: PLR0904 return [dict(row) for row in rows] @staticmethod - def get_status_counts(db_path: str | None = None) -> dict[str, int]: + def get_status_counts() -> dict[str, int]: """Get count of queue items by status.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT status, COUNT(*) as count @@ -445,12 +418,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_status_counts( user_id: int, - db_path: str | None = None, ) -> dict[str, int]: """Get count of queue items by status for a specific user.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ @@ -465,11 +435,9 @@ class Database: # noqa: PLR0904 return {row["status"]: row["count"] for row in rows} @staticmethod - def migrate_to_multi_user(db_path: str | None = None) -> None: + def migrate_to_multi_user() -> None: """Migrate database to support multiple users.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Create users table @@ -518,11 +486,9 @@ class Database: # noqa: PLR0904 logger.info("Database migrated to support multiple users") @staticmethod - def migrate_add_metadata_fields(db_path: str | None = None) -> None: + def migrate_add_metadata_fields() -> None: """Add title and author fields to queue table.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Check if columns already exist @@ -540,11 +506,9 @@ class Database: # noqa: PLR0904 logger.info("Database migrated to support metadata fields") @staticmethod - def migrate_add_episode_metadata(db_path: str | None = None) -> None: + def migrate_add_episode_metadata() -> None: """Add author and original_url fields to episodes table.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Check if columns already exist @@ -564,11 +528,9 @@ class Database: # noqa: PLR0904 logger.info("Database migrated to support episode metadata fields") @staticmethod - def migrate_add_user_status(db_path: str | None = None) -> None: + def migrate_add_user_status() -> None: """Add status field to users table.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Check if column already exists @@ -592,11 +554,9 @@ class Database: # noqa: PLR0904 logger.info("Database migrated to support user status") @staticmethod - def migrate_add_default_titles(db_path: str | None = None) -> None: + def migrate_add_default_titles() -> None: """Add default titles to queue items that have None titles.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Update queue items with NULL titles to have a default @@ -616,19 +576,16 @@ class Database: # noqa: PLR0904 ) @staticmethod - def create_user(email: str, db_path: str | None = None) -> tuple[int, str]: + def create_user(email: str) -> tuple[int, str]: """Create a new user and return (user_id, token). Raises: ValueError: If user ID cannot be retrieved after insert or if user not found. """ - if db_path is None: - db_path = Database.get_default_db_path() # Generate a secure token for RSS feed access token = secrets.token_urlsafe(32) - - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() try: cursor.execute( @@ -658,12 +615,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_by_email( email: str, - db_path: str | None = None, ) -> dict[str, Any] | None: """Get user by email address.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) row = cursor.fetchone() @@ -672,12 +626,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_by_token( token: str, - db_path: str | None = None, ) -> dict[str, Any] | None: """Get user by RSS token.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) row = cursor.fetchone() @@ -686,12 +637,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_by_id( user_id: int, - db_path: str | None = None, ) -> dict[str, Any] | None: """Get user by ID.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) row = cursor.fetchone() @@ -700,12 +648,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_queue_status( user_id: int, - db_path: str | None = None, ) -> list[dict[str, Any]]: """Return pending/processing/error items for a specific user.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ @@ -726,12 +671,9 @@ class Database: # noqa: PLR0904 def get_user_recent_episodes( user_id: int, limit: int = 20, - db_path: str | None = None, ) -> list[dict[str, Any]]: """Get recent episodes for a specific user.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM episodes WHERE user_id = ? " @@ -744,12 +686,9 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_all_episodes( user_id: int, - db_path: str | None = None, ) -> list[dict[str, Any]]: """Get all episodes for a specific user for RSS feed.""" - if db_path is None: - db_path = Database.get_default_db_path() - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM episodes WHERE user_id = ? " @@ -763,16 +702,13 @@ class Database: # noqa: PLR0904 def update_user_status( user_id: int, status: str, - db_path: str | None = None, ) -> None: """Update user account status.""" - if db_path is None: - db_path = Database.get_default_db_path() if status not in {"pending", "active", "disabled"}: msg = f"Invalid status: {status}" raise ValueError(msg) - with Database.get_connection(db_path) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute( "UPDATE users SET status = ? WHERE id = ?", @@ -785,29 +721,20 @@ class Database: # noqa: PLR0904 class TestDatabase(Test.TestCase): """Test the Database class.""" - def setUp(self) -> None: + @staticmethod + def setUp() -> None: """Set up test database.""" - self.test_db = "_/var/podcastitlater/test_podcast.db" - - # Ensure test directory exists - test_db_dir = pathlib.Path(self.test_db).parent - test_db_dir.mkdir(parents=True, exist_ok=True) - - # Clean up any existing test database - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() - Database.init_db(self.test_db) + Database.init_db() def tearDown(self) -> None: """Clean up test database.""" - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() + Database.teardown() + # Clear user ID + self.user_id = None def test_init_db(self) -> None: """Verify all tables and indexes are created correctly.""" - with Database.get_connection(self.test_db) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() # Check tables exist @@ -829,7 +756,7 @@ class TestDatabase(Test.TestCase): def test_connection_context_manager(self) -> None: """Ensure connections are properly closed.""" # Get a connection and verify it works - with Database.get_connection(self.test_db) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT 1") result = cursor.fetchone() @@ -842,43 +769,35 @@ class TestDatabase(Test.TestCase): def test_migration_idempotency(self) -> None: """Verify migrations can run multiple times safely.""" # Run migration multiple times - Database.migrate_to_multi_user(self.test_db) - Database.migrate_to_multi_user(self.test_db) - Database.migrate_to_multi_user(self.test_db) + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() # Should still work fine - with Database.get_connection(self.test_db) as conn: + with Database.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users") # Should not raise an error + # Test completed successfully - migration worked + self.assertIsNotNone(conn) class TestUserManagement(Test.TestCase): """Test user management functionality.""" - def setUp(self) -> None: + @staticmethod + def setUp() -> None: """Set up test database.""" - self.test_db = "_/var/podcastitlater/test_podcast.db" - - # Ensure test directory exists - test_db_dir = pathlib.Path(self.test_db).parent - test_db_dir.mkdir(parents=True, exist_ok=True) - - # Clean up any existing test database - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() - Database.init_db(self.test_db) + Database.init_db() - def tearDown(self) -> None: + @staticmethod + def tearDown() -> None: """Clean up test database.""" - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() + Database.teardown() def test_create_user(self) -> None: """Create user with unique email and token.""" - user_id, token = Database.create_user("test@example.com", self.test_db) + user_id, token = Database.create_user("test@example.com") self.assertIsInstance(user_id, int) self.assertIsInstance(token, str) @@ -889,13 +808,11 @@ class TestUserManagement(Test.TestCase): # Create first user user_id1, token1 = Database.create_user( "test@example.com", - self.test_db, ) # Try to create duplicate user_id2, token2 = Database.create_user( "test@example.com", - self.test_db, ) # Should return same user @@ -906,9 +823,9 @@ class TestUserManagement(Test.TestCase): def test_get_user_by_email(self) -> None: """Retrieve user by email.""" - user_id, token = Database.create_user("test@example.com", self.test_db) + user_id, token = Database.create_user("test@example.com") - user = Database.get_user_by_email("test@example.com", self.test_db) + user = Database.get_user_by_email("test@example.com") self.assertIsNotNone(user) if user is None: self.fail("User should not be None") @@ -918,9 +835,9 @@ class TestUserManagement(Test.TestCase): def test_get_user_by_token(self) -> None: """Retrieve user by RSS token.""" - user_id, token = Database.create_user("test@example.com", self.test_db) + user_id, token = Database.create_user("test@example.com") - user = Database.get_user_by_token(token, self.test_db) + user = Database.get_user_by_token(token) self.assertIsNotNone(user) if user is None: self.fail("User should not be None") @@ -929,9 +846,9 @@ class TestUserManagement(Test.TestCase): def test_get_user_by_id(self) -> None: """Retrieve user by ID.""" - user_id, token = Database.create_user("test@example.com", self.test_db) + user_id, token = Database.create_user("test@example.com") - user = Database.get_user_by_id(user_id, self.test_db) + user = Database.get_user_by_id(user_id) self.assertIsNotNone(user) if user is None: self.fail("User should not be None") @@ -941,12 +858,12 @@ class TestUserManagement(Test.TestCase): def test_invalid_user_lookups(self) -> None: """Verify None returned for non-existent users.""" self.assertIsNone( - Database.get_user_by_email("nobody@example.com", self.test_db), + Database.get_user_by_email("nobody@example.com"), ) self.assertIsNone( - Database.get_user_by_token("invalid-token", self.test_db), + Database.get_user_by_token("invalid-token"), ) - self.assertIsNone(Database.get_user_by_id(9999, self.test_db)) + self.assertIsNone(Database.get_user_by_id(9999)) def test_token_uniqueness(self) -> None: """Ensure tokens are cryptographically unique.""" @@ -954,7 +871,6 @@ class TestUserManagement(Test.TestCase): for i in range(10): _, token = Database.create_user( f"user{i}@example.com", - self.test_db, ) tokens.add(token) @@ -967,24 +883,13 @@ class TestQueueOperations(Test.TestCase): def setUp(self) -> None: """Set up test database with a user.""" - self.test_db = "_/var/podcastitlater/test_podcast.db" - - # Ensure test directory exists - test_db_dir = pathlib.Path(self.test_db).parent - test_db_dir.mkdir(parents=True, exist_ok=True) + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") - # Clean up any existing test database - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() - Database.init_db(self.test_db) - self.user_id, _ = Database.create_user("test@example.com", self.test_db) - - def tearDown(self) -> None: + @staticmethod + def tearDown() -> None: """Clean up test database.""" - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() + Database.teardown() def test_add_to_queue(self) -> None: """Add job with user association.""" @@ -992,7 +897,6 @@ class TestQueueOperations(Test.TestCase): "https://example.com/article", "test@example.com", self.user_id, - self.test_db, ) self.assertIsInstance(job_id, int) @@ -1005,25 +909,22 @@ class TestQueueOperations(Test.TestCase): "https://example.com/1", "test@example.com", self.user_id, - self.test_db, ) time.sleep(0.01) # Ensure different timestamps job2 = Database.add_to_queue( "https://example.com/2", "test@example.com", self.user_id, - self.test_db, ) time.sleep(0.01) job3 = Database.add_to_queue( "https://example.com/3", "test@example.com", self.user_id, - self.test_db, ) # Get pending jobs - jobs = Database.get_pending_jobs(limit=10, db_path=self.test_db) + jobs = Database.get_pending_jobs(limit=10) self.assertEqual(len(jobs), 3) # Should be in order of creation (oldest first) @@ -1037,12 +938,11 @@ class TestQueueOperations(Test.TestCase): "https://example.com", "test@example.com", self.user_id, - self.test_db, ) # Update to processing - Database.update_job_status(job_id, "processing", db_path=self.test_db) - job = Database.get_job_by_id(job_id, self.test_db) + Database.update_job_status(job_id, "processing") + job = Database.get_job_by_id(job_id) self.assertIsNotNone(job) if job is None: self.fail("Job should not be None") @@ -1053,9 +953,8 @@ class TestQueueOperations(Test.TestCase): job_id, "error", "Network timeout", - self.test_db, ) - job = Database.get_job_by_id(job_id, self.test_db) + job = Database.get_job_by_id(job_id) self.assertIsNotNone(job) if job is None: self.fail("Job should not be None") @@ -1069,15 +968,14 @@ class TestQueueOperations(Test.TestCase): "https://example.com", "test@example.com", self.user_id, - self.test_db, ) # Set to error - Database.update_job_status(job_id, "error", "Failed", self.test_db) + Database.update_job_status(job_id, "error", "Failed") # Retry - Database.retry_job(job_id, self.test_db) - job = Database.get_job_by_id(job_id, self.test_db) + Database.retry_job(job_id) + job = Database.get_job_by_id(job_id) self.assertIsNotNone(job) if job is None: @@ -1091,14 +989,13 @@ class TestQueueOperations(Test.TestCase): "https://example.com", "test@example.com", self.user_id, - self.test_db, ) # Delete job - Database.delete_job(job_id, self.test_db) + Database.delete_job(job_id) # Should not exist - job = Database.get_job_by_id(job_id, self.test_db) + job = Database.get_job_by_id(job_id) self.assertIsNone(job) def test_get_retryable_jobs(self) -> None: @@ -1108,14 +1005,12 @@ class TestQueueOperations(Test.TestCase): "https://example.com", "test@example.com", self.user_id, - self.test_db, ) - Database.update_job_status(job_id, "error", "Failed", self.test_db) + Database.update_job_status(job_id, "error", "Failed") # Should be retryable retryable = Database.get_retryable_jobs( max_retries=3, - db_path=self.test_db, ) self.assertEqual(len(retryable), 1) self.assertEqual(retryable[0]["id"], job_id) @@ -1125,44 +1020,39 @@ class TestQueueOperations(Test.TestCase): job_id, "error", "Failed again", - self.test_db, ) Database.update_job_status( job_id, "error", "Failed yet again", - self.test_db, ) # Should not be retryable anymore retryable = Database.get_retryable_jobs( max_retries=3, - db_path=self.test_db, ) self.assertEqual(len(retryable), 0) def test_user_queue_isolation(self) -> None: """Ensure users only see their own jobs.""" # Create second user - user2_id, _ = Database.create_user("user2@example.com", self.test_db) + user2_id, _ = Database.create_user("user2@example.com") # Add jobs for both users job1 = Database.add_to_queue( "https://example.com/1", "test@example.com", self.user_id, - self.test_db, ) job2 = Database.add_to_queue( "https://example.com/2", "user2@example.com", user2_id, - self.test_db, ) # Get user-specific queue status - user1_jobs = Database.get_user_queue_status(self.user_id, self.test_db) - user2_jobs = Database.get_user_queue_status(user2_id, self.test_db) + user1_jobs = Database.get_user_queue_status(self.user_id) + user2_jobs = Database.get_user_queue_status(user2_id) self.assertEqual(len(user1_jobs), 1) self.assertEqual(user1_jobs[0]["id"], job1) @@ -1177,26 +1067,23 @@ class TestQueueOperations(Test.TestCase): "https://example.com/1", "test@example.com", self.user_id, - self.test_db, ) job2 = Database.add_to_queue( "https://example.com/2", "test@example.com", self.user_id, - self.test_db, ) job3 = Database.add_to_queue( "https://example.com/3", "test@example.com", self.user_id, - self.test_db, ) - Database.update_job_status(job2, "processing", db_path=self.test_db) - Database.update_job_status(job3, "error", "Failed", self.test_db) + Database.update_job_status(job2, "processing") + Database.update_job_status(job3, "error", "Failed") # Get status counts - counts = Database.get_user_status_counts(self.user_id, self.test_db) + counts = Database.get_user_status_counts(self.user_id) self.assertEqual(counts.get("pending", 0), 1) self.assertEqual(counts.get("processing", 0), 1) @@ -1208,24 +1095,13 @@ class TestEpisodeManagement(Test.TestCase): def setUp(self) -> None: """Set up test database with a user.""" - self.test_db = "_/var/podcastitlater/test_podcast.db" - - # Ensure test directory exists - test_db_dir = pathlib.Path(self.test_db).parent - test_db_dir.mkdir(parents=True, exist_ok=True) + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") - # Clean up any existing test database - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() - Database.init_db(self.test_db) - self.user_id, _ = Database.create_user("test@example.com", self.test_db) - - def tearDown(self) -> None: + @staticmethod + def tearDown() -> None: """Clean up test database.""" - test_db_path = pathlib.Path(self.test_db) - if test_db_path.exists(): - test_db_path.unlink() + Database.teardown() def test_create_episode(self) -> None: """Create episode with user association.""" @@ -1235,7 +1111,6 @@ class TestEpisodeManagement(Test.TestCase): duration=300, content_length=5000, user_id=self.user_id, - db_path=self.test_db, ) self.assertIsInstance(episode_id, int) @@ -1250,7 +1125,6 @@ class TestEpisodeManagement(Test.TestCase): 100, 1000, self.user_id, - self.test_db, ) time.sleep(0.01) ep2 = Database.create_episode( @@ -1259,7 +1133,6 @@ class TestEpisodeManagement(Test.TestCase): 200, 2000, self.user_id, - self.test_db, ) time.sleep(0.01) ep3 = Database.create_episode( @@ -1268,11 +1141,10 @@ class TestEpisodeManagement(Test.TestCase): 300, 3000, self.user_id, - self.test_db, ) # Get recent episodes - episodes = Database.get_recent_episodes(limit=10, db_path=self.test_db) + episodes = Database.get_recent_episodes(limit=10) self.assertEqual(len(episodes), 3) # Should be in reverse chronological order @@ -1283,7 +1155,7 @@ class TestEpisodeManagement(Test.TestCase): def test_get_user_episodes(self) -> None: """Ensure user isolation for episodes.""" # Create second user - user2_id, _ = Database.create_user("user2@example.com", self.test_db) + user2_id, _ = Database.create_user("user2@example.com") # Create episodes for both users ep1 = Database.create_episode( @@ -1292,7 +1164,6 @@ class TestEpisodeManagement(Test.TestCase): 100, 1000, self.user_id, - self.test_db, ) ep2 = Database.create_episode( "User2 Article", @@ -1300,15 +1171,13 @@ class TestEpisodeManagement(Test.TestCase): 200, 2000, user2_id, - self.test_db, ) # Get user-specific episodes user1_episodes = Database.get_user_all_episodes( self.user_id, - self.test_db, ) - user2_episodes = Database.get_user_all_episodes(user2_id, self.test_db) + user2_episodes = Database.get_user_all_episodes(user2_id) self.assertEqual(len(user1_episodes), 1) self.assertEqual(user1_episodes[0]["id"], ep1) @@ -1324,10 +1193,9 @@ class TestEpisodeManagement(Test.TestCase): duration=12345, content_length=98765, user_id=self.user_id, - db_path=self.test_db, ) - episodes = Database.get_user_all_episodes(self.user_id, self.test_db) + episodes = Database.get_user_all_episodes(self.user_id) episode = episodes[0] self.assertEqual(episode["duration"], 12345) |
