summaryrefslogtreecommitdiff
path: root/Biz
diff options
context:
space:
mode:
Diffstat (limited to 'Biz')
-rw-r--r--Biz/PodcastItLater/Core.py145
-rw-r--r--Biz/PodcastItLater/Web.py15
-rw-r--r--Biz/PodcastItLater/Worker.py16
3 files changed, 143 insertions, 33 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
index c0d0acf..6c04db8 100644
--- a/Biz/PodcastItLater/Core.py
+++ b/Biz/PodcastItLater/Core.py
@@ -30,15 +30,29 @@ class Database: # noqa: PLR0904
"""Data access layer for PodcastItLater database operations."""
@staticmethod
+ def get_default_db_path() -> str:
+ """Get the default database path based on environment."""
+ area = App.from_env()
+ if area == App.Area.Test:
+ return "_/var/podcastitlater/podcast.db"
+ return "/var/podcastitlater/podcast.db"
+
+ @staticmethod
@contextmanager
def get_connection(
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> Iterator[sqlite3.Connection]:
"""Context manager for database connections.
+ Args:
+ db_path: Database file path. If None, uses environment-appropriate
+ default.
+
Yields:
sqlite3.Connection: Database connection with row factory set.
"""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
try:
@@ -47,8 +61,15 @@ class Database: # noqa: PLR0904
conn.close()
@staticmethod
- def init_db(db_path: str = "podcast.db") -> None:
+ def init_db(db_path: str | None = None) -> None:
"""Initialize database with required tables."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
+
+ # Ensure directory exists
+ db_dir = pathlib.Path(db_path).parent
+ db_dir.mkdir(parents=True, exist_ok=True)
+
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
@@ -101,13 +122,15 @@ class Database: # noqa: PLR0904
url: str,
email: str,
user_id: int,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> int:
"""Insert new job into queue, return job ID.
Raises:
ValueError: If job ID cannot be retrieved after insert.
"""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -125,9 +148,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_pending_jobs(
limit: int = 10,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Fetch jobs with status='pending' ordered by creation time."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -143,9 +168,11 @@ class Database: # noqa: PLR0904
job_id: int,
status: str,
error: str | None = None,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> None:
"""Update job status and error message."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
if status == "error":
@@ -165,9 +192,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_job_by_id(
job_id: int,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> dict[str, Any] | None:
"""Fetch single job by ID."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,))
@@ -181,13 +210,15 @@ class Database: # noqa: PLR0904
duration: int,
content_length: int,
user_id: int | None = None,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> int:
"""Insert episode record, return episode ID.
Raises:
ValueError: If episode ID cannot be retrieved after insert.
"""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -207,9 +238,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_recent_episodes(
limit: int = 20,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get recent episodes for RSS feed generation."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -220,8 +253,10 @@ class Database: # noqa: PLR0904
return [dict(row) for row in rows]
@staticmethod
- def get_queue_status_summary(db_path: str = "podcast.db") -> dict[str, Any]:
+ def get_queue_status_summary(db_path: str | None = None) -> dict[str, Any]:
"""Get queue status summary for web interface."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
@@ -242,8 +277,10 @@ class Database: # noqa: PLR0904
return {"status_counts": status_counts, "recent_jobs": recent_jobs}
@staticmethod
- def get_queue_status(db_path: str = "podcast.db") -> list[dict[str, Any]]:
+ def get_queue_status(db_path: str | None = None) -> list[dict[str, Any]]:
"""Return pending/processing/error items for web interface."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
@@ -258,10 +295,12 @@ class Database: # noqa: PLR0904
@staticmethod
def get_all_episodes(
- db_path: str = "podcast.db",
+ db_path: str | None = None,
user_id: int | None = None,
) -> list[dict[str, Any]]:
"""Return all episodes for RSS feed."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
if user_id:
@@ -288,9 +327,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_retryable_jobs(
max_retries: int = 3,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get failed jobs that can be retried."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -302,8 +343,10 @@ class Database: # noqa: PLR0904
return [dict(row) for row in rows]
@staticmethod
- def retry_job(job_id: int, db_path: str = "podcast.db") -> None:
+ def retry_job(job_id: int, db_path: str | None = None) -> None:
"""Reset a job to pending status for retry."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -315,8 +358,10 @@ class Database: # noqa: PLR0904
logger.info("Reset job %s to pending for retry", job_id)
@staticmethod
- def delete_job(job_id: int, db_path: str = "podcast.db") -> None:
+ def delete_job(job_id: int, db_path: str | None = None) -> None:
"""Delete a job from the queue."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,))
@@ -325,10 +370,12 @@ class Database: # noqa: PLR0904
@staticmethod
def get_all_queue_items(
- db_path: str = "podcast.db",
+ db_path: str | None = None,
user_id: int | None = None,
) -> list[dict[str, Any]]:
"""Return all queue items for admin view."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
if user_id:
@@ -353,8 +400,10 @@ class Database: # noqa: PLR0904
return [dict(row) for row in rows]
@staticmethod
- def get_status_counts(db_path: str = "podcast.db") -> dict[str, int]:
+ def get_status_counts(db_path: str | None = None) -> dict[str, int]:
"""Get count of queue items by status."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
@@ -368,9 +417,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_status_counts(
user_id: int,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> dict[str, int]:
"""Get count of queue items by status for a specific user."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -386,8 +437,10 @@ class Database: # noqa: PLR0904
return {row["status"]: row["count"] for row in rows}
@staticmethod
- def migrate_to_multi_user(db_path: str = "podcast.db") -> None:
+ def migrate_to_multi_user(db_path: str | None = None) -> None:
"""Migrate database to support multiple users."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
@@ -437,13 +490,15 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support multiple users")
@staticmethod
- def create_user(email: str, db_path: str = "podcast.db") -> tuple[int, str]:
+ def create_user(email: str, db_path: str | None = None) -> tuple[int, str]:
"""Create a new user and return (user_id, token).
Raises:
ValueError: If user ID cannot be retrieved after insert or if user
not found.
"""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
# Generate a secure token for RSS feed access
token = secrets.token_urlsafe(32)
@@ -477,9 +532,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_by_email(
email: str,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> dict[str, Any] | None:
"""Get user by email address."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
@@ -489,9 +546,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_by_token(
token: str,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> dict[str, Any] | None:
"""Get user by RSS token."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE token = ?", (token,))
@@ -501,9 +560,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_by_id(
user_id: int,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> dict[str, Any] | None:
"""Get user by ID."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
@@ -513,9 +574,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_queue_status(
user_id: int,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Return pending/processing/error items for a specific user."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -536,9 +599,11 @@ class Database: # noqa: PLR0904
def get_user_recent_episodes(
user_id: int,
limit: int = 20,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get recent episodes for a specific user."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -552,9 +617,11 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_all_episodes(
user_id: int,
- db_path: str = "podcast.db",
+ db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get all episodes for a specific user for RSS feed."""
+ if db_path is None:
+ db_path = Database.get_default_db_path()
with Database.get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
@@ -571,7 +638,12 @@ class TestDatabase(Test.TestCase):
def setUp(self) -> None:
"""Set up test database."""
- self.test_db = "test_podcast.db"
+ self.test_db = "_/var/podcastitlater/test_podcast.db"
+
+ # Ensure test directory exists
+ test_db_dir = pathlib.Path(self.test_db).parent
+ test_db_dir.mkdir(parents=True, exist_ok=True)
+
# Clean up any existing test database
test_db_path = pathlib.Path(self.test_db)
if test_db_path.exists():
@@ -637,7 +709,12 @@ class TestUserManagement(Test.TestCase):
def setUp(self) -> None:
"""Set up test database."""
- self.test_db = "test_podcast.db"
+ self.test_db = "_/var/podcastitlater/test_podcast.db"
+
+ # Ensure test directory exists
+ test_db_dir = pathlib.Path(self.test_db).parent
+ test_db_dir.mkdir(parents=True, exist_ok=True)
+
# Clean up any existing test database
test_db_path = pathlib.Path(self.test_db)
if test_db_path.exists():
@@ -741,7 +818,12 @@ class TestQueueOperations(Test.TestCase):
def setUp(self) -> None:
"""Set up test database with a user."""
- self.test_db = "test_podcast.db"
+ self.test_db = "_/var/podcastitlater/test_podcast.db"
+
+ # Ensure test directory exists
+ test_db_dir = pathlib.Path(self.test_db).parent
+ test_db_dir.mkdir(parents=True, exist_ok=True)
+
# Clean up any existing test database
test_db_path = pathlib.Path(self.test_db)
if test_db_path.exists():
@@ -977,7 +1059,12 @@ class TestEpisodeManagement(Test.TestCase):
def setUp(self) -> None:
"""Set up test database with a user."""
- self.test_db = "test_podcast.db"
+ self.test_db = "_/var/podcastitlater/test_podcast.db"
+
+ # Ensure test directory exists
+ test_db_dir = pathlib.Path(self.test_db).parent
+ test_db_dir.mkdir(parents=True, exist_ok=True)
+
# Clean up any existing test database
test_db_path = pathlib.Path(self.test_db)
if test_db_path.exists():
diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py
index b471a29..86b2099 100644
--- a/Biz/PodcastItLater/Web.py
+++ b/Biz/PodcastItLater/Web.py
@@ -50,7 +50,14 @@ from typing import override
logger = Log.setup()
# Configuration
-DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db")
+area = App.from_env()
+if area == App.Area.Test:
+ DATABASE_PATH = os.getenv(
+ "DATABASE_PATH",
+ "_/var/podcastitlater/podcast.db",
+ )
+else:
+ DATABASE_PATH = os.getenv("DATABASE_PATH", "/var/podcastitlater/podcast.db")
BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
PORT = int(os.getenv("PORT", "8000"))
@@ -1751,7 +1758,11 @@ class BaseWebTest(Test.TestCase):
def setUp(self) -> None:
"""Set up test database and client."""
# Create a test database context
- self.test_db_path = "test_podcast_web.db"
+ self.test_db_path = "_/var/podcastitlater/test_podcast_web.db"
+
+ # Ensure test directory exists
+ test_db_dir = pathlib.Path(self.test_db_path).parent
+ test_db_dir.mkdir(parents=True, exist_ok=True)
# Save original database path
self._original_db_path = globals()["_test_database_path"]
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py
index 834d44b..ce3d432 100644
--- a/Biz/PodcastItLater/Worker.py
+++ b/Biz/PodcastItLater/Worker.py
@@ -41,7 +41,14 @@ S3_ENDPOINT = os.getenv("S3_ENDPOINT")
S3_BUCKET = os.getenv("S3_BUCKET")
S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY")
S3_SECRET_KEY = os.getenv("S3_SECRET_KEY")
-DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db")
+area = App.from_env()
+if area == App.Area.Test:
+ DATABASE_PATH = os.getenv(
+ "DATABASE_PATH",
+ "_/var/podcastitlater/podcast.db",
+ )
+else:
+ DATABASE_PATH = os.getenv("DATABASE_PATH", "/var/podcastitlater/podcast.db")
# Worker configuration
MAX_CONTENT_LENGTH = 5000 # characters for TTS
@@ -943,7 +950,12 @@ class TestJobProcessing(Test.TestCase):
def setUp(self) -> None:
"""Set up test environment."""
- self.test_db = "test_podcast_worker.db"
+ self.test_db = "_/var/podcastitlater/test_podcast_worker.db"
+
+ # Ensure test directory exists
+ test_db_dir = pathlib.Path(self.test_db).parent
+ test_db_dir.mkdir(parents=True, exist_ok=True)
+
# Clean up any existing test database
test_db_path = pathlib.Path(self.test_db)
if test_db_path.exists():