From 48fdb6610957d213739f2cc84bc4c9071be909ac Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Wed, 3 Sep 2025 16:46:30 -0400 Subject: Add Environment-Aware Database Path Handling --- Biz/PodcastItLater/Core.py | 145 ++++++++++++++++++++++++++++++++++--------- Biz/PodcastItLater/Web.py | 15 ++++- Biz/PodcastItLater/Worker.py | 16 ++++- 3 files changed, 143 insertions(+), 33 deletions(-) diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py index c0d0acf..6c04db8 100644 --- a/Biz/PodcastItLater/Core.py +++ b/Biz/PodcastItLater/Core.py @@ -29,16 +29,30 @@ logger = Log.setup() class Database: # noqa: PLR0904 """Data access layer for PodcastItLater database operations.""" + @staticmethod + def get_default_db_path() -> str: + """Get the default database path based on environment.""" + area = App.from_env() + if area == App.Area.Test: + return "_/var/podcastitlater/podcast.db" + return "/var/podcastitlater/podcast.db" + @staticmethod @contextmanager def get_connection( - db_path: str = "podcast.db", + db_path: str | None = None, ) -> Iterator[sqlite3.Connection]: """Context manager for database connections. + Args: + db_path: Database file path. If None, uses environment-appropriate + default. + Yields: sqlite3.Connection: Database connection with row factory set. """ + if db_path is None: + db_path = Database.get_default_db_path() conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row try: @@ -47,8 +61,15 @@ class Database: # noqa: PLR0904 conn.close() @staticmethod - def init_db(db_path: str = "podcast.db") -> None: + def init_db(db_path: str | None = None) -> None: """Initialize database with required tables.""" + if db_path is None: + db_path = Database.get_default_db_path() + + # Ensure directory exists + db_dir = pathlib.Path(db_path).parent + db_dir.mkdir(parents=True, exist_ok=True) + with Database.get_connection(db_path) as conn: cursor = conn.cursor() @@ -101,13 +122,15 @@ class Database: # noqa: PLR0904 url: str, email: str, user_id: int, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> int: """Insert new job into queue, return job ID. Raises: ValueError: If job ID cannot be retrieved after insert. """ + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -125,9 +148,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_pending_jobs( limit: int = 10, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> list[dict[str, Any]]: """Fetch jobs with status='pending' ordered by creation time.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -143,9 +168,11 @@ class Database: # noqa: PLR0904 job_id: int, status: str, error: str | None = None, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> None: """Update job status and error message.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() if status == "error": @@ -165,9 +192,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_job_by_id( job_id: int, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> dict[str, Any] | None: """Fetch single job by ID.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) @@ -181,13 +210,15 @@ class Database: # noqa: PLR0904 duration: int, content_length: int, user_id: int | None = None, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> int: """Insert episode record, return episode ID. Raises: ValueError: If episode ID cannot be retrieved after insert. """ + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -207,9 +238,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_recent_episodes( limit: int = 20, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> list[dict[str, Any]]: """Get recent episodes for RSS feed generation.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -220,8 +253,10 @@ class Database: # noqa: PLR0904 return [dict(row) for row in rows] @staticmethod - def get_queue_status_summary(db_path: str = "podcast.db") -> dict[str, Any]: + def get_queue_status_summary(db_path: str | None = None) -> dict[str, Any]: """Get queue status summary for web interface.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() @@ -242,8 +277,10 @@ class Database: # noqa: PLR0904 return {"status_counts": status_counts, "recent_jobs": recent_jobs} @staticmethod - def get_queue_status(db_path: str = "podcast.db") -> list[dict[str, Any]]: + def get_queue_status(db_path: str | None = None) -> list[dict[str, Any]]: """Return pending/processing/error items for web interface.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute(""" @@ -258,10 +295,12 @@ class Database: # noqa: PLR0904 @staticmethod def get_all_episodes( - db_path: str = "podcast.db", + db_path: str | None = None, user_id: int | None = None, ) -> list[dict[str, Any]]: """Return all episodes for RSS feed.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() if user_id: @@ -288,9 +327,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_retryable_jobs( max_retries: int = 3, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> list[dict[str, Any]]: """Get failed jobs that can be retried.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -302,8 +343,10 @@ class Database: # noqa: PLR0904 return [dict(row) for row in rows] @staticmethod - def retry_job(job_id: int, db_path: str = "podcast.db") -> None: + def retry_job(job_id: int, db_path: str | None = None) -> None: """Reset a job to pending status for retry.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -315,8 +358,10 @@ class Database: # noqa: PLR0904 logger.info("Reset job %s to pending for retry", job_id) @staticmethod - def delete_job(job_id: int, db_path: str = "podcast.db") -> None: + def delete_job(job_id: int, db_path: str | None = None) -> None: """Delete a job from the queue.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) @@ -325,10 +370,12 @@ class Database: # noqa: PLR0904 @staticmethod def get_all_queue_items( - db_path: str = "podcast.db", + db_path: str | None = None, user_id: int | None = None, ) -> list[dict[str, Any]]: """Return all queue items for admin view.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() if user_id: @@ -353,8 +400,10 @@ class Database: # noqa: PLR0904 return [dict(row) for row in rows] @staticmethod - def get_status_counts(db_path: str = "podcast.db") -> dict[str, int]: + def get_status_counts(db_path: str | None = None) -> dict[str, int]: """Get count of queue items by status.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute(""" @@ -368,9 +417,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_status_counts( user_id: int, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> dict[str, int]: """Get count of queue items by status for a specific user.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -386,8 +437,10 @@ class Database: # noqa: PLR0904 return {row["status"]: row["count"] for row in rows} @staticmethod - def migrate_to_multi_user(db_path: str = "podcast.db") -> None: + def migrate_to_multi_user(db_path: str | None = None) -> None: """Migrate database to support multiple users.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() @@ -437,13 +490,15 @@ class Database: # noqa: PLR0904 logger.info("Database migrated to support multiple users") @staticmethod - def create_user(email: str, db_path: str = "podcast.db") -> tuple[int, str]: + def create_user(email: str, db_path: str | None = None) -> tuple[int, str]: """Create a new user and return (user_id, token). Raises: ValueError: If user ID cannot be retrieved after insert or if user not found. """ + if db_path is None: + db_path = Database.get_default_db_path() # Generate a secure token for RSS feed access token = secrets.token_urlsafe(32) @@ -477,9 +532,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_by_email( email: str, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> dict[str, Any] | None: """Get user by email address.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) @@ -489,9 +546,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_by_token( token: str, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> dict[str, Any] | None: """Get user by RSS token.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) @@ -501,9 +560,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_by_id( user_id: int, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> dict[str, Any] | None: """Get user by ID.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) @@ -513,9 +574,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_queue_status( user_id: int, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> list[dict[str, Any]]: """Return pending/processing/error items for a specific user.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -536,9 +599,11 @@ class Database: # noqa: PLR0904 def get_user_recent_episodes( user_id: int, limit: int = 20, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> list[dict[str, Any]]: """Get recent episodes for a specific user.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -552,9 +617,11 @@ class Database: # noqa: PLR0904 @staticmethod def get_user_all_episodes( user_id: int, - db_path: str = "podcast.db", + db_path: str | None = None, ) -> list[dict[str, Any]]: """Get all episodes for a specific user for RSS feed.""" + if db_path is None: + db_path = Database.get_default_db_path() with Database.get_connection(db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -571,7 +638,12 @@ class TestDatabase(Test.TestCase): def setUp(self) -> None: """Set up test database.""" - self.test_db = "test_podcast.db" + self.test_db = "_/var/podcastitlater/test_podcast.db" + + # Ensure test directory exists + test_db_dir = pathlib.Path(self.test_db).parent + test_db_dir.mkdir(parents=True, exist_ok=True) + # Clean up any existing test database test_db_path = pathlib.Path(self.test_db) if test_db_path.exists(): @@ -637,7 +709,12 @@ class TestUserManagement(Test.TestCase): def setUp(self) -> None: """Set up test database.""" - self.test_db = "test_podcast.db" + self.test_db = "_/var/podcastitlater/test_podcast.db" + + # Ensure test directory exists + test_db_dir = pathlib.Path(self.test_db).parent + test_db_dir.mkdir(parents=True, exist_ok=True) + # Clean up any existing test database test_db_path = pathlib.Path(self.test_db) if test_db_path.exists(): @@ -741,7 +818,12 @@ class TestQueueOperations(Test.TestCase): def setUp(self) -> None: """Set up test database with a user.""" - self.test_db = "test_podcast.db" + self.test_db = "_/var/podcastitlater/test_podcast.db" + + # Ensure test directory exists + test_db_dir = pathlib.Path(self.test_db).parent + test_db_dir.mkdir(parents=True, exist_ok=True) + # Clean up any existing test database test_db_path = pathlib.Path(self.test_db) if test_db_path.exists(): @@ -977,7 +1059,12 @@ class TestEpisodeManagement(Test.TestCase): def setUp(self) -> None: """Set up test database with a user.""" - self.test_db = "test_podcast.db" + self.test_db = "_/var/podcastitlater/test_podcast.db" + + # Ensure test directory exists + test_db_dir = pathlib.Path(self.test_db).parent + test_db_dir.mkdir(parents=True, exist_ok=True) + # Clean up any existing test database test_db_path = pathlib.Path(self.test_db) if test_db_path.exists(): diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py index b471a29..86b2099 100644 --- a/Biz/PodcastItLater/Web.py +++ b/Biz/PodcastItLater/Web.py @@ -50,7 +50,14 @@ from typing import override logger = Log.setup() # Configuration -DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db") +area = App.from_env() +if area == App.Area.Test: + DATABASE_PATH = os.getenv( + "DATABASE_PATH", + "_/var/podcastitlater/podcast.db", + ) +else: + DATABASE_PATH = os.getenv("DATABASE_PATH", "/var/podcastitlater/podcast.db") BASE_URL = os.getenv("BASE_URL", "http://localhost:8000") PORT = int(os.getenv("PORT", "8000")) @@ -1751,7 +1758,11 @@ class BaseWebTest(Test.TestCase): def setUp(self) -> None: """Set up test database and client.""" # Create a test database context - self.test_db_path = "test_podcast_web.db" + self.test_db_path = "_/var/podcastitlater/test_podcast_web.db" + + # Ensure test directory exists + test_db_dir = pathlib.Path(self.test_db_path).parent + test_db_dir.mkdir(parents=True, exist_ok=True) # Save original database path self._original_db_path = globals()["_test_database_path"] diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py index 834d44b..ce3d432 100644 --- a/Biz/PodcastItLater/Worker.py +++ b/Biz/PodcastItLater/Worker.py @@ -41,7 +41,14 @@ S3_ENDPOINT = os.getenv("S3_ENDPOINT") S3_BUCKET = os.getenv("S3_BUCKET") S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") -DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db") +area = App.from_env() +if area == App.Area.Test: + DATABASE_PATH = os.getenv( + "DATABASE_PATH", + "_/var/podcastitlater/podcast.db", + ) +else: + DATABASE_PATH = os.getenv("DATABASE_PATH", "/var/podcastitlater/podcast.db") # Worker configuration MAX_CONTENT_LENGTH = 5000 # characters for TTS @@ -943,7 +950,12 @@ class TestJobProcessing(Test.TestCase): def setUp(self) -> None: """Set up test environment.""" - self.test_db = "test_podcast_worker.db" + self.test_db = "_/var/podcastitlater/test_podcast_worker.db" + + # Ensure test directory exists + test_db_dir = pathlib.Path(self.test_db).parent + test_db_dir.mkdir(parents=True, exist_ok=True) + # Clean up any existing test database test_db_path = pathlib.Path(self.test_db) if test_db_path.exists(): -- cgit v1.2.3