summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Core.py
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2025-09-04 16:23:17 -0400
committerBen Sima <ben@bsima.me>2025-09-04 16:23:17 -0400
commit91750395b5047dfb99f5d9b7b49d336b2bfb38e8 (patch)
treee3915b25abd67c22f037bc9b29bfbd7cbd352438 /Biz/PodcastItLater/Core.py
parent2a2ff0749f18670ab82c304c8c3468aeea47846f (diff)
Refactor Admin and Database path stuff
Moved the Admin related stuff to a separate file. Removed the repetitive `db_path` arg everywhere and replaced it with correct assumptions, similar to whats in other apps.
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
-rw-r--r--Biz/PodcastItLater/Core.py410
1 files changed, 139 insertions, 271 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
index 278b97e..0ca945c 100644
--- a/Biz/PodcastItLater/Core.py
+++ b/Biz/PodcastItLater/Core.py
@@ -13,12 +13,14 @@ Includes:
import Omni.App as App
import Omni.Log as Log
import Omni.Test as Test
+import os
import pathlib
import pytest
import secrets
import sqlite3
import sys
import time
+import typing
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
@@ -26,33 +28,48 @@ from typing import Any
logger = Log.setup()
+CODEROOT = pathlib.Path(os.getenv("CODEROOT", "."))
+DATA_DIR = pathlib.Path(
+ os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"),
+)
+
+# Constants for UI display
+URL_TRUNCATE_LENGTH = 80
+TITLE_TRUNCATE_LENGTH = 50
+ERROR_TRUNCATE_LENGTH = 50
+
+# Admin whitelist
+ADMIN_EMAILS = ["ben@bensima.com"]
+
+
+def is_admin(user: dict[str, typing.Any] | None) -> bool:
+ """Check if user is an admin based on email whitelist."""
+ if not user:
+ return False
+ return user.get("email", "").lower() in [
+ email.lower() for email in ADMIN_EMAILS
+ ]
+
+
class Database: # noqa: PLR0904
"""Data access layer for PodcastItLater database operations."""
@staticmethod
- def get_default_db_path() -> str:
- """Get the default database path based on environment."""
- area = App.from_env()
- if area == App.Area.Test:
- return "_/var/podcastitlater/podcast.db"
- return "/var/podcastitlater/podcast.db"
+ def teardown() -> None:
+ """Delete the existing database, for cleanup after tests."""
+ db_path = DATA_DIR / "podcast.db"
+ if db_path.exists():
+ db_path.unlink()
@staticmethod
@contextmanager
- def get_connection(
- db_path: str | None = None,
- ) -> Iterator[sqlite3.Connection]:
+ def get_connection() -> Iterator[sqlite3.Connection]:
"""Context manager for database connections.
- Args:
- db_path: Database file path. If None, uses environment-appropriate
- default.
-
Yields:
sqlite3.Connection: Database connection with row factory set.
"""
- if db_path is None:
- db_path = Database.get_default_db_path()
+ db_path = DATA_DIR / "podcast.db"
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
try:
@@ -61,16 +78,9 @@ class Database: # noqa: PLR0904
conn.close()
@staticmethod
- def init_db(db_path: str | None = None) -> None:
+ def init_db() -> None:
"""Initialize database with required tables."""
- if db_path is None:
- db_path = Database.get_default_db_path()
-
- # Ensure directory exists
- db_dir = pathlib.Path(db_path).parent
- db_dir.mkdir(parents=True, exist_ok=True)
-
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Queue table for job processing
@@ -117,26 +127,25 @@ class Database: # noqa: PLR0904
logger.info("Database initialized successfully")
# Run migration to add user support
- Database.migrate_to_multi_user(db_path)
+ Database.migrate_to_multi_user()
# Run migration to add metadata fields
- Database.migrate_add_metadata_fields(db_path)
+ Database.migrate_add_metadata_fields()
# Run migration to add episode metadata fields
- Database.migrate_add_episode_metadata(db_path)
+ Database.migrate_add_episode_metadata()
# Run migration to add user status field
- Database.migrate_add_user_status(db_path)
+ Database.migrate_add_user_status()
# Run migration to add default titles
- Database.migrate_add_default_titles(db_path)
+ Database.migrate_add_default_titles()
@staticmethod
- def add_to_queue( # noqa: PLR0913, PLR0917
+ def add_to_queue(
url: str,
email: str,
user_id: int,
- db_path: str | None = None,
title: str | None = None,
author: str | None = None,
) -> int:
@@ -145,9 +154,7 @@ class Database: # noqa: PLR0904
Raises:
ValueError: If job ID cannot be retrieved after insert.
"""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO queue (url, email, user_id, title, author) "
@@ -165,12 +172,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_pending_jobs(
limit: int = 10,
- db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Fetch jobs with status='pending' ordered by creation time."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM queue WHERE status = 'pending' "
@@ -185,12 +189,9 @@ class Database: # noqa: PLR0904
job_id: int,
status: str,
error: str | None = None,
- db_path: str | None = None,
) -> None:
"""Update job status and error message."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
if status == "error":
cursor.execute(
@@ -209,12 +210,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_job_by_id(
job_id: int,
- db_path: str | None = None,
) -> dict[str, Any] | None:
"""Fetch single job by ID."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,))
row = cursor.fetchone()
@@ -229,16 +227,13 @@ class Database: # noqa: PLR0904
user_id: int | None = None,
author: str | None = None,
original_url: str | None = None,
- db_path: str | None = None,
) -> int:
"""Insert episode record, return episode ID.
Raises:
ValueError: If episode ID cannot be retrieved after insert.
"""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO episodes "
@@ -265,12 +260,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_recent_episodes(
limit: int = 20,
- db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get recent episodes for RSS feed generation."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?",
@@ -280,11 +272,9 @@ class Database: # noqa: PLR0904
return [dict(row) for row in rows]
@staticmethod
- def get_queue_status_summary(db_path: str | None = None) -> dict[str, Any]:
+ def get_queue_status_summary() -> dict[str, Any]:
"""Get queue status summary for web interface."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Count jobs by status
@@ -304,11 +294,9 @@ class Database: # noqa: PLR0904
return {"status_counts": status_counts, "recent_jobs": recent_jobs}
@staticmethod
- def get_queue_status(db_path: str | None = None) -> list[dict[str, Any]]:
+ def get_queue_status() -> list[dict[str, Any]]:
"""Return pending/processing/error items for web interface."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT id, url, email, status, created_at, error_message,
@@ -323,13 +311,10 @@ class Database: # noqa: PLR0904
@staticmethod
def get_all_episodes(
- db_path: str | None = None,
user_id: int | None = None,
) -> list[dict[str, Any]]:
"""Return all episodes for RSS feed."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
if user_id:
cursor.execute(
@@ -355,12 +340,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_retryable_jobs(
max_retries: int = 3,
- db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get failed jobs that can be retried."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM queue WHERE status = 'error' "
@@ -371,11 +353,9 @@ class Database: # noqa: PLR0904
return [dict(row) for row in rows]
@staticmethod
- def retry_job(job_id: int, db_path: str | None = None) -> None:
+ def retry_job(job_id: int) -> None:
"""Reset a job to pending status for retry."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE queue SET status = 'pending', "
@@ -386,11 +366,9 @@ class Database: # noqa: PLR0904
logger.info("Reset job %s to pending for retry", job_id)
@staticmethod
- def delete_job(job_id: int, db_path: str | None = None) -> None:
+ def delete_job(job_id: int) -> None:
"""Delete a job from the queue."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,))
conn.commit()
@@ -398,13 +376,10 @@ class Database: # noqa: PLR0904
@staticmethod
def get_all_queue_items(
- db_path: str | None = None,
user_id: int | None = None,
) -> list[dict[str, Any]]:
"""Return all queue items for admin view."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
if user_id:
cursor.execute(
@@ -428,11 +403,9 @@ class Database: # noqa: PLR0904
return [dict(row) for row in rows]
@staticmethod
- def get_status_counts(db_path: str | None = None) -> dict[str, int]:
+ def get_status_counts() -> dict[str, int]:
"""Get count of queue items by status."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT status, COUNT(*) as count
@@ -445,12 +418,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_status_counts(
user_id: int,
- db_path: str | None = None,
) -> dict[str, int]:
"""Get count of queue items by status for a specific user."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
@@ -465,11 +435,9 @@ class Database: # noqa: PLR0904
return {row["status"]: row["count"] for row in rows}
@staticmethod
- def migrate_to_multi_user(db_path: str | None = None) -> None:
+ def migrate_to_multi_user() -> None:
"""Migrate database to support multiple users."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Create users table
@@ -518,11 +486,9 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support multiple users")
@staticmethod
- def migrate_add_metadata_fields(db_path: str | None = None) -> None:
+ def migrate_add_metadata_fields() -> None:
"""Add title and author fields to queue table."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Check if columns already exist
@@ -540,11 +506,9 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support metadata fields")
@staticmethod
- def migrate_add_episode_metadata(db_path: str | None = None) -> None:
+ def migrate_add_episode_metadata() -> None:
"""Add author and original_url fields to episodes table."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Check if columns already exist
@@ -564,11 +528,9 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support episode metadata fields")
@staticmethod
- def migrate_add_user_status(db_path: str | None = None) -> None:
+ def migrate_add_user_status() -> None:
"""Add status field to users table."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Check if column already exists
@@ -592,11 +554,9 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support user status")
@staticmethod
- def migrate_add_default_titles(db_path: str | None = None) -> None:
+ def migrate_add_default_titles() -> None:
"""Add default titles to queue items that have None titles."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Update queue items with NULL titles to have a default
@@ -616,19 +576,16 @@ class Database: # noqa: PLR0904
)
@staticmethod
- def create_user(email: str, db_path: str | None = None) -> tuple[int, str]:
+ def create_user(email: str) -> tuple[int, str]:
"""Create a new user and return (user_id, token).
Raises:
ValueError: If user ID cannot be retrieved after insert or if user
not found.
"""
- if db_path is None:
- db_path = Database.get_default_db_path()
# Generate a secure token for RSS feed access
token = secrets.token_urlsafe(32)
-
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(
@@ -658,12 +615,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_by_email(
email: str,
- db_path: str | None = None,
) -> dict[str, Any] | None:
"""Get user by email address."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
row = cursor.fetchone()
@@ -672,12 +626,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_by_token(
token: str,
- db_path: str | None = None,
) -> dict[str, Any] | None:
"""Get user by RSS token."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE token = ?", (token,))
row = cursor.fetchone()
@@ -686,12 +637,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_by_id(
user_id: int,
- db_path: str | None = None,
) -> dict[str, Any] | None:
"""Get user by ID."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
@@ -700,12 +648,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_queue_status(
user_id: int,
- db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Return pending/processing/error items for a specific user."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
@@ -726,12 +671,9 @@ class Database: # noqa: PLR0904
def get_user_recent_episodes(
user_id: int,
limit: int = 20,
- db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get recent episodes for a specific user."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM episodes WHERE user_id = ? "
@@ -744,12 +686,9 @@ class Database: # noqa: PLR0904
@staticmethod
def get_user_all_episodes(
user_id: int,
- db_path: str | None = None,
) -> list[dict[str, Any]]:
"""Get all episodes for a specific user for RSS feed."""
- if db_path is None:
- db_path = Database.get_default_db_path()
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM episodes WHERE user_id = ? "
@@ -763,16 +702,13 @@ class Database: # noqa: PLR0904
def update_user_status(
user_id: int,
status: str,
- db_path: str | None = None,
) -> None:
"""Update user account status."""
- if db_path is None:
- db_path = Database.get_default_db_path()
if status not in {"pending", "active", "disabled"}:
msg = f"Invalid status: {status}"
raise ValueError(msg)
- with Database.get_connection(db_path) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE users SET status = ? WHERE id = ?",
@@ -785,29 +721,20 @@ class Database: # noqa: PLR0904
class TestDatabase(Test.TestCase):
"""Test the Database class."""
- def setUp(self) -> None:
+ @staticmethod
+ def setUp() -> None:
"""Set up test database."""
- self.test_db = "_/var/podcastitlater/test_podcast.db"
-
- # Ensure test directory exists
- test_db_dir = pathlib.Path(self.test_db).parent
- test_db_dir.mkdir(parents=True, exist_ok=True)
-
- # Clean up any existing test database
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
- Database.init_db(self.test_db)
+ Database.init_db()
def tearDown(self) -> None:
"""Clean up test database."""
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
+ Database.teardown()
+ # Clear user ID
+ self.user_id = None
def test_init_db(self) -> None:
"""Verify all tables and indexes are created correctly."""
- with Database.get_connection(self.test_db) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
# Check tables exist
@@ -829,7 +756,7 @@ class TestDatabase(Test.TestCase):
def test_connection_context_manager(self) -> None:
"""Ensure connections are properly closed."""
# Get a connection and verify it works
- with Database.get_connection(self.test_db) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
@@ -842,43 +769,35 @@ class TestDatabase(Test.TestCase):
def test_migration_idempotency(self) -> None:
"""Verify migrations can run multiple times safely."""
# Run migration multiple times
- Database.migrate_to_multi_user(self.test_db)
- Database.migrate_to_multi_user(self.test_db)
- Database.migrate_to_multi_user(self.test_db)
+ Database.migrate_to_multi_user()
+ Database.migrate_to_multi_user()
+ Database.migrate_to_multi_user()
# Should still work fine
- with Database.get_connection(self.test_db) as conn:
+ with Database.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users")
# Should not raise an error
+ # Test completed successfully - migration worked
+ self.assertIsNotNone(conn)
class TestUserManagement(Test.TestCase):
"""Test user management functionality."""
- def setUp(self) -> None:
+ @staticmethod
+ def setUp() -> None:
"""Set up test database."""
- self.test_db = "_/var/podcastitlater/test_podcast.db"
-
- # Ensure test directory exists
- test_db_dir = pathlib.Path(self.test_db).parent
- test_db_dir.mkdir(parents=True, exist_ok=True)
-
- # Clean up any existing test database
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
- Database.init_db(self.test_db)
+ Database.init_db()
- def tearDown(self) -> None:
+ @staticmethod
+ def tearDown() -> None:
"""Clean up test database."""
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
+ Database.teardown()
def test_create_user(self) -> None:
"""Create user with unique email and token."""
- user_id, token = Database.create_user("test@example.com", self.test_db)
+ user_id, token = Database.create_user("test@example.com")
self.assertIsInstance(user_id, int)
self.assertIsInstance(token, str)
@@ -889,13 +808,11 @@ class TestUserManagement(Test.TestCase):
# Create first user
user_id1, token1 = Database.create_user(
"test@example.com",
- self.test_db,
)
# Try to create duplicate
user_id2, token2 = Database.create_user(
"test@example.com",
- self.test_db,
)
# Should return same user
@@ -906,9 +823,9 @@ class TestUserManagement(Test.TestCase):
def test_get_user_by_email(self) -> None:
"""Retrieve user by email."""
- user_id, token = Database.create_user("test@example.com", self.test_db)
+ user_id, token = Database.create_user("test@example.com")
- user = Database.get_user_by_email("test@example.com", self.test_db)
+ user = Database.get_user_by_email("test@example.com")
self.assertIsNotNone(user)
if user is None:
self.fail("User should not be None")
@@ -918,9 +835,9 @@ class TestUserManagement(Test.TestCase):
def test_get_user_by_token(self) -> None:
"""Retrieve user by RSS token."""
- user_id, token = Database.create_user("test@example.com", self.test_db)
+ user_id, token = Database.create_user("test@example.com")
- user = Database.get_user_by_token(token, self.test_db)
+ user = Database.get_user_by_token(token)
self.assertIsNotNone(user)
if user is None:
self.fail("User should not be None")
@@ -929,9 +846,9 @@ class TestUserManagement(Test.TestCase):
def test_get_user_by_id(self) -> None:
"""Retrieve user by ID."""
- user_id, token = Database.create_user("test@example.com", self.test_db)
+ user_id, token = Database.create_user("test@example.com")
- user = Database.get_user_by_id(user_id, self.test_db)
+ user = Database.get_user_by_id(user_id)
self.assertIsNotNone(user)
if user is None:
self.fail("User should not be None")
@@ -941,12 +858,12 @@ class TestUserManagement(Test.TestCase):
def test_invalid_user_lookups(self) -> None:
"""Verify None returned for non-existent users."""
self.assertIsNone(
- Database.get_user_by_email("nobody@example.com", self.test_db),
+ Database.get_user_by_email("nobody@example.com"),
)
self.assertIsNone(
- Database.get_user_by_token("invalid-token", self.test_db),
+ Database.get_user_by_token("invalid-token"),
)
- self.assertIsNone(Database.get_user_by_id(9999, self.test_db))
+ self.assertIsNone(Database.get_user_by_id(9999))
def test_token_uniqueness(self) -> None:
"""Ensure tokens are cryptographically unique."""
@@ -954,7 +871,6 @@ class TestUserManagement(Test.TestCase):
for i in range(10):
_, token = Database.create_user(
f"user{i}@example.com",
- self.test_db,
)
tokens.add(token)
@@ -967,24 +883,13 @@ class TestQueueOperations(Test.TestCase):
def setUp(self) -> None:
"""Set up test database with a user."""
- self.test_db = "_/var/podcastitlater/test_podcast.db"
-
- # Ensure test directory exists
- test_db_dir = pathlib.Path(self.test_db).parent
- test_db_dir.mkdir(parents=True, exist_ok=True)
+ Database.init_db()
+ self.user_id, _ = Database.create_user("test@example.com")
- # Clean up any existing test database
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
- Database.init_db(self.test_db)
- self.user_id, _ = Database.create_user("test@example.com", self.test_db)
-
- def tearDown(self) -> None:
+ @staticmethod
+ def tearDown() -> None:
"""Clean up test database."""
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
+ Database.teardown()
def test_add_to_queue(self) -> None:
"""Add job with user association."""
@@ -992,7 +897,6 @@ class TestQueueOperations(Test.TestCase):
"https://example.com/article",
"test@example.com",
self.user_id,
- self.test_db,
)
self.assertIsInstance(job_id, int)
@@ -1005,25 +909,22 @@ class TestQueueOperations(Test.TestCase):
"https://example.com/1",
"test@example.com",
self.user_id,
- self.test_db,
)
time.sleep(0.01) # Ensure different timestamps
job2 = Database.add_to_queue(
"https://example.com/2",
"test@example.com",
self.user_id,
- self.test_db,
)
time.sleep(0.01)
job3 = Database.add_to_queue(
"https://example.com/3",
"test@example.com",
self.user_id,
- self.test_db,
)
# Get pending jobs
- jobs = Database.get_pending_jobs(limit=10, db_path=self.test_db)
+ jobs = Database.get_pending_jobs(limit=10)
self.assertEqual(len(jobs), 3)
# Should be in order of creation (oldest first)
@@ -1037,12 +938,11 @@ class TestQueueOperations(Test.TestCase):
"https://example.com",
"test@example.com",
self.user_id,
- self.test_db,
)
# Update to processing
- Database.update_job_status(job_id, "processing", db_path=self.test_db)
- job = Database.get_job_by_id(job_id, self.test_db)
+ Database.update_job_status(job_id, "processing")
+ job = Database.get_job_by_id(job_id)
self.assertIsNotNone(job)
if job is None:
self.fail("Job should not be None")
@@ -1053,9 +953,8 @@ class TestQueueOperations(Test.TestCase):
job_id,
"error",
"Network timeout",
- self.test_db,
)
- job = Database.get_job_by_id(job_id, self.test_db)
+ job = Database.get_job_by_id(job_id)
self.assertIsNotNone(job)
if job is None:
self.fail("Job should not be None")
@@ -1069,15 +968,14 @@ class TestQueueOperations(Test.TestCase):
"https://example.com",
"test@example.com",
self.user_id,
- self.test_db,
)
# Set to error
- Database.update_job_status(job_id, "error", "Failed", self.test_db)
+ Database.update_job_status(job_id, "error", "Failed")
# Retry
- Database.retry_job(job_id, self.test_db)
- job = Database.get_job_by_id(job_id, self.test_db)
+ Database.retry_job(job_id)
+ job = Database.get_job_by_id(job_id)
self.assertIsNotNone(job)
if job is None:
@@ -1091,14 +989,13 @@ class TestQueueOperations(Test.TestCase):
"https://example.com",
"test@example.com",
self.user_id,
- self.test_db,
)
# Delete job
- Database.delete_job(job_id, self.test_db)
+ Database.delete_job(job_id)
# Should not exist
- job = Database.get_job_by_id(job_id, self.test_db)
+ job = Database.get_job_by_id(job_id)
self.assertIsNone(job)
def test_get_retryable_jobs(self) -> None:
@@ -1108,14 +1005,12 @@ class TestQueueOperations(Test.TestCase):
"https://example.com",
"test@example.com",
self.user_id,
- self.test_db,
)
- Database.update_job_status(job_id, "error", "Failed", self.test_db)
+ Database.update_job_status(job_id, "error", "Failed")
# Should be retryable
retryable = Database.get_retryable_jobs(
max_retries=3,
- db_path=self.test_db,
)
self.assertEqual(len(retryable), 1)
self.assertEqual(retryable[0]["id"], job_id)
@@ -1125,44 +1020,39 @@ class TestQueueOperations(Test.TestCase):
job_id,
"error",
"Failed again",
- self.test_db,
)
Database.update_job_status(
job_id,
"error",
"Failed yet again",
- self.test_db,
)
# Should not be retryable anymore
retryable = Database.get_retryable_jobs(
max_retries=3,
- db_path=self.test_db,
)
self.assertEqual(len(retryable), 0)
def test_user_queue_isolation(self) -> None:
"""Ensure users only see their own jobs."""
# Create second user
- user2_id, _ = Database.create_user("user2@example.com", self.test_db)
+ user2_id, _ = Database.create_user("user2@example.com")
# Add jobs for both users
job1 = Database.add_to_queue(
"https://example.com/1",
"test@example.com",
self.user_id,
- self.test_db,
)
job2 = Database.add_to_queue(
"https://example.com/2",
"user2@example.com",
user2_id,
- self.test_db,
)
# Get user-specific queue status
- user1_jobs = Database.get_user_queue_status(self.user_id, self.test_db)
- user2_jobs = Database.get_user_queue_status(user2_id, self.test_db)
+ user1_jobs = Database.get_user_queue_status(self.user_id)
+ user2_jobs = Database.get_user_queue_status(user2_id)
self.assertEqual(len(user1_jobs), 1)
self.assertEqual(user1_jobs[0]["id"], job1)
@@ -1177,26 +1067,23 @@ class TestQueueOperations(Test.TestCase):
"https://example.com/1",
"test@example.com",
self.user_id,
- self.test_db,
)
job2 = Database.add_to_queue(
"https://example.com/2",
"test@example.com",
self.user_id,
- self.test_db,
)
job3 = Database.add_to_queue(
"https://example.com/3",
"test@example.com",
self.user_id,
- self.test_db,
)
- Database.update_job_status(job2, "processing", db_path=self.test_db)
- Database.update_job_status(job3, "error", "Failed", self.test_db)
+ Database.update_job_status(job2, "processing")
+ Database.update_job_status(job3, "error", "Failed")
# Get status counts
- counts = Database.get_user_status_counts(self.user_id, self.test_db)
+ counts = Database.get_user_status_counts(self.user_id)
self.assertEqual(counts.get("pending", 0), 1)
self.assertEqual(counts.get("processing", 0), 1)
@@ -1208,24 +1095,13 @@ class TestEpisodeManagement(Test.TestCase):
def setUp(self) -> None:
"""Set up test database with a user."""
- self.test_db = "_/var/podcastitlater/test_podcast.db"
-
- # Ensure test directory exists
- test_db_dir = pathlib.Path(self.test_db).parent
- test_db_dir.mkdir(parents=True, exist_ok=True)
+ Database.init_db()
+ self.user_id, _ = Database.create_user("test@example.com")
- # Clean up any existing test database
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
- Database.init_db(self.test_db)
- self.user_id, _ = Database.create_user("test@example.com", self.test_db)
-
- def tearDown(self) -> None:
+ @staticmethod
+ def tearDown() -> None:
"""Clean up test database."""
- test_db_path = pathlib.Path(self.test_db)
- if test_db_path.exists():
- test_db_path.unlink()
+ Database.teardown()
def test_create_episode(self) -> None:
"""Create episode with user association."""
@@ -1235,7 +1111,6 @@ class TestEpisodeManagement(Test.TestCase):
duration=300,
content_length=5000,
user_id=self.user_id,
- db_path=self.test_db,
)
self.assertIsInstance(episode_id, int)
@@ -1250,7 +1125,6 @@ class TestEpisodeManagement(Test.TestCase):
100,
1000,
self.user_id,
- self.test_db,
)
time.sleep(0.01)
ep2 = Database.create_episode(
@@ -1259,7 +1133,6 @@ class TestEpisodeManagement(Test.TestCase):
200,
2000,
self.user_id,
- self.test_db,
)
time.sleep(0.01)
ep3 = Database.create_episode(
@@ -1268,11 +1141,10 @@ class TestEpisodeManagement(Test.TestCase):
300,
3000,
self.user_id,
- self.test_db,
)
# Get recent episodes
- episodes = Database.get_recent_episodes(limit=10, db_path=self.test_db)
+ episodes = Database.get_recent_episodes(limit=10)
self.assertEqual(len(episodes), 3)
# Should be in reverse chronological order
@@ -1283,7 +1155,7 @@ class TestEpisodeManagement(Test.TestCase):
def test_get_user_episodes(self) -> None:
"""Ensure user isolation for episodes."""
# Create second user
- user2_id, _ = Database.create_user("user2@example.com", self.test_db)
+ user2_id, _ = Database.create_user("user2@example.com")
# Create episodes for both users
ep1 = Database.create_episode(
@@ -1292,7 +1164,6 @@ class TestEpisodeManagement(Test.TestCase):
100,
1000,
self.user_id,
- self.test_db,
)
ep2 = Database.create_episode(
"User2 Article",
@@ -1300,15 +1171,13 @@ class TestEpisodeManagement(Test.TestCase):
200,
2000,
user2_id,
- self.test_db,
)
# Get user-specific episodes
user1_episodes = Database.get_user_all_episodes(
self.user_id,
- self.test_db,
)
- user2_episodes = Database.get_user_all_episodes(user2_id, self.test_db)
+ user2_episodes = Database.get_user_all_episodes(user2_id)
self.assertEqual(len(user1_episodes), 1)
self.assertEqual(user1_episodes[0]["id"], ep1)
@@ -1324,10 +1193,9 @@ class TestEpisodeManagement(Test.TestCase):
duration=12345,
content_length=98765,
user_id=self.user_id,
- db_path=self.test_db,
)
- episodes = Database.get_user_all_episodes(self.user_id, self.test_db)
+ episodes = Database.get_user_all_episodes(self.user_id)
episode = episodes[0]
self.assertEqual(episode["duration"], 12345)