summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater')
-rw-r--r--Biz/PodcastItLater/Admin.py1068
-rw-r--r--Biz/PodcastItLater/Billing.py581
-rw-r--r--Biz/PodcastItLater/Core.py2174
-rw-r--r--Biz/PodcastItLater/DESIGN.md43
-rw-r--r--Biz/PodcastItLater/Episode.py390
-rw-r--r--Biz/PodcastItLater/INFRASTRUCTURE.md38
-rw-r--r--Biz/PodcastItLater/STRIPE_TESTING.md114
-rw-r--r--Biz/PodcastItLater/TESTING.md45
-rw-r--r--Biz/PodcastItLater/Test.py276
-rw-r--r--Biz/PodcastItLater/TestMetricsView.py121
-rw-r--r--Biz/PodcastItLater/UI.py755
-rw-r--r--Biz/PodcastItLater/Web.nix93
-rw-r--r--Biz/PodcastItLater/Web.py3480
-rw-r--r--Biz/PodcastItLater/Worker.nix63
-rw-r--r--Biz/PodcastItLater/Worker.py2199
15 files changed, 11440 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Admin.py b/Biz/PodcastItLater/Admin.py
new file mode 100644
index 0000000..6f60948
--- /dev/null
+++ b/Biz/PodcastItLater/Admin.py
@@ -0,0 +1,1068 @@
+"""
+PodcastItLater Admin Interface.
+
+Admin pages and functionality for managing users and queue items.
+"""
+
+# : out podcastitlater-admin
+# : dep ludic
+# : dep httpx
+# : dep starlette
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+import Biz.PodcastItLater.Core as Core
+import Biz.PodcastItLater.UI as UI
+import ludic.html as html
+
+# i need to import these unused because bild cannot get local transitive python
+# dependencies yet
+import Omni.App as App # noqa: F401
+import Omni.Log as Log # noqa: F401
+import Omni.Test as Test # noqa: F401
+import sys
+import typing
+from ludic.attrs import Attrs
+from ludic.components import Component
+from ludic.types import AnyChildren
+from ludic.web import Request
+from ludic.web.datastructures import FormData
+from ludic.web.responses import Response
+from typing import override
+
+
+class MetricsAttrs(Attrs):
+ """Attributes for Metrics component."""
+
+ metrics: dict[str, typing.Any]
+ user: dict[str, typing.Any] | None
+
+
+class MetricCardAttrs(Attrs):
+ """Attributes for MetricCard component."""
+
+ title: str
+ value: int
+ icon: str
+
+
+class MetricCard(Component[AnyChildren, MetricCardAttrs]):
+ """Display a single metric card."""
+
+ @override
+ def render(self) -> html.div:
+ title = self.attrs["title"]
+ value = self.attrs["value"]
+ icon = self.attrs.get("icon", "bi-bar-chart")
+
+ return html.div(
+ html.div(
+ html.div(
+ html.i(classes=["bi", icon, "text-primary", "fs-2"]),
+ classes=["col-auto"],
+ ),
+ html.div(
+ html.h6(title, classes=["text-muted", "mb-1"]),
+ html.h3(str(value), classes=["mb-0"]),
+ classes=["col"],
+ ),
+ classes=["row", "align-items-center"],
+ ),
+ classes=["card-body"],
+ )
+
+
+class TopEpisodesTableAttrs(Attrs):
+ """Attributes for TopEpisodesTable component."""
+
+ episodes: list[dict[str, typing.Any]]
+ metric_name: str
+ count_key: str
+
+
+class TopEpisodesTable(Component[AnyChildren, TopEpisodesTableAttrs]):
+ """Display a table of top episodes by a metric."""
+
+ @override
+ def render(self) -> html.div:
+ episodes = self.attrs["episodes"]
+ metric_name = self.attrs["metric_name"]
+ count_key = self.attrs["count_key"]
+
+ if not episodes:
+ return html.div(
+ html.p(
+ "No data yet",
+ classes=["text-muted", "text-center", "py-3"],
+ ),
+ classes=["card-body"],
+ )
+
+ return html.div(
+ html.div(
+ html.table(
+ html.thead(
+ html.tr(
+ html.th("#", classes=["text-muted"]),
+ html.th("Title"),
+ html.th("Author", classes=["text-muted"]),
+ html.th(
+ metric_name,
+ classes=["text-end", "text-muted"],
+ ),
+ ),
+ classes=["table-light"],
+ ),
+ html.tbody(
+ *[
+ html.tr(
+ html.td(
+ str(idx + 1),
+ classes=["text-muted"],
+ ),
+ html.td(
+ TruncatedText(
+ text=episode["title"],
+ max_length=Core.TITLE_TRUNCATE_LENGTH,
+ ),
+ ),
+ html.td(
+ episode.get("author") or "-",
+ classes=["text-muted"],
+ ),
+ html.td(
+ str(episode[count_key]),
+ classes=["text-end"],
+ ),
+ )
+ for idx, episode in enumerate(episodes)
+ ],
+ ),
+ classes=["table", "table-hover", "mb-0"],
+ ),
+ classes=["table-responsive"],
+ ),
+ classes=["card-body", "p-0"],
+ )
+
+
+class MetricsDashboard(Component[AnyChildren, MetricsAttrs]):
+ """Admin metrics dashboard showing aggregate statistics."""
+
+ @override
+ def render(self) -> UI.PageLayout:
+ metrics = self.attrs["metrics"]
+ user = self.attrs.get("user")
+
+ return UI.PageLayout(
+ html.div(
+ html.h2(
+ html.i(classes=["bi", "bi-people", "me-2"]),
+ "Growth & Usage",
+ classes=["mb-4"],
+ ),
+ # Growth & Usage cards
+ html.div(
+ html.div(
+ html.div(
+ MetricCard(
+ title="Total Users",
+ value=metrics.get("total_users", 0),
+ icon="bi-people",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ html.div(
+ html.div(
+ MetricCard(
+ title="Active Subs",
+ value=metrics.get("active_subscriptions", 0),
+ icon="bi-credit-card",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ html.div(
+ html.div(
+ MetricCard(
+ title="Submissions (24h)",
+ value=metrics.get("submissions_24h", 0),
+ icon="bi-activity",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ html.div(
+ html.div(
+ MetricCard(
+ title="Submissions (7d)",
+ value=metrics.get("submissions_7d", 0),
+ icon="bi-calendar-week",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ classes=["row", "g-3", "mb-5"],
+ ),
+ html.h2(
+ html.i(classes=["bi", "bi-graph-up", "me-2"]),
+ "Episode Metrics",
+ classes=["mb-4"],
+ ),
+ # Summary cards
+ html.div(
+ html.div(
+ html.div(
+ MetricCard(
+ title="Total Episodes",
+ value=metrics["total_episodes"],
+ icon="bi-collection",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ html.div(
+ html.div(
+ MetricCard(
+ title="Total Plays",
+ value=metrics["total_plays"],
+ icon="bi-play-circle",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ html.div(
+ html.div(
+ MetricCard(
+ title="Total Downloads",
+ value=metrics["total_downloads"],
+ icon="bi-download",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ html.div(
+ html.div(
+ MetricCard(
+ title="Total Adds",
+ value=metrics["total_adds"],
+ icon="bi-plus-circle",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-md-3"],
+ ),
+ classes=["row", "g-3", "mb-4"],
+ ),
+ # Top episodes tables
+ html.div(
+ html.div(
+ html.div(
+ html.div(
+ html.h5(
+ html.i(
+ classes=[
+ "bi",
+ "bi-play-circle-fill",
+ "me-2",
+ ],
+ ),
+ "Most Played",
+ classes=["card-title", "mb-0"],
+ ),
+ classes=["card-header", "bg-white"],
+ ),
+ TopEpisodesTable(
+ episodes=metrics["most_played"],
+ metric_name="Plays",
+ count_key="play_count",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-lg-4"],
+ ),
+ html.div(
+ html.div(
+ html.div(
+ html.h5(
+ html.i(
+ classes=[
+ "bi",
+ "bi-download",
+ "me-2",
+ ],
+ ),
+ "Most Downloaded",
+ classes=["card-title", "mb-0"],
+ ),
+ classes=["card-header", "bg-white"],
+ ),
+ TopEpisodesTable(
+ episodes=metrics["most_downloaded"],
+ metric_name="Downloads",
+ count_key="download_count",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-lg-4"],
+ ),
+ html.div(
+ html.div(
+ html.div(
+ html.h5(
+ html.i(
+ classes=[
+ "bi",
+ "bi-plus-circle-fill",
+ "me-2",
+ ],
+ ),
+ "Most Added to Feeds",
+ classes=["card-title", "mb-0"],
+ ),
+ classes=["card-header", "bg-white"],
+ ),
+ TopEpisodesTable(
+ episodes=metrics["most_added"],
+ metric_name="Adds",
+ count_key="add_count",
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-lg-4"],
+ ),
+ classes=["row", "g-3"],
+ ),
+ ),
+ user=user,
+ current_page="admin-metrics",
+ error=None,
+ )
+
+
+class AdminUsersAttrs(Attrs):
+ """Attributes for AdminUsers component."""
+
+ users: list[dict[str, typing.Any]]
+ user: dict[str, typing.Any] | None
+
+
+class StatusBadgeAttrs(Attrs):
+ """Attributes for StatusBadge component."""
+
+ status: str
+ count: int | None
+
+
+class StatusBadge(Component[AnyChildren, StatusBadgeAttrs]):
+ """Display a status badge with optional count."""
+
+ @override
+ def render(self) -> html.span:
+ status = self.attrs["status"]
+ count = self.attrs.get("count", None)
+
+ text = f"{status.upper()}: {count}" if count is not None else status
+ badge_class = self.get_status_badge_class(status)
+
+ return html.span(
+ text,
+ classes=["badge", badge_class, "me-3" if count is not None else ""],
+ )
+
+ @staticmethod
+ def get_status_badge_class(status: str) -> str:
+ """Get Bootstrap badge class for status."""
+ return {
+ "pending": "bg-warning text-dark",
+ "processing": "bg-primary",
+ "completed": "bg-success",
+ "active": "bg-success",
+ "error": "bg-danger",
+ "cancelled": "bg-secondary",
+ "disabled": "bg-danger",
+ }.get(status, "bg-secondary")
+
+
+class TruncatedTextAttrs(Attrs):
+ """Attributes for TruncatedText component."""
+
+ text: str
+ max_length: int
+ max_width: str
+
+
+class TruncatedText(Component[AnyChildren, TruncatedTextAttrs]):
+ """Display truncated text with tooltip."""
+
+ @override
+ def render(self) -> html.div:
+ text = self.attrs["text"]
+ max_length = self.attrs["max_length"]
+ max_width = self.attrs.get("max_width", "200px")
+
+ truncated = (
+ text[:max_length] + "..." if len(text) > max_length else text
+ )
+
+ return html.div(
+ truncated,
+ title=text,
+ classes=["text-truncate"],
+ style={"max-width": max_width},
+ )
+
+
+class ActionButtonsAttrs(Attrs):
+ """Attributes for ActionButtons component."""
+
+ job_id: int
+ status: str
+
+
+class ActionButtons(Component[AnyChildren, ActionButtonsAttrs]):
+ """Render action buttons for queue items."""
+
+ @override
+ def render(self) -> html.div:
+ job_id = self.attrs["job_id"]
+ status = self.attrs["status"]
+
+ buttons = []
+
+ if status != "completed":
+ buttons.append(
+ html.button(
+ html.i(classes=["bi", "bi-arrow-clockwise", "me-1"]),
+ "Retry",
+ hx_post=f"/queue/{job_id}/retry",
+ hx_target="body",
+ hx_swap="outerHTML",
+ classes=["btn", "btn-sm", "btn-success", "me-1"],
+ disabled=status == "completed",
+ ),
+ )
+
+ buttons.append(
+ html.button(
+ html.i(classes=["bi", "bi-trash", "me-1"]),
+ "Delete",
+ hx_delete=f"/queue/{job_id}",
+ hx_confirm="Are you sure you want to delete this queue item?",
+ hx_target="body",
+ hx_swap="outerHTML",
+ classes=["btn", "btn-sm", "btn-danger"],
+ ),
+ )
+
+ return html.div(
+ *buttons,
+ classes=["btn-group"],
+ )
+
+
+class QueueTableRowAttrs(Attrs):
+ """Attributes for QueueTableRow component."""
+
+ item: dict[str, typing.Any]
+
+
+class QueueTableRow(Component[AnyChildren, QueueTableRowAttrs]):
+ """Render a single queue table row."""
+
+ @override
+ def render(self) -> html.tr:
+ item = self.attrs["item"]
+
+ return html.tr(
+ html.td(str(item["id"])),
+ html.td(
+ TruncatedText(
+ text=item["url"],
+ max_length=Core.TITLE_TRUNCATE_LENGTH,
+ max_width="300px",
+ ),
+ ),
+ html.td(
+ TruncatedText(
+ text=item.get("title") or "-",
+ max_length=Core.TITLE_TRUNCATE_LENGTH,
+ ),
+ ),
+ html.td(item["email"] or "-"),
+ html.td(StatusBadge(status=item["status"])),
+ html.td(str(item.get("retry_count", 0))),
+ html.td(html.small(item["created_at"], classes=["text-muted"])),
+ html.td(
+ TruncatedText(
+ text=item["error_message"] or "-",
+ max_length=Core.ERROR_TRUNCATE_LENGTH,
+ )
+ if item["error_message"]
+ else html.span("-", classes=["text-muted"]),
+ ),
+ html.td(ActionButtons(job_id=item["id"], status=item["status"])),
+ )
+
+
+class EpisodeTableRowAttrs(Attrs):
+ """Attributes for EpisodeTableRow component."""
+
+ episode: dict[str, typing.Any]
+
+
+class EpisodeTableRow(Component[AnyChildren, EpisodeTableRowAttrs]):
+ """Render a single episode table row."""
+
+ @override
+ def render(self) -> html.tr:
+ episode = self.attrs["episode"]
+
+ return html.tr(
+ html.td(str(episode["id"])),
+ html.td(
+ TruncatedText(
+ text=episode["title"],
+ max_length=Core.TITLE_TRUNCATE_LENGTH,
+ ),
+ ),
+ html.td(
+ html.a(
+ html.i(classes=["bi", "bi-play-circle", "me-1"]),
+ "Listen",
+ href=episode["audio_url"],
+ target="_blank",
+ classes=["btn", "btn-sm", "btn-outline-primary"],
+ ),
+ ),
+ html.td(
+ f"{episode['duration']}s" if episode["duration"] else "-",
+ ),
+ html.td(
+ f"{episode['content_length']:,} chars"
+ if episode["content_length"]
+ else "-",
+ ),
+ html.td(html.small(episode["created_at"], classes=["text-muted"])),
+ )
+
+
+class UserTableRowAttrs(Attrs):
+ """Attributes for UserTableRow component."""
+
+ user: dict[str, typing.Any]
+
+
+class UserTableRow(Component[AnyChildren, UserTableRowAttrs]):
+ """Render a single user table row."""
+
+ @override
+ def render(self) -> html.tr:
+ user = self.attrs["user"]
+
+ return html.tr(
+ html.td(user["email"]),
+ html.td(html.small(user["created_at"], classes=["text-muted"])),
+ html.td(StatusBadge(status=user.get("status", "pending"))),
+ html.td(
+ html.select(
+ html.option(
+ "Pending",
+ value="pending",
+ selected=user.get("status") == "pending",
+ ),
+ html.option(
+ "Active",
+ value="active",
+ selected=user.get("status") == "active",
+ ),
+ html.option(
+ "Disabled",
+ value="disabled",
+ selected=user.get("status") == "disabled",
+ ),
+ name="status",
+ hx_post=f"/admin/users/{user['id']}/status",
+ hx_trigger="change",
+ hx_target="body",
+ hx_swap="outerHTML",
+ classes=["form-select", "form-select-sm"],
+ ),
+ ),
+ )
+
+
+def create_table_header(columns: list[str]) -> html.thead:
+ """Create a table header with given column names."""
+ return html.thead(
+ html.tr(*[html.th(col, scope="col") for col in columns]),
+ classes=["table-light"],
+ )
+
+
+class AdminUsers(Component[AnyChildren, AdminUsersAttrs]):
+ """Admin view for managing users."""
+
+ @override
+ def render(self) -> UI.PageLayout:
+ users = self.attrs["users"]
+ user = self.attrs.get("user")
+
+ return UI.PageLayout(
+ html.h2(
+ "User Management",
+ classes=["mb-4"],
+ ),
+ self._render_users_table(users),
+ user=user,
+ current_page="admin-users",
+ error=None,
+ )
+
+ @staticmethod
+ def _render_users_table(
+ users: list[dict[str, typing.Any]],
+ ) -> html.div:
+ """Render users table."""
+ return html.div(
+ html.h2("All Users", classes=["mb-3"]),
+ html.div(
+ html.table(
+ create_table_header([
+ "Email",
+ "Created At",
+ "Status",
+ "Actions",
+ ]),
+ html.tbody(*[UserTableRow(user=user) for user in users]),
+ classes=["table", "table-hover", "table-striped"],
+ ),
+ classes=["table-responsive"],
+ ),
+ )
+
+
+class AdminViewAttrs(Attrs):
+ """Attributes for AdminView component."""
+
+ queue_items: list[dict[str, typing.Any]]
+ episodes: list[dict[str, typing.Any]]
+ status_counts: dict[str, int]
+ user: dict[str, typing.Any] | None
+
+
+class AdminView(Component[AnyChildren, AdminViewAttrs]):
+ """Admin view showing all queue items and episodes in tables."""
+
+ @override
+ def render(self) -> UI.PageLayout:
+ queue_items = self.attrs["queue_items"]
+ episodes = self.attrs["episodes"]
+ status_counts = self.attrs.get("status_counts", {})
+ user = self.attrs.get("user")
+
+ return UI.PageLayout(
+ html.div(
+ AdminView.render_content(
+ queue_items,
+ episodes,
+ status_counts,
+ ),
+ id="admin-content",
+ hx_get="/admin",
+ hx_trigger="every 10s",
+ hx_swap="innerHTML",
+ hx_target="#admin-content",
+ ),
+ user=user,
+ current_page="admin",
+ error=None,
+ )
+
+ @staticmethod
+ def render_content(
+ queue_items: list[dict[str, typing.Any]],
+ episodes: list[dict[str, typing.Any]],
+ status_counts: dict[str, int],
+ ) -> html.div:
+ """Render the main content of the admin page."""
+ return html.div(
+ html.h2(
+ "Admin Queue Status",
+ classes=["mb-4"],
+ ),
+ AdminView.render_status_summary(status_counts),
+ AdminView.render_queue_table(queue_items),
+ AdminView.render_episodes_table(episodes),
+ )
+
+ @staticmethod
+ def render_status_summary(status_counts: dict[str, int]) -> html.div:
+ """Render status summary section."""
+ return html.div(
+ html.h2("Status Summary", classes=["mb-3"]),
+ html.div(
+ *[
+ StatusBadge(status=status, count=count)
+ for status, count in status_counts.items()
+ ],
+ classes=["mb-4"],
+ ),
+ )
+
+ @staticmethod
+ def render_queue_table(
+ queue_items: list[dict[str, typing.Any]],
+ ) -> html.div:
+ """Render queue items table."""
+ return html.div(
+ html.h2("Queue Items", classes=["mb-3"]),
+ html.div(
+ html.table(
+ create_table_header([
+ "ID",
+ "URL",
+ "Title",
+ "Email",
+ "Status",
+ "Retries",
+ "Created",
+ "Error",
+ "Actions",
+ ]),
+ html.tbody(*[
+ QueueTableRow(item=item) for item in queue_items
+ ]),
+ classes=["table", "table-hover", "table-sm"],
+ ),
+ classes=["table-responsive", "mb-5"],
+ ),
+ )
+
+ @staticmethod
+ def render_episodes_table(
+ episodes: list[dict[str, typing.Any]],
+ ) -> html.div:
+ """Render episodes table."""
+ return html.div(
+ html.h2("Completed Episodes", classes=["mb-3"]),
+ html.div(
+ html.table(
+ create_table_header([
+ "ID",
+ "Title",
+ "Audio URL",
+ "Duration",
+ "Content Length",
+ "Created",
+ ]),
+ html.tbody(*[
+ EpisodeTableRow(episode=episode) for episode in episodes
+ ]),
+ classes=["table", "table-hover", "table-sm"],
+ ),
+ classes=["table-responsive"],
+ ),
+ )
+
+
+def admin_queue_status(request: Request) -> AdminView | Response | html.div:
+ """Return admin view showing all queue items and episodes."""
+ # Check if user is logged in
+ user_id = request.session.get("user_id")
+ if not user_id:
+ # Redirect to login
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/"},
+ )
+
+ user = Core.Database.get_user_by_id(
+ user_id,
+ )
+ if not user:
+ # Invalid session
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/"},
+ )
+
+ # Check if user is admin
+ if not Core.is_admin(user):
+ # Forbidden - redirect to home with error
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/?error=forbidden"},
+ )
+
+ # Admins can see all data (excluding completed items)
+ all_queue_items = [
+ item
+ for item in Core.Database.get_all_queue_items(None)
+ if item.get("status") != "completed"
+ ]
+ all_episodes = Core.Database.get_all_episodes(
+ None,
+ )
+
+ # Get overall status counts for all users
+ status_counts: dict[str, int] = {}
+ for item in all_queue_items:
+ status = item.get("status", "unknown")
+ status_counts[status] = status_counts.get(status, 0) + 1
+
+ # Check if this is an HTMX request for auto-update
+ if request.headers.get("HX-Request") == "true":
+ # Return just the content div for HTMX updates
+ content = AdminView.render_content(
+ all_queue_items,
+ all_episodes,
+ status_counts,
+ )
+ return html.div(
+ content,
+ hx_get="/admin",
+ hx_trigger="every 10s",
+ hx_swap="innerHTML",
+ )
+
+ return AdminView(
+ queue_items=all_queue_items,
+ episodes=all_episodes,
+ status_counts=status_counts,
+ user=user,
+ )
+
+
+def retry_queue_item(request: Request, job_id: int) -> Response:
+ """Retry a failed queue item."""
+ try:
+ # Check if user owns this job or is admin
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ job = Core.Database.get_job_by_id(
+ job_id,
+ )
+ if job is None:
+ return Response("Job not found", status_code=404)
+
+ # Check ownership or admin status
+ user = Core.Database.get_user_by_id(user_id)
+ if job.get("user_id") != user_id and not Core.is_admin(user):
+ return Response("Forbidden", status_code=403)
+
+ Core.Database.retry_job(job_id)
+
+ # Check if request is from admin page via referer header
+ is_from_admin = "/admin" in request.headers.get("referer", "")
+
+ # Redirect to admin if from admin page, trigger update otherwise
+ if is_from_admin:
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Redirect": "/admin"},
+ )
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Trigger": "queue-updated"},
+ )
+ except (ValueError, KeyError) as e:
+ return Response(
+ f"Error retrying job: {e!s}",
+ status_code=500,
+ )
+
+
+def delete_queue_item(request: Request, job_id: int) -> Response:
+ """Delete a queue item."""
+ try:
+ # Check if user owns this job or is admin
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ job = Core.Database.get_job_by_id(
+ job_id,
+ )
+ if job is None:
+ return Response("Job not found", status_code=404)
+
+ # Check ownership or admin status
+ user = Core.Database.get_user_by_id(user_id)
+ if job.get("user_id") != user_id and not Core.is_admin(user):
+ return Response("Forbidden", status_code=403)
+
+ Core.Database.delete_job(job_id)
+
+ # Check if request is from admin page via referer header
+ is_from_admin = "/admin" in request.headers.get("referer", "")
+
+ # Redirect to admin if from admin page, trigger update otherwise
+ if is_from_admin:
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Redirect": "/admin"},
+ )
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Trigger": "queue-updated"},
+ )
+ except (ValueError, KeyError) as e:
+ return Response(
+ f"Error deleting job: {e!s}",
+ status_code=500,
+ )
+
+
+def admin_users(request: Request) -> AdminUsers | Response:
+ """Admin page for managing users."""
+ # Check if user is logged in and is admin
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/"},
+ )
+
+ user = Core.Database.get_user_by_id(
+ user_id,
+ )
+ if not user or not Core.is_admin(user):
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/?error=forbidden"},
+ )
+
+ # Get all users
+ with Core.Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT id, email, created_at, status FROM users "
+ "ORDER BY created_at DESC",
+ )
+ rows = cursor.fetchall()
+ users = [dict(row) for row in rows]
+
+ return AdminUsers(users=users, user=user)
+
+
+def update_user_status(
+ request: Request,
+ user_id: int,
+ data: FormData,
+) -> Response:
+ """Update user account status."""
+ # Check if user is logged in and is admin
+ session_user_id = request.session.get("user_id")
+ if not session_user_id:
+ return Response("Unauthorized", status_code=401)
+
+ user = Core.Database.get_user_by_id(
+ session_user_id,
+ )
+ if not user or not Core.is_admin(user):
+ return Response("Forbidden", status_code=403)
+
+ # Get new status from form data
+ new_status_raw = data.get("status", "pending")
+ new_status = (
+ new_status_raw if isinstance(new_status_raw, str) else "pending"
+ )
+ if new_status not in {"pending", "active", "disabled"}:
+ return Response("Invalid status", status_code=400)
+
+ # Update user status
+ Core.Database.update_user_status(
+ user_id,
+ new_status,
+ )
+
+ # Redirect back to users page
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Redirect": "/admin/users"},
+ )
+
+
+def toggle_episode_public(request: Request, episode_id: int) -> Response:
+ """Toggle episode public/private status."""
+ # Check if user is logged in and is admin
+ session_user_id = request.session.get("user_id")
+ if not session_user_id:
+ return Response("Unauthorized", status_code=401)
+
+ user = Core.Database.get_user_by_id(session_user_id)
+ if not user or not Core.is_admin(user):
+ return Response("Forbidden", status_code=403)
+
+ # Get current episode status
+ episode = Core.Database.get_episode_by_id(episode_id)
+ if not episode:
+ return Response("Episode not found", status_code=404)
+
+ # Toggle public status
+ current_public = episode.get("is_public", 0) == 1
+ if current_public:
+ Core.Database.unmark_episode_public(episode_id)
+ else:
+ Core.Database.mark_episode_public(episode_id)
+
+ # Redirect to home page to see updated status
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Redirect": "/"},
+ )
+
+
+def admin_metrics(request: Request) -> MetricsDashboard | Response:
+ """Admin metrics dashboard showing episode statistics."""
+ # Check if user is logged in and is admin
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/"},
+ )
+
+ user = Core.Database.get_user_by_id(
+ user_id,
+ )
+ if not user or not Core.is_admin(user):
+ return Response(
+ "",
+ status_code=302,
+ headers={"Location": "/?error=forbidden"},
+ )
+
+ # Get metrics data
+ metrics = Core.Database.get_metrics_summary()
+
+ return MetricsDashboard(metrics=metrics, user=user)
+
+
+def main() -> None:
+ """Admin tests are currently in Web."""
+ if "test" in sys.argv:
+ sys.exit(0)
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py
new file mode 100644
index 0000000..9f3739d
--- /dev/null
+++ b/Biz/PodcastItLater/Billing.py
@@ -0,0 +1,581 @@
+"""
+PodcastItLater Billing Integration.
+
+Stripe subscription management and usage enforcement.
+"""
+
+# : out podcastitlater-billing
+# : dep stripe
+# : dep pytest
+# : dep pytest-mock
+import Biz.PodcastItLater.Core as Core
+import json
+import logging
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import os
+import stripe
+import sys
+import typing
+from datetime import datetime
+from datetime import timezone
+
+logger = logging.getLogger(__name__)
+Log.setup(logger)
+
+# Stripe configuration
+stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "")
+STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "")
+
+# Price IDs from Stripe dashboard
+STRIPE_PRICE_ID_PAID = os.getenv("STRIPE_PRICE_ID_PAID", "")
+
+# Map Stripe price IDs to tier names
+PRICE_TO_TIER = {
+ STRIPE_PRICE_ID_PAID: "paid",
+}
+
+# Tier limits (None = unlimited)
+TIER_LIMITS: dict[str, dict[str, int | None]] = {
+ "free": {
+ "articles_per_period": 10,
+ "minutes_per_period": None,
+ },
+ "paid": {
+ "articles_per_period": None,
+ "minutes_per_period": None,
+ },
+}
+
+# Price map for checkout
+PRICE_MAP = {
+ "paid": STRIPE_PRICE_ID_PAID,
+}
+
+
+def get_period_boundaries(
+ user: dict[str, typing.Any],
+) -> tuple[datetime, datetime]:
+ """Get billing period boundaries for user.
+
+ For paid users: use Stripe subscription period.
+ For free users: lifetime (from account creation to far future).
+
+ Returns:
+ tuple: (period_start, period_end) as datetime objects
+ """
+ if user.get("plan_tier") != "free" and user.get("current_period_start"):
+ # Paid user - use Stripe billing period
+ period_start = datetime.fromisoformat(user["current_period_start"])
+ period_end = datetime.fromisoformat(user["current_period_end"])
+ else:
+ # Free user - lifetime limit from account creation
+ period_start = datetime.fromisoformat(user["created_at"])
+ # Set far future end date (100 years from now)
+ now = datetime.now(timezone.utc)
+ period_end = now.replace(year=now.year + 100)
+
+ return period_start, period_end
+
+
+def get_usage(
+ user_id: int,
+ period_start: datetime,
+ period_end: datetime,
+) -> dict[str, int]:
+ """Get usage stats for user in billing period.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ return Core.Database.get_usage(user_id, period_start, period_end)
+
+
+def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]:
+ """Check if user can submit article based on tier limits.
+
+ Returns:
+ tuple: (allowed: bool, message: str, usage: dict)
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return False, "User not found", {}
+
+ tier = user.get("plan_tier", "free")
+ limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"])
+
+ # Get billing period boundaries
+ period_start, period_end = get_period_boundaries(user)
+
+ # Get current usage
+ usage = get_usage(user_id, period_start, period_end)
+
+ # Check article limit
+ article_limit = limits.get("articles_per_period")
+ if article_limit is not None and usage["articles"] >= article_limit:
+ msg = (
+ f"You've reached your limit of {article_limit} articles "
+ "per period. Upgrade to continue."
+ )
+ return (False, msg, usage)
+
+ # Check minutes limit (if implemented)
+ minute_limit = limits.get("minutes_per_period")
+ if minute_limit is not None and usage.get("minutes", 0) >= minute_limit:
+ return (
+ False,
+ f"You've reached your limit of {minute_limit} minutes per period. "
+ "Please upgrade to continue.",
+ usage,
+ )
+
+ return True, "", usage
+
+
+def create_checkout_session(user_id: int, tier: str, base_url: str) -> str:
+ """Create Stripe Checkout session for subscription.
+
+ Args:
+ user_id: User ID
+ tier: Subscription tier (paid)
+ base_url: Base URL for success/cancel redirects
+
+ Returns:
+ Checkout session URL to redirect user to
+
+ Raises:
+ ValueError: If tier is invalid or price ID not configured
+ """
+ if tier not in PRICE_MAP:
+ msg = f"Invalid tier: {tier}"
+ raise ValueError(msg)
+
+ price_id = PRICE_MAP[tier]
+ if not price_id:
+ msg = f"Stripe price ID not configured for tier: {tier}"
+ raise ValueError(msg)
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ # Create checkout session
+ session_params = {
+ "mode": "subscription",
+ "line_items": [{"price": price_id, "quantity": 1}],
+ "success_url": f"{base_url}/?status=success",
+ "cancel_url": f"{base_url}/?status=cancel",
+ "client_reference_id": str(user_id),
+ "metadata": {"user_id": str(user_id), "tier": tier},
+ "allow_promotion_codes": True,
+ }
+
+ # Use existing customer if available
+ if user.get("stripe_customer_id"):
+ session_params["customer"] = user["stripe_customer_id"]
+ else:
+ session_params["customer_email"] = user["email"]
+
+ session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type]
+
+ logger.info(
+ "Created checkout session for user %s, tier %s: %s",
+ user_id,
+ tier,
+ session.id,
+ )
+
+ return session.url # type: ignore[return-value]
+
+
+def create_portal_session(user_id: int, base_url: str) -> str:
+ """Create Stripe Billing Portal session.
+
+ Args:
+ user_id: User ID
+ base_url: Base URL for return redirect
+
+ Returns:
+ Portal session URL to redirect user to
+
+ Raises:
+ ValueError: If user has no Stripe customer ID or portal not configured
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ if not user.get("stripe_customer_id"):
+ msg = "User has no Stripe customer ID"
+ raise ValueError(msg)
+
+ session = stripe.billing_portal.Session.create(
+ customer=user["stripe_customer_id"],
+ return_url=f"{base_url}/account",
+ )
+
+ logger.info(
+ "Created portal session for user %s: %s",
+ user_id,
+ session.id,
+ )
+
+ return session.url
+
+
+def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]:
+ """Verify and process Stripe webhook event.
+
+ Args:
+ payload: Raw webhook body
+ sig_header: Stripe-Signature header value
+
+ Returns:
+ dict with processing status
+
+ Note:
+ May raise stripe.error.SignatureVerificationError if invalid signature
+ """
+ # Verify webhook signature (skip in test mode if secret not configured)
+ if STRIPE_WEBHOOK_SECRET:
+ event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call]
+ payload,
+ sig_header,
+ STRIPE_WEBHOOK_SECRET,
+ )
+ else:
+ # Test mode without signature verification
+ logger.warning(
+ "Webhook signature verification skipped (no STRIPE_WEBHOOK_SECRET)",
+ )
+ event = json.loads(payload.decode("utf-8"))
+
+ event_id = event["id"]
+ event_type = event["type"]
+
+ # Check if already processed (idempotency)
+ if Core.Database.has_processed_stripe_event(event_id):
+ logger.info("Skipping already processed event: %s", event_id)
+ return {"status": "skipped", "reason": "already_processed"}
+
+ # Process event based on type
+ logger.info("Processing webhook event: %s (%s)", event_id, event_type)
+
+ try:
+ if event_type == "checkout.session.completed":
+ _handle_checkout_completed(event["data"]["object"])
+ elif event_type == "customer.subscription.created":
+ _handle_subscription_created(event["data"]["object"])
+ elif event_type == "customer.subscription.updated":
+ _handle_subscription_updated(event["data"]["object"])
+ elif event_type == "customer.subscription.deleted":
+ _handle_subscription_deleted(event["data"]["object"])
+ elif event_type == "invoice.payment_failed":
+ _handle_payment_failed(event["data"]["object"])
+ else:
+ logger.info("Unhandled event type: %s", event_type)
+ return {"status": "ignored", "type": event_type}
+
+ # Mark event as processed
+ Core.Database.mark_stripe_event_processed(event_id, event_type, payload)
+ except Exception:
+ logger.exception("Error processing webhook event %s", event_id)
+ raise
+ else:
+ return {"status": "processed", "type": event_type}
+
+
+def _handle_checkout_completed(session: dict[str, typing.Any]) -> None:
+ """Handle checkout.session.completed event."""
+ client_ref = session.get("client_reference_id") or session.get(
+ "metadata",
+ {},
+ ).get("user_id")
+ customer_id = session.get("customer")
+
+ if not client_ref or not customer_id:
+ logger.warning(
+ "Missing user_id or customer_id in checkout session: %s",
+ session.get("id", "unknown"),
+ )
+ return
+
+ try:
+ user_id = int(client_ref)
+ except (ValueError, TypeError):
+ logger.warning(
+ "Invalid user_id in checkout session: %s",
+ client_ref,
+ )
+ return
+
+ # Link Stripe customer to user
+ Core.Database.set_user_stripe_customer(user_id, customer_id)
+ logger.info("Linked user %s to Stripe customer %s", user_id, customer_id)
+
+
+def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None:
+ """Handle customer.subscription.created event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None:
+ """Handle customer.subscription.updated event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None:
+ """Handle customer.subscription.deleted event."""
+ customer_id = subscription["customer"]
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Downgrade to free
+ Core.Database.downgrade_to_free(user["id"])
+ logger.info("Downgraded user %s to free tier", user["id"])
+
+
+def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None:
+ """Handle invoice.payment_failed event."""
+ customer_id = invoice["customer"]
+ subscription_id = invoice.get("subscription")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update subscription status to past_due
+ if subscription_id:
+ Core.Database.update_subscription_status(user["id"], "past_due")
+ logger.warning(
+ "Payment failed for user %s, subscription %s",
+ user["id"],
+ subscription_id,
+ )
+
+
+def _update_subscription_state(subscription: dict[str, typing.Any]) -> None:
+ """Update user subscription state from Stripe subscription object."""
+ customer_id = subscription.get("customer")
+ subscription_id = subscription.get("id")
+ status = subscription.get("status")
+ cancel_at_period_end = subscription.get("cancel_at_period_end", False)
+
+ if not customer_id or not subscription_id or not status:
+ logger.warning(
+ "Missing required fields in subscription: %s",
+ subscription_id,
+ )
+ return
+
+ # Get billing period - try multiple field names for API compatibility
+ period_start_ts = (
+ subscription.get("current_period_start")
+ or subscription.get("billing_cycle_anchor")
+ or subscription.get("start_date")
+ )
+ period_end_ts = subscription.get("current_period_end")
+
+ if not period_start_ts:
+ logger.warning(
+ "Missing period start in subscription: %s",
+ subscription_id,
+ )
+ return
+
+ period_start = datetime.fromtimestamp(period_start_ts, tz=timezone.utc)
+
+ # Calculate period end if not provided (assume monthly)
+ december = 12
+ january = 1
+ if not period_end_ts:
+ if period_start.month == december:
+ period_end = period_start.replace(
+ year=period_start.year + 1,
+ month=january,
+ )
+ else:
+ period_end = period_start.replace(month=period_start.month + 1)
+ else:
+ period_end = datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
+
+ # Determine tier from price ID
+ items = subscription.get("items", {})
+ data = items.get("data", [])
+ if not data:
+ logger.warning("No items in subscription: %s", subscription_id)
+ return
+
+ price_id = data[0].get("price", {}).get("id")
+ if not price_id:
+ logger.warning("No price ID in subscription: %s", subscription_id)
+ return
+
+ tier = PRICE_TO_TIER.get(price_id, "free")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update user subscription
+ Core.Database.update_user_subscription(
+ user["id"],
+ subscription_id,
+ status,
+ period_start,
+ period_end,
+ tier,
+ cancel_at_period_end,
+ )
+
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user["id"],
+ tier,
+ status,
+ )
+
+
+def get_tier_info(tier: str) -> dict[str, typing.Any]:
+ """Get tier information for display.
+
+ Returns:
+ dict with keys: name, articles_limit, price, description
+ """
+ tier_info: dict[str, dict[str, typing.Any]] = {
+ "free": {
+ "name": "Free",
+ "articles_limit": 10,
+ "price": "$0",
+ "description": "10 articles total",
+ },
+ "paid": {
+ "name": "Paid",
+ "articles_limit": None,
+ "price": "$12/mo",
+ "description": "Unlimited articles",
+ },
+ }
+ return tier_info.get(tier, tier_info["free"])
+
+
+# Tests
+# ruff: noqa: PLR6301, PLW0603, S101
+
+
+class TestWebhookHandling(Test.TestCase):
+ """Test Stripe webhook handling."""
+
+ def setUp(self) -> None:
+ """Set up test database."""
+ Core.Database.init_db()
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ Core.Database.teardown()
+
+ def test_full_checkout_flow(self) -> None:
+ """Test complete checkout flow from session to subscription."""
+ # Create test user
+ user_id, _token = Core.Database.create_user("test@example.com")
+
+ # Temporarily set price mapping for test
+ global PRICE_TO_TIER
+ old_mapping = PRICE_TO_TIER.copy()
+ PRICE_TO_TIER["price_test_paid"] = "paid"
+
+ try:
+ # Step 1: Handle checkout.session.completed
+ checkout_session = {
+ "id": "cs_test123",
+ "client_reference_id": str(user_id),
+ "customer": "cus_test123",
+ "metadata": {"user_id": str(user_id), "tier": "paid"},
+ }
+ _handle_checkout_completed(checkout_session)
+
+ # Verify customer was linked
+ user = Core.Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["stripe_customer_id"], "cus_test123")
+
+ # Step 2: Handle customer.subscription.created
+ # (newer API uses billing_cycle_anchor instead of current_period_*)
+ subscription = {
+ "id": "sub_test123",
+ "customer": "cus_test123",
+ "status": "active",
+ "billing_cycle_anchor": 1700000000,
+ "cancel_at_period_end": False,
+ "items": {
+ "data": [
+ {
+ "price": {
+ "id": "price_test_paid",
+ },
+ },
+ ],
+ },
+ }
+ _update_subscription_state(subscription)
+
+ # Verify subscription was created and user upgraded
+ user = Core.Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["plan_tier"], "paid")
+ self.assertEqual(user["subscription_status"], "active")
+ self.assertEqual(user["stripe_subscription_id"], "sub_test123")
+ self.assertEqual(user["stripe_customer_id"], "cus_test123")
+ finally:
+ PRICE_TO_TIER = old_mapping
+
+ def test_webhook_missing_fields(self) -> None:
+ """Test handling webhook with missing required fields."""
+ # Create test user
+ user_id, _token = Core.Database.create_user("test@example.com")
+ Core.Database.set_user_stripe_customer(user_id, "cus_test456")
+
+ # Mock subscription with missing current_period_start
+ subscription = {
+ "id": "sub_test456",
+ "customer": "cus_test456",
+ "status": "active",
+ # Missing current_period_start and current_period_end
+ "cancel_at_period_end": False,
+ "items": {"data": []},
+ }
+
+ # Should not crash, just log warning and return
+ _update_subscription_state(subscription)
+
+ # User should remain on free tier
+ user = Core.Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["plan_tier"], "free")
+
+
+def main() -> None:
+ """Run tests."""
+ if len(sys.argv) > 1 and sys.argv[1] == "test":
+ os.environ["AREA"] = "Test"
+ Test.run(App.Area.Test, [TestWebhookHandling])
+ else:
+ logger.error("Usage: billing.py test")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
new file mode 100644
index 0000000..3a88f22
--- /dev/null
+++ b/Biz/PodcastItLater/Core.py
@@ -0,0 +1,2174 @@
+"""Core, shared logic for PodcastItalater.
+
+Includes:
+- Database models
+- Data access layer
+- Shared types
+"""
+
+# : out podcastitlater-core
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+import hashlib
+import logging
+import Omni.App as App
+import Omni.Test as Test
+import os
+import pathlib
+import pytest
+import secrets
+import sqlite3
+import sys
+import time
+import typing
+import urllib.parse
+from collections.abc import Iterator
+from contextlib import contextmanager
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+
+CODEROOT = pathlib.Path(os.getenv("CODEROOT", "."))
+DATA_DIR = pathlib.Path(
+ os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"),
+)
+
+# Constants for UI display
+URL_TRUNCATE_LENGTH = 80
+TITLE_TRUNCATE_LENGTH = 50
+ERROR_TRUNCATE_LENGTH = 50
+
+# Admin whitelist
+ADMIN_EMAILS = ["ben@bensima.com", "admin@example.com"]
+
+
+def is_admin(user: dict[str, typing.Any] | None) -> bool:
+ """Check if user is an admin based on email whitelist."""
+ if not user:
+ return False
+ return user.get("email", "").lower() in [
+ email.lower() for email in ADMIN_EMAILS
+ ]
+
+
+def normalize_url(url: str) -> str:
+ """Normalize URL for comparison and hashing.
+
+ Normalizes:
+ - Protocol (http/https)
+ - Domain case (lowercase)
+ - www prefix (removed)
+ - Trailing slash (removed)
+ - Preserves query params and fragments as they may be meaningful
+
+ Args:
+ url: URL to normalize
+
+ Returns:
+ Normalized URL string
+ """
+ parsed = urllib.parse.urlparse(url.strip())
+
+ # Normalize domain to lowercase, remove www prefix
+ domain = parsed.netloc.lower()
+ domain = domain.removeprefix("www.")
+
+ # Normalize path - remove trailing slash unless it's the root
+ path = parsed.path.rstrip("/") if parsed.path != "/" else "/"
+
+ # Rebuild URL with normalized components
+ # Use https as the canonical protocol
+ return urllib.parse.urlunparse((
+ "https", # Always use https
+ domain,
+ path,
+ parsed.params,
+ parsed.query,
+ parsed.fragment,
+ ))
+
+
+def hash_url(url: str) -> str:
+ """Generate a hash of a URL for deduplication.
+
+ Args:
+ url: URL to hash
+
+ Returns:
+ SHA256 hash of the normalized URL
+ """
+ normalized = normalize_url(url)
+ return hashlib.sha256(normalized.encode()).hexdigest()
+
+
+class Database: # noqa: PLR0904
+ """Data access layer for PodcastItLater database operations."""
+
+ @staticmethod
+ def teardown() -> None:
+ """Delete the existing database, for cleanup after tests."""
+ db_path = DATA_DIR / "podcast.db"
+ if db_path.exists():
+ db_path.unlink()
+
+ @staticmethod
+ @contextmanager
+ def get_connection() -> Iterator[sqlite3.Connection]:
+ """Context manager for database connections.
+
+ Yields:
+ sqlite3.Connection: Database connection with row factory set.
+ """
+ db_path = DATA_DIR / "podcast.db"
+ conn = sqlite3.connect(db_path)
+ conn.row_factory = sqlite3.Row
+ try:
+ yield conn
+ finally:
+ conn.close()
+
+ @staticmethod
+ def init_db() -> None:
+ """Initialize database with required tables."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Queue table for job processing
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS queue (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ url TEXT,
+ email TEXT,
+ status TEXT DEFAULT 'pending',
+ retry_count INTEGER DEFAULT 0,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ error_message TEXT,
+ title TEXT,
+ author TEXT
+ )
+ """)
+
+ # Episodes table for completed podcasts
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS episodes (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ title TEXT NOT NULL,
+ content_length INTEGER,
+ audio_url TEXT NOT NULL,
+ duration INTEGER,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Create indexes for performance
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_created "
+ "ON queue(created_at)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_created "
+ "ON episodes(created_at)",
+ )
+
+ conn.commit()
+ logger.info("Database initialized successfully")
+
+ # Run migration to add user support
+ Database.migrate_to_multi_user()
+
+ # Run migration to add metadata fields
+ Database.migrate_add_metadata_fields()
+
+ # Run migration to add episode metadata fields
+ Database.migrate_add_episode_metadata()
+
+ # Run migration to add user status field
+ Database.migrate_add_user_status()
+
+ # Run migration to add default titles
+ Database.migrate_add_default_titles()
+
+ # Run migration to add billing fields
+ Database.migrate_add_billing_fields()
+
+ # Run migration to add stripe events table
+ Database.migrate_add_stripe_events_table()
+
+ # Run migration to add public feed features
+ Database.migrate_add_public_feed()
+
+ @staticmethod
+ def add_to_queue(
+ url: str,
+ email: str,
+ user_id: int,
+ title: str | None = None,
+ author: str | None = None,
+ ) -> int:
+ """Insert new job into queue with metadata, return job ID.
+
+ Raises:
+ ValueError: If job ID cannot be retrieved after insert.
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO queue (url, email, user_id, title, author) "
+ "VALUES (?, ?, ?, ?, ?)",
+ (url, email, user_id, title, author),
+ )
+ conn.commit()
+ job_id = cursor.lastrowid
+ if job_id is None:
+ msg = "Failed to get job ID after insert"
+ raise ValueError(msg)
+ logger.info("Added job %s to queue: %s", job_id, url)
+ return job_id
+
+ @staticmethod
+ def get_pending_jobs(
+ limit: int = 10,
+ ) -> list[dict[str, Any]]:
+ """Fetch jobs with status='pending' ordered by creation time."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM queue WHERE status = 'pending' "
+ "ORDER BY created_at ASC LIMIT ?",
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def update_job_status(
+ job_id: int,
+ status: str,
+ error: str | None = None,
+ ) -> None:
+ """Update job status and error message."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ if error is not None:
+ if status == "error":
+ cursor.execute(
+ "UPDATE queue SET status = ?, error_message = ?, "
+ "retry_count = retry_count + 1 WHERE id = ?",
+ (status, error, job_id),
+ )
+ else:
+ cursor.execute(
+ "UPDATE queue SET status = ?, "
+ "error_message = ? WHERE id = ?",
+ (status, error, job_id),
+ )
+ else:
+ cursor.execute(
+ "UPDATE queue SET status = ? WHERE id = ?",
+ (status, job_id),
+ )
+ conn.commit()
+ logger.info("Updated job %s status to %s", job_id, status)
+
+ @staticmethod
+ def get_job_by_id(
+ job_id: int,
+ ) -> dict[str, Any] | None:
+ """Fetch single job by ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def create_episode( # noqa: PLR0913, PLR0917
+ title: str,
+ audio_url: str,
+ duration: int,
+ content_length: int,
+ user_id: int | None = None,
+ author: str | None = None,
+ original_url: str | None = None,
+ original_url_hash: str | None = None,
+ ) -> int:
+ """Insert episode record, return episode ID.
+
+ Raises:
+ ValueError: If episode ID cannot be retrieved after insert.
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episodes "
+ "(title, audio_url, duration, content_length, user_id, "
+ "author, original_url, original_url_hash) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
+ (
+ title,
+ audio_url,
+ duration,
+ content_length,
+ user_id,
+ author,
+ original_url,
+ original_url_hash,
+ ),
+ )
+ conn.commit()
+ episode_id = cursor.lastrowid
+ if episode_id is None:
+ msg = "Failed to get episode ID after insert"
+ raise ValueError(msg)
+ logger.info("Created episode %s: %s", episode_id, title)
+ return episode_id
+
+ @staticmethod
+ def get_recent_episodes(
+ limit: int = 20,
+ ) -> list[dict[str, Any]]:
+ """Get recent episodes for RSS feed generation."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?",
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_queue_status_summary() -> dict[str, Any]:
+ """Get queue status summary for web interface."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Count jobs by status
+ cursor.execute(
+ "SELECT status, COUNT(*) as count FROM queue GROUP BY status",
+ )
+ rows = cursor.fetchall()
+ status_counts = {row["status"]: row["count"] for row in rows}
+
+ # Get recent jobs
+ cursor.execute(
+ "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10",
+ )
+ rows = cursor.fetchall()
+ recent_jobs = [dict(row) for row in rows]
+
+ return {"status_counts": status_counts, "recent_jobs": recent_jobs}
+
+ @staticmethod
+ def get_queue_status() -> list[dict[str, Any]]:
+ """Return pending/processing/error items for web interface."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT id, url, email, status, created_at, error_message,
+ title, author
+ FROM queue
+ WHERE status IN (
+ 'pending', 'processing', 'extracting',
+ 'synthesizing', 'uploading', 'error'
+ )
+ ORDER BY created_at DESC
+ LIMIT 20
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_episode_by_id(episode_id: int) -> dict[str, Any] | None:
+ """Fetch single episode by ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url, user_id, is_public
+ FROM episodes
+ WHERE id = ?
+ """,
+ (episode_id,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_all_episodes(
+ user_id: int | None = None,
+ ) -> list[dict[str, Any]]:
+ """Return all episodes for RSS feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ if user_id:
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url
+ FROM episodes
+ WHERE user_id = ?
+ ORDER BY created_at DESC
+ """,
+ (user_id,),
+ )
+ else:
+ cursor.execute("""
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url
+ FROM episodes
+ ORDER BY created_at DESC
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_retryable_jobs(
+ max_retries: int = 3,
+ ) -> list[dict[str, Any]]:
+ """Get failed jobs that can be retried."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM queue WHERE status = 'error' "
+ "AND retry_count < ? ORDER BY created_at ASC",
+ (max_retries,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def retry_job(job_id: int) -> None:
+ """Reset a job to pending status for retry."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE queue SET status = 'pending', "
+ "error_message = NULL WHERE id = ?",
+ (job_id,),
+ )
+ conn.commit()
+ logger.info("Reset job %s to pending for retry", job_id)
+
+ @staticmethod
+ def delete_job(job_id: int) -> None:
+ """Delete a job from the queue."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,))
+ conn.commit()
+ logger.info("Deleted job %s from queue", job_id)
+
+ @staticmethod
+ def get_all_queue_items(
+ user_id: int | None = None,
+ ) -> list[dict[str, Any]]:
+ """Return all queue items for admin view."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ if user_id:
+ cursor.execute(
+ """
+ SELECT id, url, email, status, retry_count, created_at,
+ error_message, title, author
+ FROM queue
+ WHERE user_id = ?
+ ORDER BY created_at DESC
+ """,
+ (user_id,),
+ )
+ else:
+ cursor.execute("""
+ SELECT id, url, email, status, retry_count, created_at,
+ error_message, title, author
+ FROM queue
+ ORDER BY created_at DESC
+ """)
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_status_counts() -> dict[str, int]:
+ """Get count of queue items by status."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ SELECT status, COUNT(*) as count
+ FROM queue
+ GROUP BY status
+ """)
+ rows = cursor.fetchall()
+ return {row["status"]: row["count"] for row in rows}
+
+ @staticmethod
+ def get_user_status_counts(
+ user_id: int,
+ ) -> dict[str, int]:
+ """Get count of queue items by status for a specific user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT status, COUNT(*) as count
+ FROM queue
+ WHERE user_id = ?
+ GROUP BY status
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return {row["status"]: row["count"] for row in rows}
+
+ @staticmethod
+ def migrate_to_multi_user() -> None:
+ """Migrate database to support multiple users."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Create users table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ email TEXT UNIQUE NOT NULL,
+ token TEXT UNIQUE NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Add user_id columns to existing tables
+ # Check if columns already exist to make migration idempotent
+ cursor.execute("PRAGMA table_info(queue)")
+ queue_info = cursor.fetchall()
+ queue_columns = [col[1] for col in queue_info]
+
+ if "user_id" not in queue_columns:
+ cursor.execute(
+ "ALTER TABLE queue ADD COLUMN user_id INTEGER "
+ "REFERENCES users(id)",
+ )
+
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "user_id" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN user_id INTEGER "
+ "REFERENCES users(id)",
+ )
+
+ # Create indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_queue_user_id "
+ "ON queue(user_id)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_user_id "
+ "ON episodes(user_id)",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support multiple users")
+
+ @staticmethod
+ def migrate_add_metadata_fields() -> None:
+ """Add title and author fields to queue table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check if columns already exist
+ cursor.execute("PRAGMA table_info(queue)")
+ queue_info = cursor.fetchall()
+ queue_columns = [col[1] for col in queue_info]
+
+ if "title" not in queue_columns:
+ cursor.execute("ALTER TABLE queue ADD COLUMN title TEXT")
+
+ if "author" not in queue_columns:
+ cursor.execute("ALTER TABLE queue ADD COLUMN author TEXT")
+
+ conn.commit()
+ logger.info("Database migrated to support metadata fields")
+
+ @staticmethod
+ def migrate_add_episode_metadata() -> None:
+ """Add author and original_url fields to episodes table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check if columns already exist
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "author" not in episodes_columns:
+ cursor.execute("ALTER TABLE episodes ADD COLUMN author TEXT")
+
+ if "original_url" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN original_url TEXT",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support episode metadata fields")
+
+ @staticmethod
+ def migrate_add_user_status() -> None:
+ """Add status field to users table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check if column already exists
+ cursor.execute("PRAGMA table_info(users)")
+ users_info = cursor.fetchall()
+ users_columns = [col[1] for col in users_info]
+
+ if "status" not in users_columns:
+ # Add status column with default 'active'
+ cursor.execute(
+ "ALTER TABLE users ADD COLUMN status TEXT DEFAULT 'active'",
+ )
+
+ conn.commit()
+ logger.info("Database migrated to support user status")
+
+ @staticmethod
+ def migrate_add_billing_fields() -> None:
+ """Add billing-related fields to users table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Add columns one by one (SQLite limitation)
+ # Note: SQLite ALTER TABLE doesn't support adding UNIQUE constraints
+ # We add them without UNIQUE and rely on application logic
+ columns_to_add = [
+ ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"),
+ ("stripe_customer_id", "TEXT"),
+ ("stripe_subscription_id", "TEXT"),
+ ("subscription_status", "TEXT"),
+ ("current_period_start", "TIMESTAMP"),
+ ("current_period_end", "TIMESTAMP"),
+ ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"),
+ ]
+
+ for column_name, column_def in columns_to_add:
+ try:
+ query = f"ALTER TABLE users ADD COLUMN {column_name} "
+ cursor.execute(query + column_def)
+ logger.info("Added column users.%s", column_name)
+ except sqlite3.OperationalError as e: # noqa: PERF203
+ # Column already exists, skip
+ logger.debug(
+ "Column users.%s already exists: %s",
+ column_name,
+ e,
+ )
+
+ conn.commit()
+
+ @staticmethod
+ def migrate_add_stripe_events_table() -> None:
+ """Create stripe_events table for webhook idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS stripe_events (
+ id TEXT PRIMARY KEY,
+ type TEXT NOT NULL,
+ payload TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_stripe_events_created "
+ "ON stripe_events(created_at)",
+ )
+ conn.commit()
+ logger.info("Created stripe_events table")
+
+ @staticmethod
+ def migrate_add_public_feed() -> None:
+ """Add is_public column and related tables for public feed feature."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Add is_public column to episodes
+ cursor.execute("PRAGMA table_info(episodes)")
+ episodes_info = cursor.fetchall()
+ episodes_columns = [col[1] for col in episodes_info]
+
+ if "is_public" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN is_public INTEGER "
+ "NOT NULL DEFAULT 0",
+ )
+ logger.info("Added is_public column to episodes")
+
+ # Create user_episodes junction table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS user_episodes (
+ user_id INTEGER NOT NULL,
+ episode_id INTEGER NOT NULL,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (user_id, episode_id),
+ FOREIGN KEY (user_id) REFERENCES users(id),
+ FOREIGN KEY (episode_id) REFERENCES episodes(id)
+ )
+ """)
+ logger.info("Created user_episodes junction table")
+
+ # Create index on episode_id for reverse lookups
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_user_episodes_episode "
+ "ON user_episodes(episode_id)",
+ )
+
+ # Add original_url_hash column to episodes
+ if "original_url_hash" not in episodes_columns:
+ cursor.execute(
+ "ALTER TABLE episodes ADD COLUMN original_url_hash TEXT",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_url_hash "
+ "ON episodes(original_url_hash)",
+ )
+ logger.info("Added original_url_hash column to episodes")
+
+ # Create episode_metrics table
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS episode_metrics (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ episode_id INTEGER NOT NULL,
+ user_id INTEGER,
+ event_type TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (episode_id) REFERENCES episodes(id),
+ FOREIGN KEY (user_id) REFERENCES users(id)
+ )
+ """)
+ logger.info("Created episode_metrics table")
+
+ # Create indexes for metrics queries
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episode_metrics_episode "
+ "ON episode_metrics(episode_id)",
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episode_metrics_type "
+ "ON episode_metrics(event_type)",
+ )
+
+ # Create index on is_public for efficient public feed queries
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_episodes_public "
+ "ON episodes(is_public)",
+ )
+
+ conn.commit()
+ logger.info("Database migrated for public feed feature")
+
+ @staticmethod
+ def migrate_add_default_titles() -> None:
+ """Add default titles to queue items that have None titles."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Update queue items with NULL titles to have a default
+ cursor.execute("""
+ UPDATE queue
+ SET title = 'Untitled Article'
+ WHERE title IS NULL
+ """)
+
+ # Get count of updated rows
+ updated_count = cursor.rowcount
+
+ conn.commit()
+ logger.info(
+ "Updated %s queue items with default titles",
+ updated_count,
+ )
+
+ @staticmethod
+ def create_user(email: str, status: str = "active") -> tuple[int, str]:
+ """Create a new user and return (user_id, token).
+
+ Args:
+ email: User email address
+ status: Initial status (active or disabled)
+
+ Raises:
+ ValueError: If user ID cannot be retrieved after insert or if user
+ not found, or if status is invalid.
+ """
+ if status not in {"pending", "active", "disabled"}:
+ msg = f"Invalid status: {status}"
+ raise ValueError(msg)
+
+ # Generate a secure token for RSS feed access
+ token = secrets.token_urlsafe(32)
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "INSERT INTO users (email, token, status) VALUES (?, ?, ?)",
+ (email, token, status),
+ )
+ conn.commit()
+ user_id = cursor.lastrowid
+ if user_id is None:
+ msg = "Failed to get user ID after insert"
+ raise ValueError(msg)
+ logger.info(
+ "Created user %s with email %s (status: %s)",
+ user_id,
+ email,
+ status,
+ )
+ except sqlite3.IntegrityError:
+ # User already exists
+ cursor.execute(
+ "SELECT id, token FROM users WHERE email = ?",
+ (email,),
+ )
+ row = cursor.fetchone()
+ if row is None:
+ msg = f"User with email {email} not found"
+ raise ValueError(msg) from None
+ return int(row["id"]), str(row["token"])
+ else:
+ return int(user_id), str(token)
+
+ @staticmethod
+ def get_user_by_email(
+ email: str,
+ ) -> dict[str, Any] | None:
+ """Get user by email address."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE email = ?", (email,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_by_token(
+ token: str,
+ ) -> dict[str, Any] | None:
+ """Get user by RSS token."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE token = ?", (token,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_user_by_id(
+ user_id: int,
+ ) -> dict[str, Any] | None:
+ """Get user by ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_queue_position(job_id: int) -> int | None:
+ """Get position of job in pending queue."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ # Get created_at of this job
+ cursor.execute(
+ "SELECT created_at FROM queue WHERE id = ?",
+ (job_id,),
+ )
+ row = cursor.fetchone()
+ if not row:
+ return None
+ created_at = row[0]
+
+ # Count pending items created before or at same time
+ cursor.execute(
+ """
+ SELECT COUNT(*) FROM queue
+ WHERE status = 'pending' AND created_at <= ?
+ """,
+ (created_at,),
+ )
+ return int(cursor.fetchone()[0])
+
+ @staticmethod
+ def get_user_queue_status(
+ user_id: int,
+ ) -> list[dict[str, Any]]:
+ """Return pending/processing/error items for a specific user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, url, email, status, created_at, error_message,
+ title, author
+ FROM queue
+ WHERE user_id = ? AND
+ status IN (
+ 'pending', 'processing', 'extracting',
+ 'synthesizing', 'uploading', 'error'
+ )
+ ORDER BY created_at DESC
+ LIMIT 20
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_user_recent_episodes(
+ user_id: int,
+ limit: int = 20,
+ ) -> list[dict[str, Any]]:
+ """Get recent episodes for a specific user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE user_id = ? "
+ "ORDER BY created_at DESC LIMIT ?",
+ (user_id, limit),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_user_all_episodes(
+ user_id: int,
+ ) -> list[dict[str, Any]]:
+ """Get all episodes for a specific user for RSS feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE user_id = ? "
+ "ORDER BY created_at DESC",
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def update_user_status(
+ user_id: int,
+ status: str,
+ ) -> None:
+ """Update user account status."""
+ if status not in {"pending", "active", "disabled"}:
+ msg = f"Invalid status: {status}"
+ raise ValueError(msg)
+
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET status = ? WHERE id = ?",
+ (status, user_id),
+ )
+ conn.commit()
+ logger.info("Updated user %s status to %s", user_id, status)
+
+ @staticmethod
+ def delete_user(user_id: int) -> None:
+ """Delete user and all associated data."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # 1. Get owned episode IDs
+ cursor.execute(
+ "SELECT id FROM episodes WHERE user_id = ?",
+ (user_id,),
+ )
+ owned_episode_ids = [row[0] for row in cursor.fetchall()]
+
+ # 2. Delete references to owned episodes
+ if owned_episode_ids:
+ # Construct placeholders for IN clause
+ placeholders = ",".join("?" * len(owned_episode_ids))
+
+ # Delete from user_episodes where these episodes are referenced
+ query = f"DELETE FROM user_episodes WHERE episode_id IN ({placeholders})" # noqa: S608, E501
+ cursor.execute(query, tuple(owned_episode_ids))
+
+ # Delete metrics for these episodes
+ query = f"DELETE FROM episode_metrics WHERE episode_id IN ({placeholders})" # noqa: S608, E501
+ cursor.execute(query, tuple(owned_episode_ids))
+
+ # 3. Delete owned episodes
+ cursor.execute("DELETE FROM episodes WHERE user_id = ?", (user_id,))
+
+ # 4. Delete user's data referencing others or themselves
+ cursor.execute(
+ "DELETE FROM user_episodes WHERE user_id = ?",
+ (user_id,),
+ )
+ cursor.execute(
+ "DELETE FROM episode_metrics WHERE user_id = ?",
+ (user_id,),
+ )
+ cursor.execute("DELETE FROM queue WHERE user_id = ?", (user_id,))
+
+ # 5. Delete user
+ cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
+
+ conn.commit()
+ logger.info("Deleted user %s and all associated data", user_id)
+
+ @staticmethod
+ def update_user_email(user_id: int, new_email: str) -> None:
+ """Update user's email address.
+
+ Args:
+ user_id: ID of the user to update
+ new_email: New email address
+
+ Raises:
+ ValueError: If email is already taken by another user
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "UPDATE users SET email = ? WHERE id = ?",
+ (new_email, user_id),
+ )
+ conn.commit()
+ logger.info("Updated user %s email to %s", user_id, new_email)
+ except sqlite3.IntegrityError:
+ msg = f"Email {new_email} is already taken"
+ raise ValueError(msg) from None
+
+ @staticmethod
+ def mark_episode_public(episode_id: int) -> None:
+ """Mark an episode as public."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE episodes SET is_public = 1 WHERE id = ?",
+ (episode_id,),
+ )
+ conn.commit()
+ logger.info("Marked episode %s as public", episode_id)
+
+ @staticmethod
+ def unmark_episode_public(episode_id: int) -> None:
+ """Mark an episode as private (not public)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE episodes SET is_public = 0 WHERE id = ?",
+ (episode_id,),
+ )
+ conn.commit()
+ logger.info("Unmarked episode %s as public", episode_id)
+
+ @staticmethod
+ def get_public_episodes(limit: int = 50) -> list[dict[str, Any]]:
+ """Get public episodes for public feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, title, audio_url, duration, created_at,
+ content_length, author, original_url
+ FROM episodes
+ WHERE is_public = 1
+ ORDER BY created_at DESC
+ LIMIT ?
+ """,
+ (limit,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def add_episode_to_user(user_id: int, episode_id: int) -> None:
+ """Add an episode to a user's feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ "INSERT INTO user_episodes (user_id, episode_id) "
+ "VALUES (?, ?)",
+ (user_id, episode_id),
+ )
+ conn.commit()
+ logger.info(
+ "Added episode %s to user %s feed",
+ episode_id,
+ user_id,
+ )
+ except sqlite3.IntegrityError:
+ # Episode already in user's feed
+ logger.info(
+ "Episode %s already in user %s feed",
+ episode_id,
+ user_id,
+ )
+
+ @staticmethod
+ def user_has_episode(user_id: int, episode_id: int) -> bool:
+ """Check if a user has an episode in their feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT 1 FROM user_episodes "
+ "WHERE user_id = ? AND episode_id = ?",
+ (user_id, episode_id),
+ )
+ return cursor.fetchone() is not None
+
+ @staticmethod
+ def track_episode_metric(
+ episode_id: int,
+ event_type: str,
+ user_id: int | None = None,
+ ) -> None:
+ """Track an episode metric event.
+
+ Args:
+ episode_id: ID of the episode
+ event_type: Type of event ('added', 'played', 'downloaded')
+ user_id: Optional user ID (None for anonymous events)
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episode_metrics (episode_id, user_id, event_type) "
+ "VALUES (?, ?, ?)",
+ (episode_id, user_id, event_type),
+ )
+ conn.commit()
+ logger.info(
+ "Tracked %s event for episode %s (user: %s)",
+ event_type,
+ episode_id,
+ user_id or "anonymous",
+ )
+
+ @staticmethod
+ def get_user_episodes(user_id: int) -> list[dict[str, Any]]:
+ """Get all episodes in a user's feed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.audio_url, e.duration, e.created_at,
+ e.content_length, e.author, e.original_url, e.is_public,
+ ue.added_at
+ FROM episodes e
+ JOIN user_episodes ue ON e.id = ue.episode_id
+ WHERE ue.user_id = ?
+ ORDER BY ue.added_at DESC
+ """,
+ (user_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def get_episode_by_url_hash(url_hash: str) -> dict[str, Any] | None:
+ """Get episode by original URL hash."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM episodes WHERE original_url_hash = ?",
+ (url_hash,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def get_metrics_summary() -> dict[str, Any]:
+ """Get aggregate metrics summary for admin dashboard.
+
+ Returns:
+ dict with keys:
+ - total_episodes: Total number of episodes
+ - total_plays: Total play events
+ - total_downloads: Total download events
+ - total_adds: Total add events
+ - most_played: List of top 10 most played episodes
+ - most_downloaded: List of top 10 most downloaded episodes
+ - most_added: List of top 10 most added episodes
+ - total_users: Total number of users
+ - active_subscriptions: Number of active subscriptions
+ - submissions_24h: Submissions in last 24 hours
+ - submissions_7d: Submissions in last 7 days
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Get total episodes
+ cursor.execute("SELECT COUNT(*) as count FROM episodes")
+ total_episodes = cursor.fetchone()["count"]
+
+ # Get event counts
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM episode_metrics "
+ "WHERE event_type = 'played'",
+ )
+ total_plays = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM episode_metrics "
+ "WHERE event_type = 'downloaded'",
+ )
+ total_downloads = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM episode_metrics "
+ "WHERE event_type = 'added'",
+ )
+ total_adds = cursor.fetchone()["count"]
+
+ # Get most played episodes
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.author, COUNT(*) as play_count
+ FROM episode_metrics em
+ JOIN episodes e ON em.episode_id = e.id
+ WHERE em.event_type = 'played'
+ GROUP BY em.episode_id
+ ORDER BY play_count DESC
+ LIMIT 10
+ """,
+ )
+ most_played = [dict(row) for row in cursor.fetchall()]
+
+ # Get most downloaded episodes
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.author, COUNT(*) as download_count
+ FROM episode_metrics em
+ JOIN episodes e ON em.episode_id = e.id
+ WHERE em.event_type = 'downloaded'
+ GROUP BY em.episode_id
+ ORDER BY download_count DESC
+ LIMIT 10
+ """,
+ )
+ most_downloaded = [dict(row) for row in cursor.fetchall()]
+
+ # Get most added episodes
+ cursor.execute(
+ """
+ SELECT e.id, e.title, e.author, COUNT(*) as add_count
+ FROM episode_metrics em
+ JOIN episodes e ON em.episode_id = e.id
+ WHERE em.event_type = 'added'
+ GROUP BY em.episode_id
+ ORDER BY add_count DESC
+ LIMIT 10
+ """,
+ )
+ most_added = [dict(row) for row in cursor.fetchall()]
+
+ # Get user metrics
+ cursor.execute("SELECT COUNT(*) as count FROM users")
+ total_users = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM users "
+ "WHERE subscription_status = 'active'",
+ )
+ active_subscriptions = cursor.fetchone()["count"]
+
+ # Get recent submission metrics
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM queue "
+ "WHERE created_at >= datetime('now', '-1 day')",
+ )
+ submissions_24h = cursor.fetchone()["count"]
+
+ cursor.execute(
+ "SELECT COUNT(*) as count FROM queue "
+ "WHERE created_at >= datetime('now', '-7 days')",
+ )
+ submissions_7d = cursor.fetchone()["count"]
+
+ return {
+ "total_episodes": total_episodes,
+ "total_plays": total_plays,
+ "total_downloads": total_downloads,
+ "total_adds": total_adds,
+ "most_played": most_played,
+ "most_downloaded": most_downloaded,
+ "most_added": most_added,
+ "total_users": total_users,
+ "active_subscriptions": active_subscriptions,
+ "submissions_24h": submissions_24h,
+ "submissions_7d": submissions_7d,
+ }
+
+ @staticmethod
+ def track_episode_event(
+ episode_id: int,
+ event_type: str,
+ user_id: int | None = None,
+ ) -> None:
+ """Track an episode event (added, played, downloaded)."""
+ if event_type not in {"added", "played", "downloaded"}:
+ msg = f"Invalid event type: {event_type}"
+ raise ValueError(msg)
+
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO episode_metrics "
+ "(episode_id, user_id, event_type) VALUES (?, ?, ?)",
+ (episode_id, user_id, event_type),
+ )
+ conn.commit()
+ logger.info(
+ "Tracked %s event for episode %s",
+ event_type,
+ episode_id,
+ )
+
+ @staticmethod
+ def get_episode_metrics(episode_id: int) -> dict[str, int]:
+ """Get aggregated metrics for an episode."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT event_type, COUNT(*) as count
+ FROM episode_metrics
+ WHERE episode_id = ?
+ GROUP BY event_type
+ """,
+ (episode_id,),
+ )
+ rows = cursor.fetchall()
+ return {row["event_type"]: row["count"] for row in rows}
+
+ @staticmethod
+ def get_episode_metric_events(episode_id: int) -> list[dict[str, Any]]:
+ """Get raw metric events for an episode (for testing)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT id, episode_id, user_id, event_type, created_at
+ FROM episode_metrics
+ WHERE episode_id = ?
+ ORDER BY created_at DESC
+ """,
+ (episode_id,),
+ )
+ rows = cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ @staticmethod
+ def set_user_stripe_customer(user_id: int, customer_id: str) -> None:
+ """Link Stripe customer ID to user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET stripe_customer_id = ? WHERE id = ?",
+ (customer_id, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Linked user %s to Stripe customer %s",
+ user_id,
+ customer_id,
+ )
+
+ @staticmethod
+ def update_user_subscription( # noqa: PLR0913, PLR0917
+ user_id: int,
+ subscription_id: str,
+ status: str,
+ period_start: Any,
+ period_end: Any,
+ tier: str,
+ cancel_at_period_end: bool, # noqa: FBT001
+ ) -> None:
+ """Update user subscription details."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ stripe_subscription_id = ?,
+ subscription_status = ?,
+ current_period_start = ?,
+ current_period_end = ?,
+ plan_tier = ?,
+ cancel_at_period_end = ?
+ WHERE id = ?
+ """,
+ (
+ subscription_id,
+ status,
+ period_start.isoformat(),
+ period_end.isoformat(),
+ tier,
+ 1 if cancel_at_period_end else 0,
+ user_id,
+ ),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user_id,
+ tier,
+ status,
+ )
+
+ @staticmethod
+ def update_subscription_status(user_id: int, status: str) -> None:
+ """Update only the subscription status (e.g., past_due)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET subscription_status = ? WHERE id = ?",
+ (status, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription status to %s",
+ user_id,
+ status,
+ )
+
+ @staticmethod
+ def downgrade_to_free(user_id: int) -> None:
+ """Downgrade user to free tier and clear subscription data."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ plan_tier = 'free',
+ subscription_status = 'canceled',
+ stripe_subscription_id = NULL,
+ current_period_start = NULL,
+ current_period_end = NULL,
+ cancel_at_period_end = 0
+ WHERE id = ?
+ """,
+ (user_id,),
+ )
+ conn.commit()
+ logger.info("Downgraded user %s to free tier", user_id)
+
+ @staticmethod
+ def get_user_by_stripe_customer_id(
+ customer_id: str,
+ ) -> dict[str, Any] | None:
+ """Get user by Stripe customer ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM users WHERE stripe_customer_id = ?",
+ (customer_id,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def has_processed_stripe_event(event_id: str) -> bool:
+ """Check if Stripe event has already been processed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT id FROM stripe_events WHERE id = ?",
+ (event_id,),
+ )
+ return cursor.fetchone() is not None
+
+ @staticmethod
+ def mark_stripe_event_processed(
+ event_id: str,
+ event_type: str,
+ payload: bytes,
+ ) -> None:
+ """Mark Stripe event as processed for idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT OR IGNORE INTO stripe_events (id, type, payload) "
+ "VALUES (?, ?, ?)",
+ (event_id, event_type, payload.decode("utf-8")),
+ )
+ conn.commit()
+
+ @staticmethod
+ def get_usage(
+ user_id: int,
+ period_start: Any,
+ period_end: Any,
+ ) -> dict[str, int]:
+ """Get usage stats for user in period.
+
+ Counts episodes added to user's feed (via user_episodes table)
+ during the billing period, regardless of who created them.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Count articles added to user's feed in period
+ # Uses user_episodes junction table to track when episodes
+ # were added, which correctly handles shared/existing episodes
+ cursor.execute(
+ """
+ SELECT COUNT(*) as count, SUM(e.duration) as total_seconds
+ FROM user_episodes ue
+ JOIN episodes e ON e.id = ue.episode_id
+ WHERE ue.user_id = ? AND ue.added_at >= ? AND ue.added_at < ?
+ """,
+ (user_id, period_start.isoformat(), period_end.isoformat()),
+ )
+ row = cursor.fetchone()
+
+ articles = row["count"] if row else 0
+ total_seconds = (
+ row["total_seconds"] if row and row["total_seconds"] else 0
+ )
+ minutes = total_seconds // 60
+
+ return {"articles": articles, "minutes": minutes}
+
+
+class TestDatabase(Test.TestCase):
+ """Test the Database class."""
+
+ @staticmethod
+ def setUp() -> None:
+ """Set up test database."""
+ Database.init_db()
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ Database.teardown()
+ # Clear user ID
+ self.user_id = None
+
+ def test_init_db(self) -> None:
+ """Verify all tables and indexes are created correctly."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check tables exist
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+ tables = {row[0] for row in cursor.fetchall()}
+ self.assertIn("queue", tables)
+ self.assertIn("episodes", tables)
+ self.assertIn("users", tables)
+
+ # Check indexes exist
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='index'")
+ indexes = {row[0] for row in cursor.fetchall()}
+ self.assertIn("idx_queue_status", indexes)
+ self.assertIn("idx_queue_created", indexes)
+ self.assertIn("idx_episodes_created", indexes)
+ self.assertIn("idx_queue_user_id", indexes)
+ self.assertIn("idx_episodes_user_id", indexes)
+
+ def test_connection_context_manager(self) -> None:
+ """Ensure connections are properly closed."""
+ # Get a connection and verify it works
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT 1")
+ result = cursor.fetchone()
+ self.assertEqual(result[0], 1)
+
+ # Connection should be closed after context manager
+ with pytest.raises(sqlite3.ProgrammingError):
+ cursor.execute("SELECT 1")
+
+ def test_migration_idempotency(self) -> None:
+ """Verify migrations can run multiple times safely."""
+ # Run migration multiple times
+ Database.migrate_to_multi_user()
+ Database.migrate_to_multi_user()
+ Database.migrate_to_multi_user()
+
+ # Should still work fine
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM users")
+ # Should not raise an error
+ # Test completed successfully - migration worked
+ self.assertIsNotNone(conn)
+
+ def test_get_metrics_summary_extended(self) -> None:
+ """Verify extended metrics summary."""
+ # Create some data
+ user_id, _ = Database.create_user("test@example.com")
+ Database.create_episode(
+ "Test Article",
+ "url",
+ 100,
+ 1000,
+ user_id,
+ )
+
+ # Create a queue item
+ Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ user_id,
+ )
+
+ metrics = Database.get_metrics_summary()
+
+ self.assertIn("total_users", metrics)
+ self.assertIn("active_subscriptions", metrics)
+ self.assertIn("submissions_24h", metrics)
+ self.assertIn("submissions_7d", metrics)
+
+ self.assertEqual(metrics["total_users"], 1)
+ self.assertEqual(metrics["submissions_24h"], 1)
+ self.assertEqual(metrics["submissions_7d"], 1)
+
+
+class TestUserManagement(Test.TestCase):
+ """Test user management functionality."""
+
+ @staticmethod
+ def setUp() -> None:
+ """Set up test database."""
+ Database.init_db()
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Database.teardown()
+
+ def test_create_user(self) -> None:
+ """Create user with unique email and token."""
+ user_id, token = Database.create_user("test@example.com")
+
+ self.assertIsInstance(user_id, int)
+ self.assertIsInstance(token, str)
+ self.assertGreater(len(token), 20) # Should be a secure token
+
+ def test_create_duplicate_user(self) -> None:
+ """Verify duplicate emails return existing user."""
+ # Create first user
+ user_id1, token1 = Database.create_user(
+ "test@example.com",
+ )
+
+ # Try to create duplicate
+ user_id2, token2 = Database.create_user(
+ "test@example.com",
+ )
+
+ # Should return same user
+ self.assertIsNotNone(user_id1)
+ self.assertIsNotNone(user_id2)
+ self.assertEqual(user_id1, user_id2)
+ self.assertEqual(token1, token2)
+
+ def test_get_user_by_email(self) -> None:
+ """Retrieve user by email."""
+ user_id, token = Database.create_user("test@example.com")
+
+ user = Database.get_user_by_email("test@example.com")
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["id"], user_id)
+ self.assertEqual(user["email"], "test@example.com")
+ self.assertEqual(user["token"], token)
+
+ def test_get_user_by_token(self) -> None:
+ """Retrieve user by RSS token."""
+ user_id, token = Database.create_user("test@example.com")
+
+ user = Database.get_user_by_token(token)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["id"], user_id)
+ self.assertEqual(user["email"], "test@example.com")
+
+ def test_get_user_by_id(self) -> None:
+ """Retrieve user by ID."""
+ user_id, token = Database.create_user("test@example.com")
+
+ user = Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ if user is None:
+ self.fail("User should not be None")
+ self.assertEqual(user["email"], "test@example.com")
+ self.assertEqual(user["token"], token)
+
+ def test_invalid_user_lookups(self) -> None:
+ """Verify None returned for non-existent users."""
+ self.assertIsNone(
+ Database.get_user_by_email("nobody@example.com"),
+ )
+ self.assertIsNone(
+ Database.get_user_by_token("invalid-token"),
+ )
+ self.assertIsNone(Database.get_user_by_id(9999))
+
+ def test_token_uniqueness(self) -> None:
+ """Ensure tokens are cryptographically unique."""
+ tokens = set()
+ for i in range(10):
+ _, token = Database.create_user(
+ f"user{i}@example.com",
+ )
+ tokens.add(token)
+
+ # All tokens should be unique
+ self.assertEqual(len(tokens), 10)
+
+ def test_delete_user(self) -> None:
+ """Test user deletion and cleanup."""
+ # Create user
+ user_id, _ = Database.create_user("delete_me@example.com")
+
+ # Create some data for the user
+ Database.add_to_queue(
+ "https://example.com/article",
+ "delete_me@example.com",
+ user_id,
+ )
+
+ ep_id = Database.create_episode(
+ title="Test Episode",
+ audio_url="url",
+ duration=100,
+ content_length=1000,
+ user_id=user_id,
+ )
+ Database.add_episode_to_user(user_id, ep_id)
+ Database.track_episode_metric(ep_id, "played", user_id)
+
+ # Delete user
+ Database.delete_user(user_id)
+
+ # Verify user is gone
+ self.assertIsNone(Database.get_user_by_id(user_id))
+
+ # Verify queue items are gone
+ queue = Database.get_user_queue_status(user_id)
+ self.assertEqual(len(queue), 0)
+
+ # Verify episodes are gone (direct lookup)
+ self.assertIsNone(Database.get_episode_by_id(ep_id))
+
+ def test_update_user_email(self) -> None:
+ """Update user email address."""
+ user_id, _ = Database.create_user("old@example.com")
+
+ # Update email
+ Database.update_user_email(user_id, "new@example.com")
+
+ # Verify update
+ user = Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ if user:
+ self.assertEqual(user["email"], "new@example.com")
+
+ # Old email should not exist
+ self.assertIsNone(Database.get_user_by_email("old@example.com"))
+
+ @staticmethod
+ def test_update_user_email_duplicate() -> None:
+ """Cannot update to an existing email."""
+ user_id1, _ = Database.create_user("user1@example.com")
+ Database.create_user("user2@example.com")
+
+ # Try to update user1 to user2's email
+ with pytest.raises(ValueError, match="already taken"):
+ Database.update_user_email(user_id1, "user2@example.com")
+
+
+class TestQueueOperations(Test.TestCase):
+ """Test queue operations."""
+
+ def setUp(self) -> None:
+ """Set up test database with a user."""
+ Database.init_db()
+ self.user_id, _ = Database.create_user("test@example.com")
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Database.teardown()
+
+ def test_add_to_queue(self) -> None:
+ """Add job with user association."""
+ job_id = Database.add_to_queue(
+ "https://example.com/article",
+ "test@example.com",
+ self.user_id,
+ )
+
+ self.assertIsInstance(job_id, int)
+ self.assertGreater(job_id, 0)
+
+ def test_get_pending_jobs(self) -> None:
+ """Retrieve jobs in correct order."""
+ # Add multiple jobs
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01) # Ensure different timestamps
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Get pending jobs
+ jobs = Database.get_pending_jobs(limit=10)
+
+ self.assertEqual(len(jobs), 3)
+ # Should be in order of creation (oldest first)
+ self.assertEqual(jobs[0]["id"], job1)
+ self.assertEqual(jobs[1]["id"], job2)
+ self.assertEqual(jobs[2]["id"], job3)
+
+ def test_update_job_status(self) -> None:
+ """Update status and error messages."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Update to processing
+ Database.update_job_status(job_id, "processing")
+ job = Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "processing")
+
+ # Update to error with message
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Network timeout",
+ )
+ job = Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "error")
+ self.assertEqual(job["error_message"], "Network timeout")
+ self.assertEqual(job["retry_count"], 1)
+
+ def test_retry_job(self) -> None:
+ """Reset failed jobs for retry."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Set to error
+ Database.update_job_status(job_id, "error", "Failed")
+
+ # Retry
+ Database.retry_job(job_id)
+ job = Database.get_job_by_id(job_id)
+
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "pending")
+ self.assertIsNone(job["error_message"])
+
+ def test_delete_job(self) -> None:
+ """Remove jobs from queue."""
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Delete job
+ Database.delete_job(job_id)
+
+ # Should not exist
+ job = Database.get_job_by_id(job_id)
+ self.assertIsNone(job)
+
+ def test_get_retryable_jobs(self) -> None:
+ """Find jobs eligible for retry."""
+ # Add job and mark as error
+ job_id = Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+ Database.update_job_status(job_id, "error", "Failed")
+
+ # Should be retryable
+ retryable = Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 1)
+ self.assertEqual(retryable[0]["id"], job_id)
+
+ # Exceed retry limit
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Failed again",
+ )
+ Database.update_job_status(
+ job_id,
+ "error",
+ "Failed yet again",
+ )
+
+ # Should not be retryable anymore
+ retryable = Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 0)
+
+ def test_user_queue_isolation(self) -> None:
+ """Ensure users only see their own jobs."""
+ # Create second user
+ user2_id, _ = Database.create_user("user2@example.com")
+
+ # Add jobs for both users
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "user2@example.com",
+ user2_id,
+ )
+
+ # Get user-specific queue status
+ user1_jobs = Database.get_user_queue_status(self.user_id)
+ user2_jobs = Database.get_user_queue_status(user2_id)
+
+ self.assertEqual(len(user1_jobs), 1)
+ self.assertEqual(user1_jobs[0]["id"], job1)
+
+ self.assertEqual(len(user2_jobs), 1)
+ self.assertEqual(user2_jobs[0]["id"], job2)
+
+ def test_status_counts(self) -> None:
+ """Verify status aggregation queries."""
+ # Add jobs with different statuses
+ Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ Database.update_job_status(job2, "processing")
+ Database.update_job_status(job3, "error", "Failed")
+
+ # Get status counts
+ counts = Database.get_user_status_counts(self.user_id)
+
+ self.assertEqual(counts.get("pending", 0), 1)
+ self.assertEqual(counts.get("processing", 0), 1)
+ self.assertEqual(counts.get("error", 0), 1)
+
+ def test_queue_position(self) -> None:
+ """Verify queue position calculation."""
+ # Add multiple pending jobs
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Check positions
+ self.assertEqual(Database.get_queue_position(job1), 1)
+ self.assertEqual(Database.get_queue_position(job2), 2)
+ self.assertEqual(Database.get_queue_position(job3), 3)
+
+ # Move job 2 to processing
+ Database.update_job_status(job2, "processing")
+
+ # Check positions (job 3 should now be 2nd pending job)
+ self.assertEqual(Database.get_queue_position(job1), 1)
+ self.assertIsNone(Database.get_queue_position(job2))
+ self.assertEqual(Database.get_queue_position(job3), 2)
+
+
+class TestEpisodeManagement(Test.TestCase):
+ """Test episode management functionality."""
+
+ def setUp(self) -> None:
+ """Set up test database with a user."""
+ Database.init_db()
+ self.user_id, _ = Database.create_user("test@example.com")
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Database.teardown()
+
+ def test_create_episode(self) -> None:
+ """Create episode with user association."""
+ episode_id = Database.create_episode(
+ title="Test Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=5000,
+ user_id=self.user_id,
+ )
+
+ self.assertIsInstance(episode_id, int)
+ self.assertGreater(episode_id, 0)
+
+ def test_get_recent_episodes(self) -> None:
+ """Retrieve episodes in reverse chronological order."""
+ # Create multiple episodes
+ ep1 = Database.create_episode(
+ "Article 1",
+ "url1",
+ 100,
+ 1000,
+ self.user_id,
+ )
+ time.sleep(0.01)
+ ep2 = Database.create_episode(
+ "Article 2",
+ "url2",
+ 200,
+ 2000,
+ self.user_id,
+ )
+ time.sleep(0.01)
+ ep3 = Database.create_episode(
+ "Article 3",
+ "url3",
+ 300,
+ 3000,
+ self.user_id,
+ )
+
+ # Get recent episodes
+ episodes = Database.get_recent_episodes(limit=10)
+
+ self.assertEqual(len(episodes), 3)
+ # Should be in reverse chronological order
+ self.assertEqual(episodes[0]["id"], ep3)
+ self.assertEqual(episodes[1]["id"], ep2)
+ self.assertEqual(episodes[2]["id"], ep1)
+
+ def test_get_user_episodes(self) -> None:
+ """Ensure user isolation for episodes."""
+ # Create second user
+ user2_id, _ = Database.create_user("user2@example.com")
+
+ # Create episodes for both users
+ ep1 = Database.create_episode(
+ "User1 Article",
+ "url1",
+ 100,
+ 1000,
+ self.user_id,
+ )
+ ep2 = Database.create_episode(
+ "User2 Article",
+ "url2",
+ 200,
+ 2000,
+ user2_id,
+ )
+
+ # Get user-specific episodes
+ user1_episodes = Database.get_user_all_episodes(
+ self.user_id,
+ )
+ user2_episodes = Database.get_user_all_episodes(user2_id)
+
+ self.assertEqual(len(user1_episodes), 1)
+ self.assertEqual(user1_episodes[0]["id"], ep1)
+
+ self.assertEqual(len(user2_episodes), 1)
+ self.assertEqual(user2_episodes[0]["id"], ep2)
+
+ def test_episode_metadata(self) -> None:
+ """Verify duration and content_length storage."""
+ Database.create_episode(
+ title="Test Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=12345,
+ content_length=98765,
+ user_id=self.user_id,
+ )
+
+ episodes = Database.get_user_all_episodes(self.user_id)
+ episode = episodes[0]
+
+ self.assertEqual(episode["duration"], 12345)
+ self.assertEqual(episode["content_length"], 98765)
+
+
+def test() -> None:
+ """Run the tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestDatabase,
+ TestUserManagement,
+ TestQueueOperations,
+ TestEpisodeManagement,
+ ],
+ )
+
+
+def main() -> None:
+ """Run all PodcastItLater.Core tests."""
+ if "test" in sys.argv:
+ test()
diff --git a/Biz/PodcastItLater/DESIGN.md b/Biz/PodcastItLater/DESIGN.md
new file mode 100644
index 0000000..29c4837
--- /dev/null
+++ b/Biz/PodcastItLater/DESIGN.md
@@ -0,0 +1,43 @@
+# PodcastItLater Design & Architecture
+
+## Overview
+Service converting web articles to podcast episodes via email/web submission.
+
+## Architecture
+- **Web**: `Biz/PodcastItLater/Web.py` (Ludic + HTMX + Starlette)
+- **Worker**: `Biz/PodcastItLater/Worker.py` (Background processing)
+- **Core**: `Biz/PodcastItLater/Core.py` (DB & Shared Logic)
+- **Billing**: `Biz/PodcastItLater/Billing.py` (Stripe Integration)
+
+## Key Features
+1. **User Management**: Email-based magic links, RSS tokens.
+2. **Article Processing**: Trafilatura extraction -> LLM cleanup -> TTS.
+3. **Billing (In Progress)**: Stripe Checkout/Portal, Freemium model.
+
+## Path to Paid Product (Epic: t-143KQl2)
+
+### 1. Billing Infrastructure
+- **Stripe**: Use Stripe Checkout for subs, Portal for management.
+- **Webhooks**: Handle `checkout.session.completed`, `customer.subscription.*`.
+- **Tiers**:
+ - `Free`: 10 articles/month.
+ - `Paid`: Unlimited (initially).
+
+### 2. Usage Tracking
+- **Table**: `users` table needs `plan_tier`, `subscription_status`, `stripe_customer_id`.
+- **Logic**: Check usage count vs tier limit before allowing submission.
+- **Reset**: Usage counters reset at billing period boundary.
+
+### 3. Admin Dashboard
+- View all users and their status.
+- Manually retry/delete jobs.
+- View metrics (signups, conversions).
+
+## UX Polish (Epic: t-1vIPJYG)
+- **Mobile**: Ensure all pages work on mobile.
+- **Feedback**: Better error messages for failed URLs.
+- **Navigation**: Clean up navbar, account management access.
+
+## Audio Improvements
+- **Intro/Outro**: Add metadata-rich intro ("Title by Author").
+- **Sound Design**: Crossfade intro music.
diff --git a/Biz/PodcastItLater/Episode.py b/Biz/PodcastItLater/Episode.py
new file mode 100644
index 0000000..7090c70
--- /dev/null
+++ b/Biz/PodcastItLater/Episode.py
@@ -0,0 +1,390 @@
+"""
+PodcastItLater Episode Detail Components.
+
+Components for displaying individual episode pages with media player,
+share functionality, and signup prompts for non-authenticated users.
+"""
+
+# : out podcastitlater-episode
+# : dep ludic
+import Biz.PodcastItLater.UI as UI
+import ludic.html as html
+import sys
+import typing
+from ludic.attrs import Attrs
+from ludic.components import Component
+from ludic.types import AnyChildren
+from typing import override
+
+
+class EpisodePlayerAttrs(Attrs):
+ """Attributes for EpisodePlayer component."""
+
+ audio_url: str
+ title: str
+ episode_id: int
+
+
+class EpisodePlayer(Component[AnyChildren, EpisodePlayerAttrs]):
+ """HTML5 audio player for episode playback."""
+
+ @override
+ def render(self) -> html.div:
+ audio_url = self.attrs["audio_url"]
+ episode_id = self.attrs["episode_id"]
+
+ return html.div(
+ html.div(
+ html.div(
+ html.h5(
+ html.i(classes=["bi", "bi-play-circle", "me-2"]),
+ "Listen",
+ classes=["card-title", "mb-3"],
+ ),
+ html.audio(
+ html.source(src=audio_url, type="audio/mpeg"),
+ "Your browser does not support the audio element.",
+ controls=True,
+ preload="metadata",
+ id=f"audio-player-{episode_id}",
+ classes=["w-100"],
+ style={"max-width": "100%"},
+ ),
+ # JavaScript to track play events
+ html.script(
+ f"""
+ (function() {{
+ var player = document.getElementById(
+ 'audio-player-{episode_id}'
+ );
+ var hasTrackedPlay = false;
+
+ player.addEventListener('play', function() {{
+ // Track first play only
+ if (!hasTrackedPlay) {{
+ hasTrackedPlay = true;
+
+ // Send play event to server
+ fetch('/episode/{episode_id}/track', {{
+ method: 'POST',
+ headers: {{
+ 'Content-Type':
+ 'application/x-www-form-urlencoded'
+ }},
+ body: 'event_type=played'
+ }});
+ }}
+ }});
+ }})();
+ """,
+ ),
+ classes=["card-body"],
+ ),
+ classes=["card", "mb-4"],
+ ),
+ )
+
+
+class ShareButtonAttrs(Attrs):
+ """Attributes for ShareButton component."""
+
+ share_url: str
+
+
+class ShareButton(Component[AnyChildren, ShareButtonAttrs]):
+ """Button to copy episode URL to clipboard."""
+
+ @override
+ def render(self) -> html.div:
+ share_url = self.attrs["share_url"]
+
+ return html.div(
+ html.div(
+ html.div(
+ html.h5(
+ html.i(classes=["bi", "bi-share", "me-2"]),
+ "Share Episode",
+ classes=["card-title", "mb-3"],
+ ),
+ html.div(
+ html.div(
+ html.button(
+ html.i(classes=["bi", "bi-copy", "me-1"]),
+ "Copy",
+ type="button",
+ id="share-button",
+ on_click=f"navigator.clipboard.writeText('{share_url}'); " # noqa: E501
+ "const btn = document.getElementById('share-button'); " # noqa: E501
+ "const originalHTML = btn.innerHTML; "
+ "btn.innerHTML = '<i class=\"bi bi-check me-1\"></i>Copied!'; " # noqa: E501
+ "btn.classList.remove('btn-outline-secondary'); " # noqa: E501
+ "btn.classList.add('btn-success'); "
+ "setTimeout(() => {{ "
+ "btn.innerHTML = originalHTML; "
+ "btn.classList.remove('btn-success'); "
+ "btn.classList.add('btn-outline-secondary'); "
+ "}}, 2000);",
+ classes=["btn", "btn-outline-secondary"],
+ ),
+ html.input(
+ type="text",
+ value=share_url,
+ readonly=True,
+ on_focus="this.select()",
+ classes=["form-control"],
+ ),
+ classes=["input-group"],
+ ),
+ ),
+ classes=["card-body"],
+ ),
+ classes=["card"],
+ ),
+ classes=["mb-4"],
+ )
+
+
+class SignupBannerAttrs(Attrs):
+ """Attributes for SignupBanner component."""
+
+ creator_email: str
+ base_url: str
+
+
+class SignupBanner(Component[AnyChildren, SignupBannerAttrs]):
+ """Banner prompting non-authenticated users to sign up."""
+
+ @override
+ def render(self) -> html.div:
+ return html.div(
+ html.div(
+ html.div(
+ html.div(
+ html.i(
+ classes=[
+ "bi",
+ "bi-info-circle-fill",
+ "me-2",
+ ],
+ ),
+ html.strong(
+ "This episode was created using PodcastItLater.",
+ ),
+ classes=["mb-3"],
+ ),
+ html.div(
+ html.p(
+ "Want to convert your own articles "
+ "to podcast episodes?",
+ classes=["mb-2"],
+ ),
+ html.form(
+ html.div(
+ html.input(
+ type="email",
+ name="email",
+ placeholder="Enter your email to start",
+ required=True,
+ classes=["form-control"],
+ ),
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-arrow-right-circle",
+ "me-2",
+ ],
+ ),
+ "Sign Up",
+ type="submit",
+ classes=["btn", "btn-primary"],
+ ),
+ classes=["input-group"],
+ ),
+ hx_post="/login",
+ hx_target="#signup-result",
+ hx_swap="innerHTML",
+ ),
+ html.div(id="signup-result", classes=["mt-2"]),
+ ),
+ classes=["card-body"],
+ ),
+ classes=["card", "border-primary"],
+ ),
+ classes=["mb-4"],
+ )
+
+
+class EpisodeDetailPageAttrs(Attrs):
+ """Attributes for EpisodeDetailPage component."""
+
+ episode: dict[str, typing.Any]
+ episode_sqid: str
+ creator_email: str | None
+ user: dict[str, typing.Any] | None
+ base_url: str
+ user_has_episode: bool
+
+
+class EpisodeDetailPage(Component[AnyChildren, EpisodeDetailPageAttrs]):
+ """Full page view for a single episode."""
+
+ @override
+ def render(self) -> UI.PageLayout:
+ episode = self.attrs["episode"]
+ episode_sqid = self.attrs["episode_sqid"]
+ creator_email = self.attrs.get("creator_email")
+ user = self.attrs.get("user")
+ base_url = self.attrs["base_url"]
+ user_has_episode = self.attrs.get("user_has_episode", False)
+
+ share_url = f"{base_url}/episode/{episode_sqid}"
+ duration_str = UI.format_duration(episode.get("duration"))
+
+ # Build page title
+ page_title = f"{episode['title']} - PodcastItLater"
+
+ # Build meta tags for Open Graph
+ meta_tags = [
+ html.meta(property="og:title", content=episode["title"]),
+ html.meta(property="og:type", content="website"),
+ html.meta(property="og:url", content=share_url),
+ html.meta(
+ property="og:description",
+ content=f"Listen to this article read aloud. "
+ f"Duration: {duration_str}"
+ + (f" by {episode['author']}" if episode.get("author") else ""),
+ ),
+ html.meta(
+ property="og:site_name",
+ content="PodcastItLater",
+ ),
+ html.meta(property="og:audio", content=episode["audio_url"]),
+ html.meta(property="og:audio:type", content="audio/mpeg"),
+ ]
+
+ # Add Twitter Card tags
+ meta_tags.extend([
+ html.meta(name="twitter:card", content="summary"),
+ html.meta(name="twitter:title", content=episode["title"]),
+ html.meta(
+ name="twitter:description",
+ content=f"Listen to this article. Duration: {duration_str}",
+ ),
+ ])
+
+ # Add author if available
+ if episode.get("author"):
+ meta_tags.append(
+ html.meta(property="article:author", content=episode["author"]),
+ )
+
+ return UI.PageLayout(
+ # Show signup banner if user is not logged in
+ SignupBanner(
+ creator_email=creator_email or "a user",
+ base_url=base_url,
+ )
+ if not user and creator_email
+ else html.div(),
+ # Episode title and metadata
+ html.div(
+ html.h2(
+ episode["title"],
+ classes=["display-6", "mb-3"],
+ ),
+ html.div(
+ html.span(
+ html.i(classes=["bi", "bi-person", "me-1"]),
+ f"by {episode['author']}",
+ classes=["text-muted", "me-3"],
+ )
+ if episode.get("author")
+ else html.span(),
+ html.span(
+ html.i(classes=["bi", "bi-clock", "me-1"]),
+ f"Duration: {duration_str}",
+ classes=["text-muted", "me-3"],
+ ),
+ html.span(
+ html.i(classes=["bi", "bi-calendar", "me-1"]),
+ f"Created: {episode['created_at']}",
+ classes=["text-muted"],
+ ),
+ classes=["mb-3"],
+ ),
+ html.div(
+ html.a(
+ html.i(classes=["bi", "bi-link-45deg", "me-1"]),
+ "View original article",
+ href=episode["original_url"],
+ target="_blank",
+ rel="noopener",
+ classes=["btn", "btn-sm", "btn-outline-secondary"],
+ ),
+ )
+ if episode.get("original_url")
+ else html.div(),
+ classes=["mb-4"],
+ ),
+ # Audio player
+ EpisodePlayer(
+ audio_url=episode["audio_url"],
+ title=episode["title"],
+ episode_id=episode["id"],
+ ),
+ # Share button
+ ShareButton(share_url=share_url),
+ # Add to feed button (logged-in users without episode)
+ html.div(
+ html.div(
+ html.div(
+ html.h5(
+ html.i(classes=["bi", "bi-plus-circle", "me-2"]),
+ "Add to Your Feed",
+ classes=["card-title", "mb-3"],
+ ),
+ html.p(
+ "Save this episode to your personal feed "
+ "to listen later.",
+ classes=["text-muted", "mb-3"],
+ ),
+ html.button(
+ html.i(classes=["bi", "bi-plus-lg", "me-1"]),
+ "Add to My Feed",
+ hx_post=f"/episode/{episode['id']}/add-to-feed",
+ hx_target="#add-to-feed-result",
+ hx_swap="innerHTML",
+ classes=["btn", "btn-primary"],
+ ),
+ html.div(id="add-to-feed-result", classes=["mt-2"]),
+ classes=["card-body"],
+ ),
+ classes=["card"],
+ ),
+ classes=["mb-4"],
+ )
+ if user and not user_has_episode
+ else html.div(),
+ # Back to home link
+ html.div(
+ html.a(
+ html.i(classes=["bi", "bi-arrow-left", "me-1"]),
+ "Back to Home",
+ href="/",
+ classes=["btn", "btn-link"],
+ ),
+ classes=["mt-4"],
+ ),
+ user=user,
+ current_page="",
+ error=None,
+ page_title=page_title,
+ meta_tags=meta_tags,
+ )
+
+
+def main() -> None:
+ """Episode module has no tests currently."""
+ if "test" in sys.argv:
+ sys.exit(0)
diff --git a/Biz/PodcastItLater/INFRASTRUCTURE.md b/Biz/PodcastItLater/INFRASTRUCTURE.md
new file mode 100644
index 0000000..1c61618
--- /dev/null
+++ b/Biz/PodcastItLater/INFRASTRUCTURE.md
@@ -0,0 +1,38 @@
+# Infrastructure Setup for PodcastItLater
+
+## Mailgun Setup
+
+Since PodcastItLater requires sending transactional emails (magic links), we use Mailgun.
+
+### 1. Sign up for Mailgun
+Sign up at [mailgun.com](https://www.mailgun.com/).
+
+### 2. Add Domain
+Add `podcastitlater.com` (or `mg.podcastitlater.com`) to Mailgun.
+We recommend using the root domain `podcastitlater.com` if you want emails to come from `@podcastitlater.com`.
+
+### 3. Configure DNS
+Mailgun will provide DNS records to verify the domain and authorize email sending. You must add these to your DNS provider (e.g., Cloudflare, Namecheap).
+
+Required records usually include:
+- **TXT** (SPF): `v=spf1 include:mailgun.org ~all`
+- **TXT** (DKIM): `k=rsa; p=...` (Provided by Mailgun)
+- **MX** (if receiving email, optional for just sending): `10 mxa.mailgun.org`, `10 mxb.mailgun.org`
+- **CNAME** (for tracking, optional): `email.podcastitlater.com` -> `mailgun.org`
+
+### 4. Verify Domain
+Click "Verify DNS Settings" in Mailgun dashboard. This may take up to 24 hours but is usually instant.
+
+### 5. Generate API Key / SMTP Credentials
+Go to "Sending" -> "Domain Settings" -> "SMTP Credentials".
+Create a new SMTP user (e.g., `postmaster@podcastitlater.com`).
+**Save the password immediately.**
+
+### 6. Update Secrets
+Update the production secrets file on the server (`/run/podcastitlater/env`):
+
+```bash
+SMTP_SERVER=smtp.mailgun.org
+SMTP_PASSWORD=your-new-smtp-password
+EMAIL_FROM=noreply@podcastitlater.com
+```
diff --git a/Biz/PodcastItLater/STRIPE_TESTING.md b/Biz/PodcastItLater/STRIPE_TESTING.md
new file mode 100644
index 0000000..1461c06
--- /dev/null
+++ b/Biz/PodcastItLater/STRIPE_TESTING.md
@@ -0,0 +1,114 @@
+# Stripe Testing Guide
+
+## Testing Stripe Integration Without Real Transactions
+
+### 1. Use Stripe Test Mode
+
+Stripe provides test API keys that allow you to simulate payments without real money:
+
+1. Get your test keys from https://dashboard.stripe.com/test/apikeys
+2. Set environment variables with test keys:
+ ```bash
+ export STRIPE_SECRET_KEY="sk_test_..."
+ export STRIPE_WEBHOOK_SECRET="whsec_test_..."
+ export STRIPE_PRICE_ID_PRO="price_test_..."
+ ```
+
+### 2. Use Stripe Test Cards
+
+In test mode, use these test card numbers:
+- **Success**: `4242 4242 4242 4242`
+- **Decline**: `4000 0000 0000 0002`
+- **3D Secure**: `4000 0025 0000 3155`
+
+Any future expiry date and any 3-digit CVC will work.
+
+### 3. Trigger Test Webhooks
+
+Use Stripe CLI to trigger webhook events locally:
+
+```bash
+# Install Stripe CLI
+# https://stripe.com/docs/stripe-cli
+
+# Login
+stripe login
+
+# Forward webhooks to local server
+stripe listen --forward-to localhost:8000/stripe/webhook
+
+# Trigger specific events
+stripe trigger checkout.session.completed
+stripe trigger customer.subscription.created
+stripe trigger customer.subscription.updated
+stripe trigger invoice.payment_failed
+```
+
+### 4. Run Unit Tests
+
+The billing module includes unit tests that mock Stripe webhooks:
+
+```bash
+# Run billing tests
+AREA=Test python3 Biz/PodcastItLater/Billing.py test
+
+# Or use bild
+bild --test Biz/PodcastItLater/Billing.py
+```
+
+### 5. Test Migration on Production
+
+To fix the production database missing columns issue, you need to trigger the migration.
+
+The migration runs automatically when `Database.init_db()` is called, but production may have an old database.
+
+**Option A: Restart the web service**
+The init_db() runs on startup, so restarting should apply migrations.
+
+**Option B: Run migration manually**
+```bash
+# SSH to production
+# Run Python REPL with proper environment
+python3
+>>> import os
+>>> os.environ["AREA"] = "Live"
+>>> os.environ["DATA_DIR"] = "/var/podcastitlater"
+>>> import Biz.PodcastItLater.Core as Core
+>>> Core.Database.init_db()
+```
+
+### 6. Verify Database Schema
+
+Check that billing columns exist:
+
+```bash
+sqlite3 /var/podcastitlater/podcast.db
+.schema users
+```
+
+Should show:
+- `stripe_customer_id TEXT`
+- `stripe_subscription_id TEXT`
+- `subscription_status TEXT`
+- `current_period_start TEXT`
+- `current_period_end TEXT`
+- `plan_tier TEXT NOT NULL DEFAULT 'free'`
+- `cancel_at_period_end INTEGER DEFAULT 0`
+
+### 7. End-to-End Test Flow
+
+1. Start in test mode: `AREA=Test PORT=8000 python3 Biz/PodcastItLater/Web.py`
+2. Login with test account
+3. Go to /billing
+4. Click "Upgrade Now"
+5. Use test card: 4242 4242 4242 4242
+6. Stripe CLI will forward webhook to your local server
+7. Verify subscription updated in database
+
+### 8. Common Issues
+
+**KeyError in webhook**: Make sure you're using safe `.get()` access for all Stripe object fields, as the structure can vary.
+
+**Database column missing**: Run migrations by restarting the service or calling `Database.init_db()`.
+
+**Webhook signature verification fails**: Make sure `STRIPE_WEBHOOK_SECRET` matches your endpoint secret from Stripe dashboard.
diff --git a/Biz/PodcastItLater/TESTING.md b/Biz/PodcastItLater/TESTING.md
new file mode 100644
index 0000000..2911610
--- /dev/null
+++ b/Biz/PodcastItLater/TESTING.md
@@ -0,0 +1,45 @@
+# PodcastItLater Testing Strategy
+
+## Overview
+We use `pytest` with `Omni.Test` integration. Tests are co-located with code or in `Biz/PodcastItLater/Test.py` for E2E.
+
+## Test Categories
+
+### 1. Core (Database/Logic)
+- **Location**: `Biz/PodcastItLater/Core.py`
+- **Scope**: User creation, Job queue ops, Episode management.
+- **Key Tests**:
+ - `test_create_user`: Unique tokens.
+ - `test_queue_isolation`: Users see only their jobs.
+
+### 2. Web (HTTP/UI)
+- **Location**: `Biz/PodcastItLater/Web.py`
+- **Scope**: Routes, Auth, HTMX responses.
+- **Key Tests**:
+ - `test_submit_requires_auth`.
+ - `test_rss_feed_xml`.
+ - `test_admin_access_control`.
+
+### 3. Worker (Processing)
+- **Location**: `Biz/PodcastItLater/Worker.py`
+- **Scope**: Extraction, TTS, S3 upload.
+- **Key Tests**:
+ - `test_extract_content`: Mocked network calls.
+ - `test_tts_chunking`: Handle long text.
+ - **Error Handling**: Ensure retries work and errors are logged.
+
+### 4. Billing (Stripe)
+- **Location**: `Biz/PodcastItLater/Billing.py`
+- **Scope**: Webhook processing, Entitlement checks.
+- **Key Tests**:
+ - `test_webhook_subscription_update`: Update local DB.
+ - `test_enforce_limits`: Block submission if over limit.
+
+## Running Tests
+```bash
+# Run all
+bild --test Biz/PodcastItLater.hs
+
+# Run specific file
+./Biz/PodcastItLater/Web.py test
+```
diff --git a/Biz/PodcastItLater/Test.py b/Biz/PodcastItLater/Test.py
new file mode 100644
index 0000000..86b04f4
--- /dev/null
+++ b/Biz/PodcastItLater/Test.py
@@ -0,0 +1,276 @@
+"""End-to-end tests for PodcastItLater."""
+
+# : dep boto3
+# : dep botocore
+# : dep feedgen
+# : dep httpx
+# : dep itsdangerous
+# : dep ludic
+# : dep openai
+# : dep psutil
+# : dep pydub
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+# : dep sqids
+# : dep starlette
+# : dep stripe
+# : dep trafilatura
+# : dep uvicorn
+# : out podcastitlater-e2e-test
+# : run ffmpeg
+import Biz.PodcastItLater.Core as Core
+import Biz.PodcastItLater.UI as UI
+import Biz.PodcastItLater.Web as Web
+import Biz.PodcastItLater.Worker as Worker
+import Omni.App as App
+import Omni.Test as Test
+import pathlib
+import re
+import sys
+import unittest.mock
+from starlette.testclient import TestClient
+
+
+class BaseWebTest(Test.TestCase):
+ """Base test class with common setup."""
+
+ def setUp(self) -> None:
+ """Set up test environment."""
+ self.app = Web.app
+ self.client = TestClient(self.app)
+
+ # Initialize database for each test
+ Core.Database.init_db()
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up after each test."""
+ Core.Database.teardown()
+
+
+class TestEndToEnd(BaseWebTest):
+ """Test complete end-to-end flows."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in user."""
+ super().setUp()
+
+ # Create and login user
+ self.user_id, self.token = Core.Database.create_user(
+ "test@example.com",
+ )
+ Core.Database.update_user_status(
+ self.user_id,
+ "active",
+ )
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ def test_full_article_to_rss_flow(self) -> None: # noqa: PLR0915
+ """Test complete flow: submit URL → process → appears in RSS feed."""
+ # Step 1: Submit article URL
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com/great-article"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Article submitted successfully", response.text)
+
+ # Extract job ID from response
+ match = re.search(r"Job ID: (\d+)", response.text)
+ self.assertIsNotNone(match)
+ if match is None:
+ self.fail("Job ID not found in response")
+ job_id = int(match.group(1))
+
+ # Verify job was created
+ job = Core.Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "pending")
+ self.assertEqual(job["user_id"], self.user_id)
+
+ # Step 2: Process the job with mocked external services
+ shutdown_handler = Worker.ShutdownHandler()
+ processor = Worker.ArticleProcessor(shutdown_handler)
+
+ # Mock external dependencies
+ mock_audio_data = b"fake-mp3-audio-content-12345"
+
+ with (
+ unittest.mock.patch.object(
+ Worker.ArticleProcessor,
+ "extract_article_content",
+ return_value=(
+ "Great Article Title",
+ "This is the article content.",
+ ),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=["This is the article content."],
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ unittest.mock.patch.object(
+ processor.openai_client.audio.speech,
+ "create",
+ ) as mock_tts,
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ return_value="https://cdn.example.com/episode_123_Great_Article.mp3",
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ ) as mock_audio_segment,
+ unittest.mock.patch(
+ "pathlib.Path.read_bytes",
+ return_value=mock_audio_data,
+ ),
+ ):
+ # Mock TTS response
+ mock_tts_response = unittest.mock.MagicMock()
+ mock_tts_response.content = mock_audio_data
+ mock_tts.return_value = mock_tts_response
+
+ # Mock audio segment
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.export = lambda path, **_kwargs: pathlib.Path(
+ path,
+ ).write_bytes(
+ mock_audio_data,
+ )
+ mock_audio_segment.return_value = mock_segment
+
+ # Process the pending job
+ Worker.process_pending_jobs(processor)
+
+ # Step 3: Verify job was marked completed
+ job = Core.Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None:
+ self.fail("Job should not be None")
+ self.assertEqual(job["status"], "completed")
+
+ # Step 4: Verify episode was created
+ episodes = Core.Database.get_user_all_episodes(self.user_id)
+ self.assertEqual(len(episodes), 1)
+
+ episode = episodes[0]
+ self.assertEqual(episode["title"], "Great Article Title")
+ self.assertEqual(
+ episode["audio_url"],
+ "https://cdn.example.com/episode_123_Great_Article.mp3",
+ )
+ self.assertGreater(episode["duration"], 0)
+ self.assertEqual(episode["user_id"], self.user_id)
+
+ # Step 5: Verify episode appears in RSS feed
+ response = self.client.get(f"/feed/{self.token}.rss")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(
+ response.headers["content-type"],
+ "application/rss+xml; charset=utf-8",
+ )
+
+ # Check RSS contains the episode
+ self.assertIn("Great Article Title", response.text)
+ self.assertIn(
+ "https://cdn.example.com/episode_123_Great_Article.mp3",
+ response.text,
+ )
+ self.assertIn("<enclosure", response.text)
+ self.assertIn('type="audio/mpeg"', response.text)
+
+ # Step 6: Verify only this user's episode is in their feed
+ # Create another user with their own episode
+ user2_id, token2 = Core.Database.create_user("other@example.com")
+ Core.Database.create_episode(
+ "Other User's Article",
+ "https://cdn.example.com/other.mp3",
+ 200,
+ 3000,
+ user2_id,
+ )
+
+ # Original user's feed should not contain other user's episode
+ response = self.client.get(f"/feed/{self.token}.rss")
+ self.assertIn("Great Article Title", response.text)
+ self.assertNotIn("Other User's Article", response.text)
+
+ # Other user's feed should only contain their episode
+ response = self.client.get(f"/feed/{token2}.rss")
+ self.assertNotIn("Great Article Title", response.text)
+ self.assertIn("Other User's Article", response.text)
+
+
+class TestUI(Test.TestCase):
+ """Test UI components."""
+
+ def test_render_navbar(self) -> None:
+ """Test navbar rendering."""
+ user = {"email": "test@example.com", "id": 1}
+ layout = UI.PageLayout(
+ user=user,
+ current_page="home",
+ error=None,
+ page_title="Test",
+ meta_tags=[],
+ )
+ navbar = layout._render_navbar(user, "home") # noqa: SLF001
+ html_output = navbar.to_html()
+
+ # Check basic structure
+ self.assertIn("navbar", html_output)
+ self.assertIn("Home", html_output)
+ self.assertIn("Public Feed", html_output)
+ self.assertIn("Pricing", html_output)
+ self.assertIn("Manage Account", html_output)
+
+ # Check active state
+ self.assertIn("active", html_output)
+
+ # Check non-admin user doesn't see admin menu
+ self.assertNotIn("Admin", html_output)
+
+ def test_render_navbar_admin(self) -> None:
+ """Test navbar rendering for admin."""
+ user = {"email": "ben@bensima.com", "id": 1} # Admin email
+ layout = UI.PageLayout(
+ user=user,
+ current_page="admin",
+ error=None,
+ page_title="Test",
+ meta_tags=[],
+ )
+ navbar = layout._render_navbar(user, "admin") # noqa: SLF001
+ html_output = navbar.to_html()
+
+ # Check admin menu present
+ self.assertIn("Admin", html_output)
+ self.assertIn("Queue Status", html_output)
+
+
+def test() -> None:
+ """Run all end-to-end tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestEndToEnd,
+ TestUI,
+ ],
+ )
+
+
+def main() -> None:
+ """Run the tests."""
+ if "test" in sys.argv:
+ test()
+ else:
+ test()
diff --git a/Biz/PodcastItLater/TestMetricsView.py b/Biz/PodcastItLater/TestMetricsView.py
new file mode 100644
index 0000000..c6fdd46
--- /dev/null
+++ b/Biz/PodcastItLater/TestMetricsView.py
@@ -0,0 +1,121 @@
+"""Tests for Admin metrics view."""
+
+# : out podcastitlater-test-metrics
+# : dep pytest
+# : dep starlette
+# : dep httpx
+# : dep ludic
+# : dep feedgen
+# : dep itsdangerous
+# : dep uvicorn
+# : dep stripe
+# : dep sqids
+
+import Biz.PodcastItLater.Core as Core
+import Biz.PodcastItLater.Web as Web
+import Omni.Test as Test
+from starlette.testclient import TestClient
+
+
+class BaseWebTest(Test.TestCase):
+ """Base class for web tests."""
+
+ def setUp(self) -> None:
+ """Set up test database and client."""
+ Core.Database.init_db()
+ self.client = TestClient(Web.app)
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Core.Database.teardown()
+
+
+class TestMetricsView(BaseWebTest):
+ """Test Admin Metrics View."""
+
+ def test_admin_metrics_view_access(self) -> None:
+ """Admin user should be able to access metrics view."""
+ # Create admin user
+ _admin_id, _ = Core.Database.create_user("ben@bensima.com")
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ response = self.client.get("/admin/metrics")
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Growth & Usage", response.text)
+ self.assertIn("Total Users", response.text)
+
+ def test_admin_metrics_data(self) -> None:
+ """Metrics view should show correct data."""
+ # Create admin user
+ admin_id, _ = Core.Database.create_user("ben@bensima.com")
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ # Create some data
+ # 1. Users
+ Core.Database.create_user("user1@example.com")
+ user2_id, _ = Core.Database.create_user("user2@example.com")
+
+ # 2. Subscriptions (simulate by setting subscription_status)
+ with Core.Database.get_connection() as conn:
+ conn.execute(
+ "UPDATE users SET subscription_status = 'active' WHERE id = ?",
+ (user2_id,),
+ )
+ conn.commit()
+
+ # 3. Submissions
+ Core.Database.add_to_queue(
+ "http://example.com/1",
+ "user1@example.com",
+ admin_id,
+ )
+
+ # Get metrics page
+ response = self.client.get("/admin/metrics")
+ self.assertEqual(response.status_code, 200)
+
+ # Check labels
+ self.assertIn("Total Users", response.text)
+ self.assertIn("Active Subs", response.text)
+ self.assertIn("Submissions (24h)", response.text)
+
+ # Check values (metrics dict is passed to template,
+ # we check rendered HTML)
+ # Total users: 3 (admin + user1 + user2)
+ # Active subs: 1 (user2)
+ # Submissions 24h: 1
+
+ # Check for values in HTML
+ # Note: This is a bit brittle, but effective for quick verification
+ self.assertIn('<h3 class="mb-0">3</h3>', response.text)
+ self.assertIn('<h3 class="mb-0">1</h3>', response.text)
+
+ def test_non_admin_access_denied(self) -> None:
+ """Non-admin users should be denied access."""
+ # Create regular user
+ Core.Database.create_user("regular@example.com")
+ self.client.post("/login", data={"email": "regular@example.com"})
+
+ response = self.client.get("/admin/metrics")
+ # Should redirect to /?error=forbidden
+ self.assertEqual(response.status_code, 302)
+ self.assertIn("error=forbidden", response.headers["Location"])
+
+ def test_anonymous_access_redirect(self) -> None:
+ """Anonymous users should be redirected to login."""
+ response = self.client.get("/admin/metrics")
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response.headers["Location"], "/")
+
+
+def main() -> None:
+ """Run the tests."""
+ Test.run(
+ Web.area,
+ [TestMetricsView],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py
new file mode 100644
index 0000000..e9ef27d
--- /dev/null
+++ b/Biz/PodcastItLater/UI.py
@@ -0,0 +1,755 @@
+"""
+PodcastItLater Shared UI Components.
+
+Common UI components and utilities shared across web pages.
+"""
+
+# : lib
+# : dep ludic
+import Biz.PodcastItLater.Core as Core
+import ludic.html as html
+import typing
+from ludic.attrs import Attrs
+from ludic.components import Component
+from ludic.types import AnyChildren
+from typing import override
+
+
+def format_duration(seconds: int | None) -> str:
+ """Format duration from seconds to human-readable format.
+
+ Examples:
+ 300 -> "5m"
+ 3840 -> "1h 4m"
+ 11520 -> "3h 12m"
+ """
+ if seconds is None or seconds <= 0:
+ return "Unknown"
+
+ # Constants for time conversion
+ seconds_per_minute = 60
+ minutes_per_hour = 60
+ seconds_per_hour = 3600
+
+ # Round up to nearest minute
+ minutes = (seconds + seconds_per_minute - 1) // seconds_per_minute
+
+ # Show as minutes only if under 60 minutes (exclusive)
+ # 3599 seconds rounds up to 60 minutes, which we keep as "60m"
+ if minutes <= minutes_per_hour:
+ # If exactly 3600 seconds (already 60 full minutes without rounding)
+ if seconds >= seconds_per_hour:
+ return "1h"
+ return f"{minutes}m"
+
+ hours = minutes // minutes_per_hour
+ remaining_minutes = minutes % minutes_per_hour
+
+ if remaining_minutes == 0:
+ return f"{hours}h"
+
+ return f"{hours}h {remaining_minutes}m"
+
+
+def create_bootstrap_styles() -> html.style:
+ """Load Bootstrap CSS and icons."""
+ return html.style(
+ "@import url('https://cdn.jsdelivr.net/npm/bootstrap@5.3.2"
+ "/dist/css/bootstrap.min.css');"
+ "@import url('https://cdn.jsdelivr.net/npm/bootstrap-icons"
+ "@1.11.3/font/bootstrap-icons.min.css');",
+ )
+
+
+def create_auto_dark_mode_style() -> html.style:
+ """Create CSS for automatic dark mode based on prefers-color-scheme."""
+ return html.style(
+ """
+ /* Auto dark mode - applies Bootstrap dark theme via media query */
+ @media (prefers-color-scheme: dark) {
+ :root {
+ color-scheme: dark;
+ --bs-body-color: #dee2e6;
+ --bs-body-color-rgb: 222, 226, 230;
+ --bs-body-bg: #212529;
+ --bs-body-bg-rgb: 33, 37, 41;
+ --bs-emphasis-color: #fff;
+ --bs-emphasis-color-rgb: 255, 255, 255;
+ --bs-secondary-color: rgba(222, 226, 230, 0.75);
+ --bs-secondary-bg: #343a40;
+ --bs-tertiary-color: rgba(222, 226, 230, 0.5);
+ --bs-tertiary-bg: #2b3035;
+ --bs-heading-color: inherit;
+ --bs-link-color: #6ea8fe;
+ --bs-link-hover-color: #8bb9fe;
+ --bs-link-color-rgb: 110, 168, 254;
+ --bs-link-hover-color-rgb: 139, 185, 254;
+ --bs-code-color: #e685b5;
+ --bs-border-color: #495057;
+ --bs-border-color-translucent: rgba(255, 255, 255, 0.15);
+ }
+
+ /* Navbar dark mode */
+ .navbar.bg-body-tertiary {
+ background-color: #2b3035 !important;
+ }
+
+ .navbar .navbar-text {
+ color: #dee2e6 !important;
+ }
+
+ /* Table header dark mode */
+ .table-light {
+ --bs-table-bg: #343a40;
+ --bs-table-color: #dee2e6;
+ background-color: #343a40 !important;
+ color: #dee2e6 !important;
+ }
+ }
+ """,
+ )
+
+
+def create_htmx_script() -> html.script:
+ """Load HTMX library."""
+ return html.script(
+ src="https://unpkg.com/htmx.org@1.9.10",
+ integrity="sha384-D1Kt99CQMDuVetoL1lrYwg5t+9QdHe7NLX/SoJYkXDFfX37iInKRy5xLSi8nO7UC",
+ crossorigin="anonymous",
+ )
+
+
+def create_bootstrap_js() -> html.script:
+ """Load Bootstrap JavaScript bundle."""
+ return html.script(
+ src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/js/bootstrap.bundle.min.js",
+ integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL",
+ crossorigin="anonymous",
+ )
+
+
+class PageLayoutAttrs(Attrs):
+ """Attributes for PageLayout component."""
+
+ user: dict[str, typing.Any] | None
+ current_page: str
+ error: str | None
+ page_title: str | None
+ meta_tags: list[html.meta] | None
+
+
+class PageLayout(Component[AnyChildren, PageLayoutAttrs]):
+ """Reusable page layout with header and navbar."""
+
+ @staticmethod
+ def _render_nav_item(
+ label: str,
+ href: str,
+ icon: str,
+ *,
+ is_active: bool,
+ ) -> html.li:
+ return html.li(
+ html.a(
+ html.i(classes=["bi", f"bi-{icon}", "me-1"]),
+ label,
+ href=href,
+ classes=[
+ "nav-link",
+ "active" if is_active else "",
+ ],
+ ),
+ classes=["nav-item"],
+ )
+
+ @staticmethod
+ def _render_admin_dropdown(
+ is_active_func: typing.Callable[[str], bool],
+ ) -> html.li:
+ is_active = is_active_func("admin") or is_active_func("admin-users")
+ return html.li(
+ html.a( # type: ignore[call-arg]
+ html.i(classes=["bi", "bi-gear-fill", "me-1"]),
+ "Admin",
+ href="#",
+ id="adminDropdown",
+ role="button",
+ data_bs_toggle="dropdown",
+ aria_expanded="false",
+ classes=[
+ "nav-link",
+ "dropdown-toggle",
+ "active" if is_active else "",
+ ],
+ ),
+ html.ul( # type: ignore[call-arg]
+ html.li(
+ html.a(
+ html.i(classes=["bi", "bi-list-task", "me-2"]),
+ "Queue Status",
+ href="/admin",
+ classes=["dropdown-item"],
+ ),
+ ),
+ html.li(
+ html.a(
+ html.i(classes=["bi", "bi-people-fill", "me-2"]),
+ "Manage Users",
+ href="/admin/users",
+ classes=["dropdown-item"],
+ ),
+ ),
+ html.li(
+ html.a(
+ html.i(classes=["bi", "bi-graph-up", "me-2"]),
+ "Metrics",
+ href="/admin/metrics",
+ classes=["dropdown-item"],
+ ),
+ ),
+ classes=["dropdown-menu"],
+ aria_labelledby="adminDropdown",
+ ),
+ classes=["nav-item", "dropdown"],
+ )
+
+ @staticmethod
+ def _render_navbar(
+ user: dict[str, typing.Any] | None,
+ current_page: str,
+ ) -> html.nav:
+ """Render navigation bar."""
+
+ def is_active(page: str) -> bool:
+ return current_page == page
+
+ return html.nav(
+ html.div(
+ html.button( # type: ignore[call-arg]
+ html.span(classes=["navbar-toggler-icon"]),
+ classes=["navbar-toggler", "ms-auto"],
+ type="button",
+ data_bs_toggle="collapse",
+ data_bs_target="#navbarNav",
+ aria_controls="navbarNav",
+ aria_expanded="false",
+ aria_label="Toggle navigation",
+ ),
+ html.div(
+ html.ul(
+ PageLayout._render_nav_item(
+ "Home",
+ "/",
+ "house-fill",
+ is_active=is_active("home"),
+ ),
+ PageLayout._render_nav_item(
+ "Public Feed",
+ "/public",
+ "globe",
+ is_active=is_active("public"),
+ ),
+ PageLayout._render_nav_item(
+ "Pricing",
+ "/pricing",
+ "stars",
+ is_active=is_active("pricing"),
+ ),
+ PageLayout._render_nav_item(
+ "Manage Account",
+ "/account",
+ "person-circle",
+ is_active=is_active("account"),
+ ),
+ PageLayout._render_admin_dropdown(is_active)
+ if user and Core.is_admin(user)
+ else html.span(),
+ classes=["navbar-nav"],
+ ),
+ id="navbarNav",
+ classes=["collapse", "navbar-collapse"],
+ ),
+ classes=["container-fluid"],
+ ),
+ classes=[
+ "navbar",
+ "navbar-expand-lg",
+ "bg-body-tertiary",
+ "rounded",
+ "mb-4",
+ ],
+ )
+
+ @override
+ def render(self) -> html.html:
+ user = self.attrs.get("user")
+ current_page = self.attrs.get("current_page", "")
+ error = self.attrs.get("error")
+ page_title = self.attrs.get("page_title") or "PodcastItLater"
+ meta_tags = self.attrs.get("meta_tags") or []
+
+ return html.html(
+ html.head(
+ html.meta(charset="utf-8"),
+ html.meta(
+ name="viewport",
+ content="width=device-width, initial-scale=1",
+ ),
+ html.meta(
+ name="color-scheme",
+ content="light dark",
+ ),
+ html.title(page_title),
+ *meta_tags,
+ create_htmx_script(),
+ ),
+ html.body(
+ create_bootstrap_styles(),
+ create_auto_dark_mode_style(),
+ html.div(
+ html.div(
+ html.h1(
+ "PodcastItLater",
+ classes=["display-4", "mb-2"],
+ ),
+ html.p(
+ "Convert web articles to podcast episodes",
+ classes=["lead", "text-muted"],
+ ),
+ classes=["text-center", "mb-4", "pt-4"],
+ ),
+ html.div(
+ html.div(
+ html.i(
+ classes=[
+ "bi",
+ "bi-exclamation-triangle-fill",
+ "me-2",
+ ],
+ ),
+ error or "",
+ classes=[
+ "alert",
+ "alert-danger",
+ "d-flex",
+ "align-items-center",
+ ],
+ role="alert", # type: ignore[call-arg]
+ ),
+ )
+ if error
+ else html.div(),
+ self._render_navbar(user, current_page)
+ if user
+ else html.div(),
+ *self.children,
+ classes=["container", "px-3", "px-md-4"],
+ style={"max-width": "900px"},
+ ),
+ create_bootstrap_js(),
+ ),
+ )
+
+
+class AccountPageAttrs(Attrs):
+ """Attributes for AccountPage component."""
+
+ user: dict[str, typing.Any]
+ usage: dict[str, int]
+ limits: dict[str, int | None]
+ portal_url: str | None
+
+
+class AccountPage(Component[AnyChildren, AccountPageAttrs]):
+ """Account management page component."""
+
+ @override
+ def render(self) -> PageLayout:
+ user = self.attrs["user"]
+ usage = self.attrs["usage"]
+ limits = self.attrs["limits"]
+ portal_url = self.attrs["portal_url"]
+
+ plan_tier = user.get("plan_tier", "free")
+ is_paid = plan_tier == "paid"
+
+ article_limit = limits.get("articles_per_period")
+ article_usage = usage.get("articles", 0)
+
+ limit_text = (
+ "Unlimited" if article_limit is None else str(article_limit)
+ )
+
+ usage_percent = 0
+ if article_limit:
+ usage_percent = min(100, int((article_usage / article_limit) * 100))
+
+ progress_style = (
+ {"width": f"{usage_percent}%"} if article_limit else {"width": "0%"}
+ )
+
+ return PageLayout(
+ html.div(
+ html.div(
+ html.div(
+ html.div(
+ html.div(
+ html.h2(
+ html.i(
+ classes=[
+ "bi",
+ "bi-person-circle",
+ "me-2",
+ ],
+ ),
+ "My Account",
+ classes=["card-title", "mb-4"],
+ ),
+ # User Info Section
+ html.div(
+ html.h5("Profile", classes=["mb-3"]),
+ html.div(
+ html.strong("Email: "),
+ html.span(user.get("email", "")),
+ html.button(
+ "Change",
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-secondary",
+ "ms-2",
+ "py-0",
+ ],
+ hx_get="/settings/email/edit",
+ hx_target="closest div",
+ hx_swap="outerHTML",
+ ),
+ classes=[
+ "mb-2",
+ "d-flex",
+ "align-items-center",
+ ],
+ ),
+ html.p(
+ html.strong("Member since: "),
+ user.get("created_at", "").split("T")[
+ 0
+ ],
+ classes=["mb-4"],
+ ),
+ classes=["mb-5"],
+ ),
+ # Subscription Section
+ html.div(
+ html.h5("Subscription", classes=["mb-3"]),
+ html.div(
+ html.div(
+ html.strong("Current Plan"),
+ html.span(
+ plan_tier.title(),
+ classes=[
+ "badge",
+ "bg-success"
+ if is_paid
+ else "bg-secondary",
+ "ms-2",
+ ],
+ ),
+ classes=[
+ "d-flex",
+ "align-items-center",
+ "mb-3",
+ ],
+ ),
+ # Usage Stats
+ html.div(
+ html.p(
+ "Usage this period:",
+ classes=["mb-2", "text-muted"],
+ ),
+ html.div(
+ html.div(
+ f"{article_usage} / "
+ f"{limit_text}",
+ classes=["mb-1"],
+ ),
+ html.div(
+ html.div(
+ classes=[
+ "progress-bar",
+ ],
+ role="progressbar", # type: ignore[call-arg]
+ style=progress_style, # type: ignore[arg-type]
+ ),
+ classes=[
+ "progress",
+ "mb-3",
+ ],
+ style={"height": "10px"},
+ )
+ if article_limit
+ else html.div(),
+ classes=["mb-3"],
+ ),
+ ),
+ # Actions
+ html.div(
+ html.form(
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-credit-card",
+ "me-2",
+ ],
+ ),
+ "Manage Subscription",
+ type="submit",
+ classes=[
+ "btn",
+ "btn-outline-primary",
+ ],
+ ),
+ method="post",
+ action=portal_url,
+ )
+ if is_paid and portal_url
+ else html.a(
+ html.i(
+ classes=[
+ "bi",
+ "bi-star-fill",
+ "me-2",
+ ],
+ ),
+ "Upgrade to Pro",
+ href="/pricing",
+ classes=["btn", "btn-primary"],
+ ),
+ classes=["d-flex", "gap-2"],
+ ),
+ classes=[
+ "card",
+ "card-body",
+ "bg-light",
+ ],
+ ),
+ classes=["mb-5"],
+ ),
+ # Logout Section
+ html.div(
+ html.form(
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-box-arrow-right",
+ "me-2",
+ ],
+ ),
+ "Log Out",
+ type="submit",
+ classes=[
+ "btn",
+ "btn-outline-danger",
+ ],
+ ),
+ action="/logout",
+ method="post",
+ ),
+ classes=["border-top", "pt-4"],
+ ),
+ # Delete Account Section
+ html.div(
+ html.h5(
+ "Danger Zone",
+ classes=["text-danger", "mb-3"],
+ ),
+ html.div(
+ html.h6("Delete Account"),
+ html.p(
+ "Once you delete your account, "
+ "there is no going back. "
+ "Please be certain.",
+ classes=["card-text"],
+ ),
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-trash",
+ "me-2",
+ ],
+ ),
+ "Delete Account",
+ hx_delete="/account",
+ hx_confirm=(
+ "Are you absolutely sure you "
+ "want to delete your account? "
+ "This action cannot be undone."
+ ),
+ classes=["btn", "btn-danger"],
+ ),
+ classes=[
+ "card",
+ "card-body",
+ "border-danger",
+ ],
+ ),
+ classes=["mt-5", "pt-4", "border-top"],
+ ),
+ classes=["card-body", "p-4"],
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-lg-8", "mx-auto"],
+ ),
+ classes=["row"],
+ ),
+ ),
+ user=user,
+ current_page="account",
+ page_title="Account - PodcastItLater",
+ error=None,
+ meta_tags=[],
+ )
+
+
+class PricingPageAttrs(Attrs):
+ """Attributes for PricingPage component."""
+
+ user: dict[str, typing.Any] | None
+
+
+class PricingPage(Component[AnyChildren, PricingPageAttrs]):
+ """Pricing page component."""
+
+ @override
+ def render(self) -> PageLayout:
+ user = self.attrs.get("user")
+ current_tier = user.get("plan_tier", "free") if user else "free"
+
+ return PageLayout(
+ html.div(
+ html.div(
+ # Free Tier
+ html.div(
+ html.div(
+ html.div(
+ html.h3("Free", classes=["card-title"]),
+ html.h4(
+ "$0",
+ classes=[
+ "card-subtitle",
+ "mb-3",
+ "text-muted",
+ ],
+ ),
+ html.p(
+ "10 articles total",
+ classes=["card-text"],
+ ),
+ html.ul(
+ html.li("Convert 10 articles"),
+ html.li("Basic features"),
+ classes=["list-unstyled", "mb-4"],
+ ),
+ html.button(
+ "Current Plan",
+ classes=[
+ "btn",
+ "btn-outline-primary",
+ "w-100",
+ ],
+ disabled=True,
+ )
+ if current_tier == "free"
+ else html.div(),
+ classes=["card-body"],
+ ),
+ classes=["card", "mb-4", "shadow-sm", "h-100"],
+ ),
+ classes=["col-md-6"],
+ ),
+ # Paid Tier
+ html.div(
+ html.div(
+ html.div(
+ html.h3(
+ "Unlimited",
+ classes=["card-title"],
+ ),
+ html.h4(
+ "$12/mo",
+ classes=[
+ "card-subtitle",
+ "mb-3",
+ "text-muted",
+ ],
+ ),
+ html.p(
+ "Unlimited articles",
+ classes=["card-text"],
+ ),
+ html.ul(
+ html.li("Unlimited conversions"),
+ html.li("Priority processing"),
+ html.li("Support independent software"),
+ classes=["list-unstyled", "mb-4"],
+ ),
+ html.form(
+ html.button(
+ "Upgrade Now",
+ type="submit",
+ classes=[
+ "btn",
+ "btn-primary",
+ "w-100",
+ ],
+ ),
+ action="/upgrade",
+ method="post",
+ )
+ if user and current_tier == "free"
+ else (
+ html.button(
+ "Current Plan",
+ classes=[
+ "btn",
+ "btn-success",
+ "w-100",
+ ],
+ disabled=True,
+ )
+ if user and current_tier == "paid"
+ else html.a(
+ "Login to Upgrade",
+ href="/",
+ classes=[
+ "btn",
+ "btn-primary",
+ "w-100",
+ ],
+ )
+ ),
+ classes=["card-body"],
+ ),
+ classes=[
+ "card",
+ "mb-4",
+ "shadow-sm",
+ "border-primary",
+ "h-100",
+ ],
+ ),
+ classes=["col-md-6"],
+ ),
+ classes=["row"],
+ ),
+ ),
+ user=user,
+ current_page="pricing",
+ page_title="Pricing - PodcastItLater",
+ error=None,
+ meta_tags=[],
+ )
diff --git a/Biz/PodcastItLater/Web.nix b/Biz/PodcastItLater/Web.nix
new file mode 100644
index 0000000..7533ca4
--- /dev/null
+++ b/Biz/PodcastItLater/Web.nix
@@ -0,0 +1,93 @@
+{
+ options,
+ lib,
+ config,
+ ...
+}: let
+ cfg = config.services.podcastitlater-web;
+ rootDomain = "podcastitlater.com";
+ ports = import ../../Omni/Cloud/Ports.nix;
+in {
+ options.services.podcastitlater-web = {
+ enable = lib.mkEnableOption "Enable the PodcastItLater web service";
+ port = lib.mkOption {
+ type = lib.types.int;
+ default = 8000;
+ description = ''
+ The port on which PodcastItLater web will listen for
+ incoming HTTP traffic.
+ '';
+ };
+ dataDir = lib.mkOption {
+ type = lib.types.path;
+ default = "/var/podcastitlater";
+ description = "Data directory for PodcastItLater (shared with worker)";
+ };
+ package = lib.mkOption {
+ type = lib.types.package;
+ description = "PodcastItLater web package to use";
+ };
+ };
+ config = lib.mkIf cfg.enable {
+ systemd.services.podcastitlater-web = {
+ path = [cfg.package];
+ wantedBy = ["multi-user.target"];
+ preStart = ''
+ # Create data directory if it doesn't exist
+ mkdir -p ${cfg.dataDir}
+
+ # Manual step: create this file with secrets
+ # SECRET_KEY=your-secret-key-for-sessions
+ # SESSION_SECRET=your-session-secret
+ # EMAIL_FROM=noreply@podcastitlater.com
+ # SMTP_SERVER=smtp.mailgun.org
+ # SMTP_PASSWORD=your-smtp-password
+ # STRIPE_SECRET_KEY=sk_live_your_stripe_secret_key
+ # STRIPE_WEBHOOK_SECRET=whsec_your_webhook_secret
+ # STRIPE_PRICE_ID_PRO=price_your_pro_price_id
+ test -f /run/podcastitlater/env
+ '';
+ script = ''
+ ${cfg.package}/bin/podcastitlater-web
+ '';
+ description = ''
+ PodcastItLater Web Service
+ '';
+ serviceConfig = {
+ Environment = [
+ "PORT=${toString cfg.port}"
+ "AREA=Live"
+ "DATA_DIR=${cfg.dataDir}"
+ "BASE_URL=https://${rootDomain}"
+ ];
+ EnvironmentFile = "/run/podcastitlater/env";
+ KillSignal = "INT";
+ Type = "simple";
+ Restart = "on-abort";
+ RestartSec = "1";
+ };
+ };
+
+ # Nginx configuration
+ services.nginx = {
+ enable = true;
+ recommendedGzipSettings = true;
+ recommendedOptimisation = true;
+ recommendedProxySettings = true;
+ recommendedTlsSettings = true;
+ statusPage = true;
+
+ virtualHosts."${rootDomain}" = {
+ forceSSL = true;
+ enableACME = true;
+ locations."/" = {
+ proxyPass = "http://127.0.0.1:${toString cfg.port}";
+ proxyWebsockets = true;
+ };
+ };
+ };
+
+ # Ensure firewall allows web traffic
+ networking.firewall.allowedTCPPorts = [ports.ssh ports.http ports.https];
+ };
+}
diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py
new file mode 100644
index 0000000..30b5236
--- /dev/null
+++ b/Biz/PodcastItLater/Web.py
@@ -0,0 +1,3480 @@
+"""
+PodcastItLater Web Service.
+
+Web frontend for converting articles to podcast episodes.
+Provides ludic + htmx interface and RSS feed generation.
+"""
+
+# : out podcastitlater-web
+# : dep ludic
+# : dep feedgen
+# : dep httpx
+# : dep itsdangerous
+# : dep uvicorn
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+# : dep starlette
+# : dep stripe
+# : dep sqids
+import Biz.EmailAgent
+import Biz.PodcastItLater.Admin as Admin
+import Biz.PodcastItLater.Billing as Billing
+import Biz.PodcastItLater.Core as Core
+import Biz.PodcastItLater.Episode as Episode
+import Biz.PodcastItLater.UI as UI
+import html as html_module
+import httpx
+import logging
+import ludic.html as html
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import os
+import pathlib
+import re
+import sys
+import tempfile
+import typing
+import urllib.parse
+import uvicorn
+from datetime import datetime
+from datetime import timezone
+from feedgen.feed import FeedGenerator # type: ignore[import-untyped]
+from itsdangerous import URLSafeTimedSerializer
+from ludic.attrs import Attrs
+from ludic.components import Component
+from ludic.types import AnyChildren
+from ludic.web import LudicApp
+from ludic.web import Request
+from ludic.web.datastructures import FormData
+from ludic.web.responses import Response
+from sqids import Sqids
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import RedirectResponse
+from starlette.testclient import TestClient
+from typing import override
+from unittest.mock import patch
+
+logger = logging.getLogger(__name__)
+Log.setup(logger)
+
+
+# Configuration
+area = App.from_env()
+BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
+PORT = int(os.getenv("PORT", "8000"))
+
+# Initialize sqids for episode URL encoding
+sqids = Sqids(min_length=8)
+
+
+def encode_episode_id(episode_id: int) -> str:
+ """Encode episode ID to sqid for URLs."""
+ return str(sqids.encode([episode_id]))
+
+
+def decode_episode_id(sqid: str) -> int | None:
+ """Decode sqid to episode ID. Returns None if invalid."""
+ try:
+ decoded = sqids.decode(sqid)
+ return decoded[0] if decoded else None
+ except (ValueError, IndexError):
+ return None
+
+
+# Authentication configuration
+MAGIC_LINK_MAX_AGE = 3600 # 1 hour
+SESSION_MAX_AGE = 30 * 24 * 3600 # 30 days
+EMAIL_FROM = os.getenv("EMAIL_FROM", "noreply@podcastitlater.com")
+SMTP_SERVER = os.getenv("SMTP_SERVER", "smtp.mailgun.org")
+SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
+
+# Initialize serializer for magic links
+magic_link_serializer = URLSafeTimedSerializer(
+ os.getenv("SECRET_KEY", "dev-secret-key"),
+)
+
+
+RSS_CONFIG = {
+ "author": "PodcastItLater",
+ "language": "en-US",
+ "base_url": BASE_URL,
+}
+
+
+def extract_og_metadata(url: str) -> tuple[str | None, str | None]:
+ """Extract Open Graph title and author from URL.
+
+ Returns:
+ tuple: (title, author) - both may be None if extraction fails
+ """
+ try:
+ # Use httpx to fetch the page with a timeout
+ response = httpx.get(url, timeout=10.0, follow_redirects=True)
+ response.raise_for_status()
+
+ # Simple regex-based extraction to avoid heavy dependencies
+ html_content = response.text
+
+ # Extract og:title
+ title_match = re.search(
+ r'<meta\s+(?:property|name)=["\']og:title["\']\s+content=["\'](.*?)["\']',
+ html_content,
+ re.IGNORECASE,
+ )
+ title = title_match.group(1) if title_match else None
+
+ # Extract author - try article:author first, then og:site_name
+ author_match = re.search(
+ r'<meta\s+(?:property|name)=["\']article:author["\']\s+content=["\'](.*?)["\']',
+ html_content,
+ re.IGNORECASE,
+ )
+ if not author_match:
+ author_match = re.search(
+ r'<meta\s+(?:property|name)=["\']og:site_name["\']\s+content=["\'](.*?)["\']',
+ html_content,
+ re.IGNORECASE,
+ )
+ author = author_match.group(1) if author_match else None
+
+ # Clean up HTML entities
+ if title:
+ title = html_module.unescape(title)
+ if author:
+ author = html_module.unescape(author)
+
+ except (httpx.HTTPError, httpx.TimeoutException, re.error) as e:
+ logger.warning("Failed to extract metadata from %s: %s", url, e)
+ return None, None
+ else:
+ return title, author
+
+
+def send_magic_link(email: str, token: str) -> None:
+ """Send magic link email to user."""
+ subject = "Login to PodcastItLater"
+
+ # Create temporary file for email body
+ with tempfile.NamedTemporaryFile(
+ mode="w",
+ suffix=".txt",
+ delete=False,
+ encoding="utf-8",
+ ) as f:
+ body_text_path = pathlib.Path(f.name)
+
+ # Create email body
+ magic_link = f"{BASE_URL}/auth/verify?token={token}"
+ body_text_path.write_text(f"""
+Hello,
+
+Click this link to login to PodcastItLater:
+{magic_link}
+
+This link will expire in 1 hour.
+
+If you didn't request this, please ignore this email.
+
+Best,
+PodcastItLater
+""")
+
+ try:
+ Biz.EmailAgent.send_email(
+ to_addrs=[email],
+ from_addr=EMAIL_FROM,
+ smtp_server=SMTP_SERVER,
+ password=SMTP_PASSWORD,
+ subject=subject,
+ body_text=body_text_path,
+ )
+ finally:
+ # Clean up temporary file
+ body_text_path.unlink(missing_ok=True)
+
+
+class LoginFormAttrs(Attrs):
+ """Attributes for LoginForm component."""
+
+ error: str | None
+
+
+class LoginForm(Component[AnyChildren, LoginFormAttrs]):
+ """Simple email-based login/registration form."""
+
+ @override
+ def render(self) -> html.div:
+ error = self.attrs.get("error")
+ is_dev_mode = App.from_env() == App.Area.Test
+
+ return html.div(
+ # Dev mode banner
+ html.div(
+ html.div(
+ html.i(classes=["bi", "bi-info-circle", "me-2"]),
+ html.strong("Dev/Test Mode: "),
+ "Use ",
+ html.code(
+ "demo@example.com",
+ classes=["text-dark", "mx-1"],
+ ),
+ " for instant login",
+ classes=[
+ "alert",
+ "alert-info",
+ "d-flex",
+ "align-items-center",
+ "mb-3",
+ ],
+ ),
+ )
+ if is_dev_mode
+ else html.div(),
+ html.div(
+ html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-envelope-fill", "me-2"]),
+ "Login / Register",
+ classes=["card-title", "mb-3"],
+ ),
+ html.form(
+ html.div(
+ html.label(
+ "Email address",
+ for_="email",
+ classes=["form-label"],
+ ),
+ html.input(
+ type="email",
+ id="email",
+ name="email",
+ placeholder="your@email.com",
+ value="demo@example.com" if is_dev_mode else "",
+ required=True,
+ classes=["form-control", "mb-3"],
+ ),
+ ),
+ html.button(
+ html.i(
+ classes=["bi", "bi-arrow-right-circle", "me-2"],
+ ),
+ "Continue",
+ type="submit",
+ classes=["btn", "btn-primary", "w-100"],
+ ),
+ hx_post="/login",
+ hx_target="#login-result",
+ hx_swap="innerHTML",
+ ),
+ html.div(
+ error or "",
+ id="login-result",
+ classes=["mt-3"],
+ ),
+ classes=["card-body"],
+ ),
+ classes=["card"],
+ ),
+ classes=["mb-4"],
+ )
+
+
+class SubmitForm(Component[AnyChildren, Attrs]):
+ """Article submission form with HTMX."""
+
+ @override
+ def render(self) -> html.div:
+ return html.div(
+ html.div(
+ html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-file-earmark-plus", "me-2"]),
+ "Submit Article",
+ classes=["card-title", "mb-3"],
+ ),
+ html.form(
+ html.div(
+ html.label(
+ "Article URL",
+ for_="url",
+ classes=["form-label"],
+ ),
+ html.div(
+ html.input(
+ type="url",
+ id="url",
+ name="url",
+ placeholder="https://example.com/article",
+ required=True,
+ classes=["form-control"],
+ on_focus="this.select()",
+ ),
+ html.button(
+ html.i(classes=["bi", "bi-send-fill"]),
+ type="submit",
+ classes=["btn", "btn-primary"],
+ ),
+ classes=["input-group", "mb-3"],
+ ),
+ ),
+ hx_post="/submit",
+ hx_target="#submit-result",
+ hx_swap="innerHTML",
+ hx_on=(
+ "htmx:afterRequest: "
+ "if(event.detail.successful) "
+ "document.getElementById('url').value = ''"
+ ),
+ ),
+ html.div(id="submit-result", classes=["mt-2"]),
+ classes=["card-body"],
+ ),
+ classes=["card"],
+ ),
+ classes=["mb-4"],
+ )
+
+
+class QueueStatusAttrs(Attrs):
+ """Attributes for QueueStatus component."""
+
+ items: list[dict[str, typing.Any]]
+
+
+class QueueStatus(Component[AnyChildren, QueueStatusAttrs]):
+ """Display queue items with auto-refresh."""
+
+ @override
+ def render(self) -> html.div:
+ items = self.attrs["items"]
+ if not items:
+ return html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-list-check", "me-2"]),
+ "Queue Status",
+ classes=["mb-3"],
+ ),
+ html.p("No items in queue", classes=["text-muted"]),
+ )
+
+ # Map status to Bootstrap badge classes
+ status_classes = {
+ "pending": "bg-warning text-dark",
+ "processing": "bg-primary",
+ "extracting": "bg-info text-dark",
+ "synthesizing": "bg-primary",
+ "uploading": "bg-success",
+ "error": "bg-danger",
+ "cancelled": "bg-secondary",
+ }
+
+ status_icons = {
+ "pending": "bi-clock",
+ "processing": "bi-arrow-repeat",
+ "extracting": "bi-file-text",
+ "synthesizing": "bi-mic",
+ "uploading": "bi-cloud-arrow-up",
+ "error": "bi-exclamation-triangle",
+ "cancelled": "bi-x-circle",
+ }
+
+ queue_items = []
+ for item in items:
+ badge_class = status_classes.get(item["status"], "bg-secondary")
+ icon_class = status_icons.get(item["status"], "bi-question-circle")
+
+ # Get queue position for pending items
+ queue_pos = None
+ if item["status"] == "pending":
+ queue_pos = Core.Database.get_queue_position(item["id"])
+
+ queue_items.append(
+ html.div(
+ html.div(
+ html.div(
+ html.strong(f"#{item['id']}", classes=["me-2"]),
+ html.span(
+ html.i(classes=["bi", icon_class, "me-1"]),
+ item["status"].upper(),
+ classes=["badge", badge_class],
+ ),
+ classes=[
+ "d-flex",
+ "align-items-center",
+ "justify-content-between",
+ ],
+ ),
+ # Add title and author if available
+ *(
+ [
+ html.div(
+ html.strong(
+ item["title"],
+ classes=["d-block"],
+ ),
+ html.small(
+ f"by {item['author']}",
+ classes=["text-muted"],
+ )
+ if item.get("author")
+ else html.span(),
+ classes=["mt-2"],
+ ),
+ ]
+ if item.get("title")
+ else []
+ ),
+ html.small(
+ html.i(classes=["bi", "bi-link-45deg", "me-1"]),
+ item["url"][: Core.URL_TRUNCATE_LENGTH]
+ + (
+ "..."
+ if len(item["url"]) > Core.URL_TRUNCATE_LENGTH
+ else ""
+ ),
+ classes=["text-muted", "d-block", "mt-2"],
+ ),
+ html.small(
+ html.i(classes=["bi", "bi-calendar", "me-1"]),
+ f"Created: {item['created_at']}",
+ classes=["text-muted", "d-block", "mt-1"],
+ ),
+ # Display queue position if available
+ html.small(
+ html.i(
+ classes=["bi", "bi-hourglass-split", "me-1"],
+ ),
+ f"Position in queue: #{queue_pos}",
+ classes=["text-info", "d-block", "mt-1"],
+ )
+ if queue_pos
+ else html.span(),
+ *(
+ [
+ html.div(
+ html.i(
+ classes=[
+ "bi",
+ "bi-exclamation-circle",
+ "me-1",
+ ],
+ ),
+ f"Error: {item['error_message']}",
+ classes=[
+ "alert",
+ "alert-danger",
+ "mt-2",
+ "mb-0",
+ "py-1",
+ "px-2",
+ "small",
+ ],
+ ),
+ ]
+ if item["error_message"]
+ else []
+ ),
+ # Add cancel button for pending jobs, remove for others
+ html.div(
+ # Retry button for error items
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-arrow-clockwise",
+ "me-1",
+ ],
+ ),
+ "Retry",
+ hx_post=f"/queue/{item['id']}/retry",
+ hx_trigger="click",
+ hx_on=(
+ "htmx:afterRequest: "
+ "if(event.detail.successful) "
+ "htmx.trigger('body', 'queue-updated')"
+ ),
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-primary",
+ "mt-2",
+ "me-2",
+ ],
+ )
+ if item["status"] == "error"
+ else html.span(),
+ html.button(
+ html.i(classes=["bi", "bi-x-lg", "me-1"]),
+ "Cancel",
+ hx_post=f"/queue/{item['id']}/cancel",
+ hx_trigger="click",
+ hx_on=(
+ "htmx:afterRequest: "
+ "if(event.detail.successful) "
+ "htmx.trigger('body', 'queue-updated')"
+ ),
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-danger",
+ "mt-2",
+ ],
+ )
+ if item["status"] == "pending"
+ else html.button(
+ html.i(classes=["bi", "bi-trash", "me-1"]),
+ "Remove",
+ hx_delete=f"/queue/{item['id']}",
+ hx_trigger="click",
+ hx_confirm="Remove this item from the queue?",
+ hx_on=(
+ "htmx:afterRequest: "
+ "if(event.detail.successful) "
+ "htmx.trigger('body', 'queue-updated')"
+ ),
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-secondary",
+ "mt-2",
+ ],
+ ),
+ classes=["mt-2"],
+ ),
+ classes=["card-body"],
+ ),
+ classes=["card", "mb-2"],
+ ),
+ )
+
+ return html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-list-check", "me-2"]),
+ "Queue Status",
+ classes=["mb-3"],
+ ),
+ *queue_items,
+ )
+
+
+class EpisodeListAttrs(Attrs):
+ """Attributes for EpisodeList component."""
+
+ episodes: list[dict[str, typing.Any]]
+ rss_url: str | None
+ user: dict[str, typing.Any] | None
+ viewing_own_feed: bool
+
+
+class EpisodeList(Component[AnyChildren, EpisodeListAttrs]):
+ """List recent episodes (no audio player - use podcast app)."""
+
+ @override
+ def render(self) -> html.div:
+ episodes = self.attrs["episodes"]
+ rss_url = self.attrs.get("rss_url")
+ user = self.attrs.get("user")
+ viewing_own_feed = self.attrs.get("viewing_own_feed", False)
+
+ if not episodes:
+ return html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-broadcast", "me-2"]),
+ "Recent Episodes",
+ classes=["mb-3"],
+ ),
+ html.p("No episodes yet", classes=["text-muted"]),
+ )
+
+ episode_items = []
+ for episode in episodes:
+ duration_str = UI.format_duration(episode.get("duration"))
+ episode_sqid = encode_episode_id(episode["id"])
+ is_public = episode.get("is_public", 0) == 1
+
+ # Admin "Add to public feed" button at bottom of card
+ admin_button: html.div | html.button = html.div()
+ if user and Core.is_admin(user):
+ if is_public:
+ admin_button = html.button(
+ html.i(classes=["bi", "bi-check-circle-fill", "me-1"]),
+ "Added to public feed",
+ hx_post=f"/admin/episode/{episode['id']}/toggle-public",
+ hx_target="body",
+ hx_swap="outerHTML",
+ classes=["btn", "btn-sm", "btn-success", "mt-2"],
+ )
+ else:
+ admin_button = html.button(
+ html.i(classes=["bi", "bi-plus-circle", "me-1"]),
+ "Add to public feed",
+ hx_post=f"/admin/episode/{episode['id']}/toggle-public",
+ hx_target="body",
+ hx_swap="outerHTML",
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-success",
+ "mt-2",
+ ],
+ )
+
+ # "Add to my feed" button for logged-in users
+ # (only when NOT viewing own feed)
+ user_button: html.div | html.button = html.div()
+ if user and not viewing_own_feed:
+ # Check if user already has this episode
+ user_has_episode = Core.Database.user_has_episode(
+ user["id"],
+ episode["id"],
+ )
+ if user_has_episode:
+ user_button = html.button(
+ html.i(classes=["bi", "bi-check-circle-fill", "me-1"]),
+ "In your feed",
+ disabled=True,
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-secondary",
+ "mt-2",
+ "ms-2",
+ ],
+ )
+ else:
+ user_button = html.button(
+ html.i(classes=["bi", "bi-plus-circle", "me-1"]),
+ "Add to my feed",
+ hx_post=f"/episode/{episode['id']}/add-to-feed",
+ hx_swap="none",
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-primary",
+ "mt-2",
+ "ms-2",
+ ],
+ )
+
+ episode_items.append(
+ html.div(
+ html.div(
+ html.h5(
+ html.a(
+ episode["title"],
+ href=f"/episode/{episode_sqid}",
+ classes=["text-decoration-none"],
+ ),
+ classes=["card-title", "mb-2"],
+ ),
+ # Show author if available
+ html.p(
+ html.i(classes=["bi", "bi-person", "me-1"]),
+ f"by {episode['author']}",
+ classes=["text-muted", "small", "mb-3"],
+ )
+ if episode.get("author")
+ else html.div(),
+ html.div(
+ html.small(
+ html.i(classes=["bi", "bi-clock", "me-1"]),
+ f"Duration: {duration_str}",
+ classes=["text-muted", "me-3"],
+ ),
+ html.small(
+ html.i(classes=["bi", "bi-calendar", "me-1"]),
+ f"Created: {episode['created_at']}",
+ classes=["text-muted"],
+ ),
+ classes=["mb-2"],
+ ),
+ # Show link to original article if available
+ html.div(
+ html.a(
+ html.i(classes=["bi", "bi-link-45deg", "me-1"]),
+ "View original article",
+ href=episode["original_url"],
+ target="_blank",
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-primary",
+ ],
+ ),
+ )
+ if episode.get("original_url")
+ else html.div(),
+ # Buttons row (admin and user buttons)
+ html.div(
+ admin_button,
+ user_button,
+ classes=["d-flex", "flex-wrap"],
+ ),
+ classes=["card-body"],
+ ),
+ classes=["card", "mb-3"],
+ ),
+ )
+
+ return html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-broadcast", "me-2"]),
+ "Recent Episodes",
+ classes=["mb-3"],
+ ),
+ # RSS feed link with copy-to-clipboard
+ html.div(
+ html.div(
+ html.label(
+ html.i(classes=["bi", "bi-rss-fill", "me-2"]),
+ "Subscribe in your podcast app:",
+ classes=["form-label", "fw-bold"],
+ ),
+ html.div(
+ html.button(
+ html.i(classes=["bi", "bi-copy", "me-1"]),
+ "Copy",
+ type="button",
+ id="rss-copy-button",
+ on_click=f"navigator.clipboard.writeText('{rss_url}'); " # noqa: E501
+ "const btn = document.getElementById('rss-copy-button'); " # noqa: E501
+ "const originalHTML = btn.innerHTML; "
+ "btn.innerHTML = '<i class=\"bi bi-check me-1\"></i>Copied!'; " # noqa: E501
+ "btn.classList.remove('btn-outline-secondary'); "
+ "btn.classList.add('btn-success'); "
+ "setTimeout(() => {{ "
+ "btn.innerHTML = originalHTML; "
+ "btn.classList.remove('btn-success'); "
+ "btn.classList.add('btn-outline-secondary'); "
+ "}}, 2000);",
+ classes=["btn", "btn-outline-secondary"],
+ ),
+ html.input(
+ type="text",
+ value=rss_url or "",
+ readonly=True,
+ on_focus="this.select()",
+ classes=["form-control"],
+ ),
+ classes=["input-group", "mb-3"],
+ ),
+ ),
+ )
+ if rss_url
+ else html.div(),
+ *episode_items,
+ )
+
+
+class HomePageAttrs(Attrs):
+ """Attributes for HomePage component."""
+
+ queue_items: list[dict[str, typing.Any]]
+ episodes: list[dict[str, typing.Any]]
+ user: dict[str, typing.Any] | None
+ error: str | None
+
+
+class PublicFeedPageAttrs(Attrs):
+ """Attributes for PublicFeedPage component."""
+
+ episodes: list[dict[str, typing.Any]]
+ user: dict[str, typing.Any] | None
+
+
+class PublicFeedPage(Component[AnyChildren, PublicFeedPageAttrs]):
+ """Public feed page without auto-refresh."""
+
+ @override
+ def render(self) -> UI.PageLayout:
+ episodes = self.attrs["episodes"]
+ user = self.attrs.get("user")
+
+ return UI.PageLayout(
+ html.div(
+ html.h2(
+ html.i(classes=["bi", "bi-globe", "me-2"]),
+ "Public Feed",
+ classes=["mb-3"],
+ ),
+ html.p(
+ "Featured articles converted to audio by our community. "
+ "Subscribe to get new episodes in your podcast app!",
+ classes=["lead", "text-muted", "mb-4"],
+ ),
+ EpisodeList(
+ episodes=episodes,
+ rss_url=f"{BASE_URL}/public.rss",
+ user=user,
+ viewing_own_feed=False,
+ ),
+ ),
+ user=user,
+ current_page="public",
+ error=None,
+ )
+
+
+class HomePage(Component[AnyChildren, HomePageAttrs]):
+ """Main page combining all components."""
+
+ @staticmethod
+ def _render_plan_callout(
+ user: dict[str, typing.Any],
+ ) -> html.div:
+ """Render plan info callout box below navbar."""
+ tier = user.get("plan_tier", "free")
+
+ if tier == "free":
+ # Get usage and show quota
+ period_start, period_end = Billing.get_period_boundaries(user)
+ usage = Billing.get_usage(user["id"], period_start, period_end)
+ articles_used = usage["articles"]
+ articles_limit = 10
+ articles_left = max(0, articles_limit - articles_used)
+
+ return html.div(
+ html.div(
+ html.div(
+ html.i(
+ classes=[
+ "bi",
+ "bi-info-circle-fill",
+ "me-2",
+ ],
+ ),
+ html.strong(f"{articles_left} articles remaining"),
+ " of your free plan limit. ",
+ html.br(),
+ "Upgrade to ",
+ html.strong("Paid Plan"),
+ " for unlimited articles at $12/month.",
+ ),
+ html.form(
+ html.input(
+ type="hidden",
+ name="tier",
+ value="paid",
+ ),
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-arrow-up-circle",
+ "me-1",
+ ],
+ ),
+ "Upgrade Now",
+ type="submit",
+ classes=[
+ "btn",
+ "btn-success",
+ "btn-sm",
+ "mt-2",
+ ],
+ ),
+ method="post",
+ action="/billing/checkout",
+ ),
+ classes=[
+ "alert",
+ "alert-info",
+ "d-flex",
+ "justify-content-between",
+ "align-items-center",
+ "mb-4",
+ ],
+ ),
+ classes=["mb-4"],
+ )
+ # Paid user - no callout needed
+ return html.div()
+
+ @override
+ def render(self) -> UI.PageLayout | html.html:
+ queue_items = self.attrs["queue_items"]
+ episodes = self.attrs["episodes"]
+ user = self.attrs.get("user")
+ error = self.attrs.get("error")
+
+ if not user:
+ # Show public feed with login form for logged-out users
+ return UI.PageLayout(
+ LoginForm(error=error),
+ html.div(
+ html.h4(
+ html.i(classes=["bi", "bi-broadcast", "me-2"]),
+ "Public Feed",
+ classes=["mb-3", "mt-4"],
+ ),
+ html.p(
+ "Featured articles converted to audio. "
+ "Sign up to create your own personal feed!",
+ classes=["text-muted", "mb-3"],
+ ),
+ EpisodeList(
+ episodes=episodes,
+ rss_url=None,
+ user=None,
+ viewing_own_feed=False,
+ ),
+ ),
+ user=None,
+ current_page="home",
+ error=error,
+ )
+
+ return UI.PageLayout(
+ self._render_plan_callout(user),
+ SubmitForm(),
+ html.div(
+ QueueStatus(items=queue_items),
+ EpisodeList(
+ episodes=episodes,
+ rss_url=f"{BASE_URL}/feed/{user['token']}.rss",
+ user=user,
+ viewing_own_feed=True,
+ ),
+ id="dashboard-content",
+ hx_get="/dashboard-updates",
+ hx_trigger="every 3s, queue-updated from:body",
+ hx_swap="innerHTML",
+ ),
+ user=user,
+ current_page="home",
+ error=error,
+ )
+
+
+# Create ludic app with session support
+app = LudicApp()
+app.add_middleware(
+ SessionMiddleware,
+ secret_key=os.getenv("SESSION_SECRET", "dev-secret-key"),
+ max_age=SESSION_MAX_AGE, # 30 days
+ same_site="lax",
+ https_only=App.from_env() == App.Area.Live, # HTTPS only in production
+)
+
+
+@app.get("/")
+def index(request: Request) -> HomePage:
+ """Display main page with form and status."""
+ user_id = request.session.get("user_id")
+ user = None
+ queue_items = []
+ episodes = []
+ error = request.query_params.get("error")
+ status = request.query_params.get("status")
+
+ # Map error codes to user-friendly messages
+ error_messages = {
+ "invalid_link": "Invalid login link",
+ "expired_link": "Login link has expired. Please request a new one.",
+ "user_not_found": "User not found. Please try logging in again.",
+ "forbidden": "Access denied. Admin privileges required.",
+ "cancel": "Checkout cancelled.",
+ }
+
+ # Handle billing status messages
+ if status == "success":
+ error_message = None
+ elif status == "cancel":
+ error_message = error_messages["cancel"]
+ else:
+ error_message = error_messages.get(error) if error else None
+
+ if user_id:
+ user = Core.Database.get_user_by_id(user_id)
+ if user:
+ # Get user-specific queue items and episodes
+ queue_items = Core.Database.get_user_queue_status(
+ user_id,
+ )
+ episodes = Core.Database.get_user_episodes(
+ user_id,
+ )
+ else:
+ # Show public feed when not logged in
+ episodes = Core.Database.get_public_episodes(10)
+
+ return HomePage(
+ queue_items=queue_items,
+ episodes=episodes,
+ user=user,
+ error=error_message,
+ )
+
+
+@app.get("/public")
+def public_feed(request: Request) -> PublicFeedPage:
+ """Display public feed page."""
+ # Always show public episodes, whether user is logged in or not
+ episodes = Core.Database.get_public_episodes(50)
+ user_id = request.session.get("user_id")
+ user = Core.Database.get_user_by_id(user_id) if user_id else None
+
+ return PublicFeedPage(
+ episodes=episodes,
+ user=user,
+ )
+
+
+@app.get("/pricing")
+def pricing(request: Request) -> UI.PricingPage:
+ """Display pricing page."""
+ user_id = request.session.get("user_id")
+ user = Core.Database.get_user_by_id(user_id) if user_id else None
+
+ return UI.PricingPage(
+ user=user,
+ )
+
+
+@app.post("/upgrade")
+def upgrade(request: Request) -> RedirectResponse:
+ """Start upgrade checkout flow."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ try:
+ checkout_url = Billing.create_checkout_session(
+ user_id,
+ "paid",
+ BASE_URL,
+ )
+ return RedirectResponse(url=checkout_url, status_code=303)
+ except ValueError:
+ logger.exception("Failed to create checkout session")
+ return RedirectResponse(url="/pricing?error=checkout_failed")
+
+
+@app.post("/logout")
+def logout(request: Request) -> RedirectResponse:
+ """Log out user."""
+ request.session.clear()
+ return RedirectResponse(url="/", status_code=303)
+
+
+@app.post("/billing/portal")
+def billing_portal(request: Request) -> RedirectResponse:
+ """Redirect to Stripe billing portal."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ try:
+ portal_url = Billing.create_portal_session(user_id, BASE_URL)
+ return RedirectResponse(url=portal_url, status_code=303)
+ except ValueError as e:
+ logger.warning("Failed to create portal session: %s", e)
+ # If user has no customer ID (e.g. free tier), redirect to pricing
+ return RedirectResponse(url="/pricing")
+
+
+def _handle_test_login(email: str, request: Request) -> Response:
+ """Handle login in test mode."""
+ # Special handling for demo account
+ is_demo_account = email == "demo@example.com"
+
+ user = Core.Database.get_user_by_email(email)
+ if not user:
+ # Create new user
+ status = "active"
+ user_id, token = Core.Database.create_user(email, status=status)
+ user = {
+ "id": user_id,
+ "email": email,
+ "token": token,
+ "status": status,
+ }
+ elif is_demo_account and user.get("status") != "active":
+ # Auto-activate demo account if it exists but isn't active
+ Core.Database.update_user_status(user["id"], "active")
+ user["status"] = "active"
+
+ # Check if user is active
+ if user.get("status") != "active":
+ pending_message = (
+ '<div class="alert alert-warning">'
+ "Account created, currently pending. "
+ 'Email <a href="mailto:ben@bensima.com" '
+ 'class="alert-link">ben@bensima.com</a> '
+ 'or message <a href="https://x.com/bensima" '
+ 'target="_blank" class="alert-link">@bensima</a> '
+ "to get your account activated.</div>"
+ )
+ return Response(pending_message, status_code=200)
+
+ # Set session with extended lifetime
+ request.session["user_id"] = user["id"]
+ request.session["permanent"] = True
+
+ return Response(
+ '<div class="alert alert-success">✓ Logged in (dev mode)</div>',
+ status_code=200,
+ headers={"HX-Redirect": "/"},
+ )
+
+
+def _handle_production_login(email: str) -> Response:
+ """Handle login in production mode."""
+ pending_message = (
+ '<div class="alert alert-warning">'
+ "Account created, currently pending. "
+ 'Email <a href="mailto:ben@bensima.com" '
+ 'class="alert-link">ben@bensima.com</a> '
+ 'or message <a href="https://x.com/bensima" '
+ 'target="_blank" class="alert-link">@bensima</a> '
+ "to get your account activated.</div>"
+ )
+
+ # Get or create user
+ user = Core.Database.get_user_by_email(email)
+ if not user:
+ user_id, token = Core.Database.create_user(email)
+ user = {
+ "id": user_id,
+ "email": email,
+ "token": token,
+ "status": "active",
+ }
+
+ # Check if user is active
+ if user.get("status") != "active":
+ return Response(pending_message, status_code=200)
+
+ # Generate magic link token
+ magic_token = magic_link_serializer.dumps({
+ "user_id": user["id"],
+ "email": email,
+ })
+
+ # Send email
+ send_magic_link(email, magic_token)
+
+ return Response(
+ f'<div class="alert alert-success">✓ Magic link sent to {email}. '
+ f"Check your email!</div>",
+ status_code=200,
+ )
+
+
+@app.post("/login")
+def login(request: Request, data: FormData) -> Response:
+ """Handle login/registration."""
+ try:
+ email_raw = data.get("email", "")
+ email = email_raw.strip().lower() if isinstance(email_raw, str) else ""
+
+ if not email:
+ return Response(
+ '<div class="alert alert-danger">Email is required</div>',
+ status_code=400,
+ )
+
+ area = App.from_env()
+
+ if area == App.Area.Test:
+ return _handle_test_login(email, request)
+ return _handle_production_login(email)
+
+ except Exception as e:
+ logger.exception("Login error")
+ return Response(
+ f'<div class="alert alert-danger">Error: {e!s}</div>',
+ status_code=500,
+ )
+
+
+@app.get("/auth/verify")
+def verify_magic_link(request: Request) -> Response:
+ """Verify magic link and log user in."""
+ token = request.query_params.get("token")
+
+ if not token:
+ return RedirectResponse("/?error=invalid_link")
+
+ try:
+ # Verify token
+ data = magic_link_serializer.loads(token, max_age=MAGIC_LINK_MAX_AGE)
+ user_id = data["user_id"]
+
+ # Verify user still exists
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return RedirectResponse("/?error=user_not_found")
+
+ # Set session with extended lifetime
+ request.session["user_id"] = user_id
+ request.session["permanent"] = True
+
+ return RedirectResponse("/")
+
+ except (ValueError, KeyError):
+ # Token is invalid or expired
+ return RedirectResponse("/?error=expired_link")
+
+
+@app.get("/settings/email/edit")
+def edit_email_form(request: Request) -> typing.Any:
+ """Return form to edit email."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return Response("User not found", status_code=404)
+
+ return html.div(
+ html.form(
+ html.strong("Email: ", classes=["me-2"]),
+ html.input(
+ type="email",
+ name="email",
+ value=user["email"],
+ required=True,
+ classes=[
+ "form-control",
+ "form-control-sm",
+ "d-inline-block",
+ "w-auto",
+ "me-2",
+ ],
+ ),
+ html.button(
+ "Save",
+ type="submit",
+ classes=["btn", "btn-sm", "btn-primary", "me-1"],
+ ),
+ html.button(
+ "Cancel",
+ hx_get="/settings/email/cancel",
+ hx_target="closest div",
+ hx_swap="outerHTML",
+ classes=["btn", "btn-sm", "btn-secondary"],
+ ),
+ hx_post="/settings/email",
+ hx_target="closest div",
+ hx_swap="outerHTML",
+ classes=["d-flex", "align-items-center"],
+ ),
+ classes=["mb-2"],
+ )
+
+
+@app.get("/settings/email/cancel")
+def cancel_edit_email(request: Request) -> typing.Any:
+ """Cancel email editing and show original view."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return Response("User not found", status_code=404)
+
+ return html.div(
+ html.strong("Email: "),
+ html.span(user["email"]),
+ html.button(
+ "Change",
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-secondary",
+ "ms-2",
+ "py-0",
+ ],
+ hx_get="/settings/email/edit",
+ hx_target="closest div",
+ hx_swap="outerHTML",
+ ),
+ classes=["mb-2", "d-flex", "align-items-center"],
+ )
+
+
+@app.post("/settings/email")
+def update_email(request: Request, data: FormData) -> typing.Any:
+ """Update user email."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ new_email_raw = data.get("email", "")
+ new_email = (
+ new_email_raw.strip().lower() if isinstance(new_email_raw, str) else ""
+ )
+
+ if not new_email:
+ return Response("Email required", status_code=400)
+
+ try:
+ Core.Database.update_user_email(user_id, new_email)
+ return cancel_edit_email(request)
+ except ValueError as e:
+ # Return form with error
+ return html.div(
+ html.form(
+ html.strong("Email: ", classes=["me-2"]),
+ html.input(
+ type="email",
+ name="email",
+ value=new_email,
+ required=True,
+ classes=[
+ "form-control",
+ "form-control-sm",
+ "d-inline-block",
+ "w-auto",
+ "me-2",
+ "is-invalid",
+ ],
+ ),
+ html.button(
+ "Save",
+ type="submit",
+ classes=["btn", "btn-sm", "btn-primary", "me-1"],
+ ),
+ html.button(
+ "Cancel",
+ hx_get="/settings/email/cancel",
+ hx_target="closest div",
+ hx_swap="outerHTML",
+ classes=["btn", "btn-sm", "btn-secondary"],
+ ),
+ html.div(
+ str(e),
+ classes=["invalid-feedback", "d-block", "ms-2"],
+ ),
+ hx_post="/settings/email",
+ hx_target="closest div",
+ hx_swap="outerHTML",
+ classes=["d-flex", "align-items-center", "flex-wrap"],
+ ),
+ classes=["mb-2"],
+ )
+
+
+@app.get("/account")
+def account_page(request: Request) -> typing.Any:
+ """Account management page."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return RedirectResponse(url="/?error=user_not_found")
+
+ # Get usage stats
+ period_start, period_end = Billing.get_period_boundaries(user)
+ usage = Billing.get_usage(user["id"], period_start, period_end)
+
+ # Get limits
+ tier = user.get("plan_tier", "free")
+ limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"])
+
+ return UI.AccountPage(
+ user=user,
+ usage=usage,
+ limits=limits,
+ portal_url="/billing/portal" if tier == "paid" else None,
+ )
+
+
+@app.delete("/account")
+def delete_account(request: Request) -> Response:
+ """Delete user account."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ Core.Database.delete_user(user_id)
+ request.session.clear()
+
+ return Response(
+ "Account deleted",
+ headers={"HX-Redirect": "/?message=account_deleted"},
+ )
+
+
+@app.post("/submit")
+def submit_article( # noqa: PLR0911, PLR0914
+ request: Request,
+ data: FormData,
+) -> typing.Any:
+ """Handle manual form submission."""
+ try:
+ # Check if user is logged in
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return html.div(
+ html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]),
+ "Error: Please login first",
+ classes=["alert", "alert-danger"],
+ )
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return html.div(
+ html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]),
+ "Error: Invalid session",
+ classes=["alert", "alert-danger"],
+ )
+
+ url_raw = data.get("url", "")
+ url = url_raw.strip() if isinstance(url_raw, str) else ""
+
+ if not url:
+ return html.div(
+ html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]),
+ "Error: URL is required",
+ classes=["alert", "alert-danger"],
+ )
+
+ # Basic URL validation
+ parsed = urllib.parse.urlparse(url)
+ if not parsed.scheme or not parsed.netloc:
+ return html.div(
+ html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]),
+ "Error: Invalid URL format",
+ classes=["alert", "alert-danger"],
+ )
+
+ # Check usage limits
+ allowed, _msg, usage = Billing.can_submit(user_id)
+ if not allowed:
+ tier = user.get("plan_tier", "free")
+ tier_info = Billing.get_tier_info(tier)
+ limit = tier_info.get("articles_limit", 0)
+ return html.div(
+ html.i(classes=["bi", "bi-exclamation-circle", "me-2"]),
+ html.strong("Limit reached: "),
+ f"You've used {usage['articles']}/{limit} articles "
+ "this period. ",
+ html.a(
+ "Upgrade your plan",
+ href="/billing",
+ classes=["alert-link"],
+ ),
+ " to continue.",
+ classes=["alert", "alert-warning"],
+ )
+
+ # Check if episode already exists for this URL
+ url_hash = Core.hash_url(url)
+ existing_episode = Core.Database.get_episode_by_url_hash(url_hash)
+
+ if existing_episode:
+ # Episode already processed - check if user has it
+ episode_id = existing_episode["id"]
+ if Core.Database.user_has_episode(user_id, episode_id):
+ return html.div(
+ html.i(classes=["bi", "bi-info-circle", "me-2"]),
+ "This episode is already in your feed.",
+ classes=["alert", "alert-info"],
+ )
+ # Add existing episode to user's feed
+ Core.Database.add_episode_to_user(user_id, episode_id)
+ Core.Database.track_episode_event(
+ episode_id,
+ "added",
+ user_id,
+ )
+ return html.div(
+ html.i(classes=["bi", "bi-check-circle", "me-2"]),
+ "✓ Episode added to your feed! ",
+ html.a(
+ "View episode",
+ href=f"/episode/{encode_episode_id(episode_id)}",
+ classes=["alert-link"],
+ ),
+ classes=["alert", "alert-success"],
+ )
+
+ # Episode doesn't exist yet - extract metadata and queue for processing
+ title, author = extract_og_metadata(url)
+
+ job_id = Core.Database.add_to_queue(
+ url,
+ user["email"],
+ user_id,
+ title=title,
+ author=author,
+ )
+ return html.div(
+ html.i(classes=["bi", "bi-check-circle", "me-2"]),
+ f"✓ Article submitted successfully! Job ID: {job_id}",
+ classes=["alert", "alert-success"],
+ )
+
+ except (httpx.HTTPError, httpx.TimeoutException, ValueError) as e:
+ return html.div(
+ html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]),
+ f"Error: {e!s}",
+ classes=["alert", "alert-danger"],
+ )
+
+
+@app.get("/feed/{token}.rss")
+def rss_feed(request: Request, token: str) -> Response: # noqa: ARG001
+ """Generate user-specific RSS podcast feed."""
+ try:
+ # Validate token and get user
+ user = Core.Database.get_user_by_token(token)
+ if not user:
+ return Response("Invalid feed token", status_code=404)
+
+ # Get episodes for this user only
+ episodes = Core.Database.get_user_all_episodes(
+ user["id"],
+ )
+
+ # Extract first name from email for personalization
+ email_name = user["email"].split("@")[0].split(".")[0].title()
+
+ fg = FeedGenerator()
+ fg.title(f"{email_name}'s Article Podcast")
+ fg.description(f"Web articles converted to audio for {user['email']}")
+ fg.author(name=RSS_CONFIG["author"])
+ fg.language(RSS_CONFIG["language"])
+ fg.link(href=f"{RSS_CONFIG['base_url']}/feed/{token}.rss")
+ fg.id(f"{RSS_CONFIG['base_url']}/feed/{token}.rss")
+
+ for episode in episodes:
+ fe = fg.add_entry()
+ episode_sqid = encode_episode_id(episode["id"])
+ fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode_sqid}")
+ fe.title(episode["title"])
+ fe.description(episode["title"])
+ fe.enclosure(
+ episode["audio_url"],
+ str(episode.get("content_length", 0)),
+ "audio/mpeg",
+ )
+ # SQLite timestamps don't have timezone info, so add UTC
+ created_at = datetime.fromisoformat(episode["created_at"])
+ if created_at.tzinfo is None:
+ created_at = created_at.replace(tzinfo=timezone.utc)
+ fe.pubDate(created_at)
+
+ rss_str = fg.rss_str(pretty=True)
+ return Response(
+ rss_str,
+ media_type="application/rss+xml; charset=utf-8",
+ )
+
+ except (ValueError, KeyError, AttributeError) as e:
+ return Response(f"Error generating feed: {e}", status_code=500)
+
+
+# Backwards compatibility: .xml extension
+@app.get("/feed/{token}.xml")
+def rss_feed_xml_alias(request: Request, token: str) -> Response:
+ """Alias for .rss feed (backwards compatibility)."""
+ return rss_feed(request, token)
+
+
+@app.get("/public.rss")
+def public_rss_feed(request: Request) -> Response: # noqa: ARG001
+ """Generate public RSS podcast feed."""
+ try:
+ # Get public episodes
+ episodes = Core.Database.get_public_episodes(50)
+
+ fg = FeedGenerator()
+ fg.title("PodcastItLater Public Feed")
+ fg.description("Curated articles converted to audio")
+ fg.author(name=RSS_CONFIG["author"])
+ fg.language(RSS_CONFIG["language"])
+ fg.link(href=f"{RSS_CONFIG['base_url']}/public.rss")
+ fg.id(f"{RSS_CONFIG['base_url']}/public.rss")
+
+ for episode in episodes:
+ fe = fg.add_entry()
+ episode_sqid = encode_episode_id(episode["id"])
+ fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode_sqid}")
+ fe.title(episode["title"])
+ fe.description(episode["title"])
+ fe.enclosure(
+ episode["audio_url"],
+ str(episode.get("content_length", 0)),
+ "audio/mpeg",
+ )
+ # SQLite timestamps don't have timezone info, so add UTC
+ created_at = datetime.fromisoformat(episode["created_at"])
+ if created_at.tzinfo is None:
+ created_at = created_at.replace(tzinfo=timezone.utc)
+ fe.pubDate(created_at)
+
+ rss_str = fg.rss_str(pretty=True)
+ return Response(
+ rss_str,
+ media_type="application/rss+xml; charset=utf-8",
+ )
+
+ except (ValueError, KeyError, AttributeError) as e:
+ return Response(f"Error generating feed: {e}", status_code=500)
+
+
+# Backwards compatibility: .xml extension
+@app.get("/public.xml")
+def public_rss_feed_xml_alias(request: Request) -> Response:
+ """Alias for .rss feed (backwards compatibility)."""
+ return public_rss_feed(request)
+
+
+@app.get("/episode/{episode_id:int}")
+def episode_detail_legacy(
+ request: Request, # noqa: ARG001
+ episode_id: int,
+) -> RedirectResponse:
+ """Redirect legacy integer episode IDs to sqid URLs.
+
+ Deprecated: This route exists for backward compatibility.
+ Will be removed in a future version.
+ """
+ episode_sqid = encode_episode_id(episode_id)
+ return RedirectResponse(
+ url=f"/episode/{episode_sqid}",
+ status_code=301, # Permanent redirect
+ )
+
+
+@app.get("/episode/{episode_sqid}")
+def episode_detail(
+ request: Request,
+ episode_sqid: str,
+) -> Episode.EpisodeDetailPage | Response:
+ """Display individual episode page (public, no auth required)."""
+ try:
+ # Decode sqid to episode ID
+ episode_id = decode_episode_id(episode_sqid)
+ if episode_id is None:
+ return Response("Invalid episode ID", status_code=404)
+
+ # Get episode from database
+ episode = Core.Database.get_episode_by_id(episode_id)
+
+ if not episode:
+ return Response("Episode not found", status_code=404)
+
+ # Get creator email if episode has user_id
+ creator_email = None
+ if episode.get("user_id"):
+ creator = Core.Database.get_user_by_id(episode["user_id"])
+ creator_email = creator["email"] if creator else None
+
+ # Check if current user is logged in
+ user_id = request.session.get("user_id")
+ user = None
+ user_has_episode = False
+ if user_id:
+ user = Core.Database.get_user_by_id(user_id)
+ user_has_episode = Core.Database.user_has_episode(
+ user_id,
+ episode_id,
+ )
+
+ return Episode.EpisodeDetailPage(
+ episode=episode,
+ episode_sqid=episode_sqid,
+ creator_email=creator_email,
+ user=user,
+ base_url=BASE_URL,
+ user_has_episode=user_has_episode,
+ )
+
+ except (ValueError, KeyError) as e:
+ logger.exception("Error loading episode")
+ return Response(f"Error loading episode: {e}", status_code=500)
+
+
+@app.get("/status")
+def queue_status(request: Request) -> QueueStatus:
+ """Return HTMX endpoint for live queue updates."""
+ # Check if user is logged in
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return QueueStatus(items=[])
+
+ # Get user-specific queue items
+ queue_items = Core.Database.get_user_queue_status(
+ user_id,
+ )
+ return QueueStatus(items=queue_items)
+
+
+@app.get("/dashboard-updates")
+def dashboard_updates(request: Request) -> Response:
+ """Return both queue status and recent episodes for dashboard updates."""
+ # Check if user is logged in
+ user_id = request.session.get("user_id")
+ if not user_id:
+ queue_status = QueueStatus(items=[])
+ episode_list = EpisodeList(
+ episodes=[],
+ rss_url=None,
+ user=None,
+ viewing_own_feed=False,
+ )
+ # Return HTML as string with both components
+ return Response(
+ str(queue_status) + str(episode_list),
+ media_type="text/html",
+ )
+
+ # Get user info for RSS URL
+ user = Core.Database.get_user_by_id(user_id)
+ rss_url = f"{BASE_URL}/feed/{user['token']}.rss" if user else None
+
+ # Get user-specific queue items and episodes
+ queue_items = Core.Database.get_user_queue_status(user_id)
+ episodes = Core.Database.get_user_recent_episodes(user_id, 10)
+
+ # Return just the content components, not the wrapper div
+ # The wrapper div with HTMX attributes is in HomePage
+ queue_status = QueueStatus(items=queue_items)
+ episode_list = EpisodeList(
+ episodes=episodes,
+ rss_url=rss_url,
+ user=user,
+ viewing_own_feed=True,
+ )
+ return Response(
+ str(queue_status) + str(episode_list),
+ media_type="text/html",
+ )
+
+
+# Register admin routes
+app.get("/admin")(Admin.admin_queue_status)
+app.post("/queue/{job_id}/retry")(Admin.retry_queue_item)
+
+
+@app.post("/billing/checkout")
+def billing_checkout(request: Request, data: FormData) -> Response:
+ """Create Stripe Checkout session."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ tier_raw = data.get("tier", "paid")
+ tier = tier_raw if isinstance(tier_raw, str) else "paid"
+ if tier != "paid":
+ return Response("Invalid tier", status_code=400)
+
+ try:
+ checkout_url = Billing.create_checkout_session(user_id, tier, BASE_URL)
+ return RedirectResponse(url=checkout_url, status_code=303)
+ except ValueError as e:
+ logger.exception("Checkout error")
+ return Response(f"Error: {e!s}", status_code=400)
+
+
+@app.post("/stripe/webhook")
+async def stripe_webhook(request: Request) -> Response:
+ """Handle Stripe webhook events."""
+ payload = await request.body()
+ sig_header = request.headers.get("stripe-signature", "")
+
+ try:
+ result = Billing.handle_webhook_event(payload, sig_header)
+ return Response(f"OK: {result['status']}", status_code=200)
+ except Exception as e:
+ logger.exception("Webhook error")
+ return Response(f"Error: {e!s}", status_code=400)
+
+
+@app.post("/queue/{job_id}/cancel")
+def cancel_queue_item(request: Request, job_id: int) -> Response:
+ """Cancel a pending queue item."""
+ try:
+ # Check if user is logged in
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ # Get job and verify ownership
+ job = Core.Database.get_job_by_id(job_id)
+ if job is None or job.get("user_id") != user_id:
+ return Response("Forbidden", status_code=403)
+
+ # Only allow canceling pending jobs
+ if job.get("status") != "pending":
+ return Response("Can only cancel pending jobs", status_code=400)
+
+ # Update status to cancelled
+ Core.Database.update_job_status(
+ job_id,
+ "cancelled",
+ error="Cancelled by user",
+ )
+
+ # Return success with HTMX trigger to refresh
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Trigger": "queue-updated"},
+ )
+ except (ValueError, KeyError) as e:
+ return Response(
+ f"Error cancelling job: {e!s}",
+ status_code=500,
+ )
+
+
+app.delete("/queue/{job_id}")(Admin.delete_queue_item)
+app.get("/admin/users")(Admin.admin_users)
+app.get("/admin/metrics")(Admin.admin_metrics)
+app.post("/admin/users/{user_id}/status")(Admin.update_user_status)
+app.post("/admin/episode/{episode_id}/toggle-public")(
+ Admin.toggle_episode_public,
+)
+
+
+@app.post("/episode/{episode_id}/add-to-feed")
+def add_episode_to_feed(request: Request, episode_id: int) -> Response:
+ """Add an episode to the user's feed."""
+ # Check if user is logged in
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response(
+ '<div class="alert alert-warning">Please login first</div>',
+ status_code=200,
+ )
+
+ # Check if episode exists
+ episode = Core.Database.get_episode_by_id(episode_id)
+ if not episode:
+ return Response(
+ '<div class="alert alert-danger">Episode not found</div>',
+ status_code=404,
+ )
+
+ # Check if user already has this episode
+ if Core.Database.user_has_episode(user_id, episode_id):
+ return Response(
+ '<div class="alert alert-info">Already in your feed</div>',
+ status_code=200,
+ )
+
+ # Add episode to user's feed
+ Core.Database.add_episode_to_user(user_id, episode_id)
+
+ # Track the "added" event
+ Core.Database.track_episode_event(episode_id, "added", user_id)
+
+ # Reload the current page to show updated button state
+ # Check referer to determine where to redirect
+ referer = request.headers.get("referer", "/")
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Redirect": referer},
+ )
+
+
+@app.post("/episode/{episode_id}/track")
+def track_episode(
+ request: Request,
+ episode_id: int,
+ data: FormData,
+) -> Response:
+ """Track an episode metric event (play, download)."""
+ # Get event type from form data
+ event_type_raw = data.get("event_type", "")
+ event_type = event_type_raw if isinstance(event_type_raw, str) else ""
+
+ # Validate event type
+ if event_type not in {"played", "downloaded"}:
+ return Response("Invalid event type", status_code=400)
+
+ # Get user ID if logged in (None for anonymous)
+ user_id = request.session.get("user_id")
+
+ # Track the event
+ Core.Database.track_episode_event(episode_id, event_type, user_id)
+
+ return Response("", status_code=200)
+
+
+class BaseWebTest(Test.TestCase):
+ """Base class for web tests with database setup."""
+
+ def setUp(self) -> None:
+ """Set up test database and client."""
+ Core.Database.init_db()
+ # Create test client
+ self.client = TestClient(app)
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up test database."""
+ Core.Database.teardown()
+
+
+class TestDurationFormatting(Test.TestCase):
+ """Test duration formatting functionality."""
+
+ def test_format_duration_minutes_only(self) -> None:
+ """Test formatting durations less than an hour."""
+ self.assertEqual(UI.format_duration(60), "1m")
+ self.assertEqual(UI.format_duration(240), "4m")
+ self.assertEqual(UI.format_duration(300), "5m")
+ self.assertEqual(UI.format_duration(3599), "60m")
+
+ def test_format_duration_hours_and_minutes(self) -> None:
+ """Test formatting durations with hours and minutes."""
+ self.assertEqual(UI.format_duration(3600), "1h")
+ self.assertEqual(UI.format_duration(3840), "1h 4m")
+ self.assertEqual(UI.format_duration(11520), "3h 12m")
+ self.assertEqual(UI.format_duration(7320), "2h 2m")
+
+ def test_format_duration_round_up(self) -> None:
+ """Test that seconds are rounded up to nearest minute."""
+ self.assertEqual(UI.format_duration(61), "2m")
+ self.assertEqual(UI.format_duration(119), "2m")
+ self.assertEqual(UI.format_duration(121), "3m")
+ self.assertEqual(UI.format_duration(3601), "1h 1m")
+
+ def test_format_duration_edge_cases(self) -> None:
+ """Test edge cases for duration formatting."""
+ self.assertEqual(UI.format_duration(None), "Unknown")
+ self.assertEqual(UI.format_duration(0), "Unknown")
+ self.assertEqual(UI.format_duration(-100), "Unknown")
+
+
+class TestAuthentication(BaseWebTest):
+ """Test authentication functionality."""
+
+ def test_login_new_user_active(self) -> None:
+ """New users should be created with active status."""
+ response = self.client.post("/login", data={"email": "new@example.com"})
+ self.assertEqual(response.status_code, 200)
+
+ # Verify user was created with active status
+ user = Core.Database.get_user_by_email(
+ "new@example.com",
+ )
+ self.assertIsNotNone(user)
+ if user is None:
+ msg = "no user found"
+ raise Test.TestError(msg)
+ self.assertEqual(user.get("status"), "active")
+
+ def test_login_active_user(self) -> None:
+ """Active users should be able to login."""
+ # Create user and set to active
+ user_id, _ = Core.Database.create_user(
+ "active@example.com",
+ )
+ Core.Database.update_user_status(user_id, "active")
+
+ response = self.client.post(
+ "/login",
+ data={"email": "active@example.com"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("HX-Redirect", response.headers)
+
+ def test_login_existing_pending_user(self) -> None:
+ """Existing pending users should see the pending message."""
+ # Create a pending user
+ _user_id, _ = Core.Database.create_user(
+ "pending@example.com",
+ status="pending",
+ )
+
+ response = self.client.post(
+ "/login",
+ data={"email": "pending@example.com"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Account created, currently pending", response.text)
+ self.assertIn("ben@bensima.com", response.text)
+ self.assertIn("@bensima", response.text)
+
+ def test_login_disabled_user(self) -> None:
+ """Disabled users should not be able to login."""
+ # Create user and set to disabled
+ user_id, _ = Core.Database.create_user(
+ "disabled@example.com",
+ )
+ Core.Database.update_user_status(
+ user_id,
+ "disabled",
+ )
+
+ response = self.client.post(
+ "/login",
+ data={"email": "disabled@example.com"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Account created, currently pending", response.text)
+
+ def test_login_invalid_email(self) -> None:
+ """Reject malformed emails."""
+ response = self.client.post("/login", data={"email": ""})
+
+ self.assertEqual(response.status_code, 400)
+ self.assertIn("Email is required", response.text)
+
+ def test_session_persistence(self) -> None:
+ """Verify session across requests."""
+ # Create active user
+ _user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ status="active",
+ )
+ # Login
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ # Access protected page
+ response = self.client.get("/")
+
+ # Should see logged-in content (navbar with Manage Account link)
+ self.assertIn("Manage Account", response.text)
+ self.assertIn("Home", response.text)
+
+ def test_protected_routes_pending_user(self) -> None:
+ """Pending users should not access protected routes."""
+ # Create pending user
+ Core.Database.create_user("pending@example.com", status="pending")
+
+ # Try to login
+ response = self.client.post(
+ "/login",
+ data={"email": "pending@example.com"},
+ )
+ self.assertEqual(response.status_code, 200)
+
+ # Should not have session
+ response = self.client.get("/")
+ self.assertNotIn("Logged in as:", response.text)
+
+ def test_protected_routes(self) -> None:
+ """Ensure auth required for user actions."""
+ # Try to submit without login
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com"},
+ )
+
+ self.assertIn("Please login first", response.text)
+
+
+class TestArticleSubmission(BaseWebTest):
+ """Test article submission functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in user."""
+ super().setUp()
+ # Create active user and login
+ user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ )
+ Core.Database.update_user_status(user_id, "active")
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ def test_submit_valid_url(self) -> None:
+ """Accept well-formed URLs."""
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com/article"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Article submitted successfully", response.text)
+ self.assertIn("Job ID:", response.text)
+
+ def test_submit_invalid_url(self) -> None:
+ """Reject malformed URLs."""
+ response = self.client.post("/submit", data={"url": "not-a-url"})
+
+ self.assertIn("Invalid URL format", response.text)
+
+ def test_submit_without_auth(self) -> None:
+ """Reject unauthenticated submissions."""
+ # Clear session
+ self.client.get("/logout")
+
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com"},
+ )
+
+ self.assertIn("Please login first", response.text)
+
+ def test_submit_creates_job(self) -> None:
+ """Verify job creation in database."""
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com/test"},
+ )
+
+ # Extract job ID from response
+ match = re.search(r"Job ID: (\d+)", response.text)
+ self.assertIsNotNone(match)
+ if match is None:
+ self.fail("Job ID not found in response")
+ job_id = int(match.group(1))
+
+ # Verify job in database
+ job = Core.Database.get_job_by_id(job_id)
+ self.assertIsNotNone(job)
+ if job is None: # Type guard for mypy
+ self.fail("Job should not be None")
+ self.assertEqual(job["url"], "https://example.com/test")
+ self.assertEqual(job["status"], "pending")
+
+ def test_htmx_response(self) -> None:
+ """Ensure proper HTMX response format."""
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com"},
+ )
+
+ # Should return HTML fragment, not full page
+ self.assertNotIn("<!DOCTYPE", response.text)
+ self.assertIn("<div", response.text)
+
+
+class TestRSSFeed(BaseWebTest):
+ """Test RSS feed generation."""
+
+ def setUp(self) -> None:
+ """Set up test client and create test data."""
+ super().setUp()
+
+ # Create user and episodes
+ self.user_id, self.token = Core.Database.create_user(
+ "test@example.com",
+ )
+ Core.Database.update_user_status(
+ self.user_id,
+ "active",
+ )
+
+ # Create test episodes
+ ep1_id = Core.Database.create_episode(
+ "Episode 1",
+ "https://example.com/ep1.mp3",
+ 300,
+ 5000,
+ self.user_id,
+ )
+ ep2_id = Core.Database.create_episode(
+ "Episode 2",
+ "https://example.com/ep2.mp3",
+ 600,
+ 10000,
+ self.user_id,
+ )
+ Core.Database.add_episode_to_user(self.user_id, ep1_id)
+ Core.Database.add_episode_to_user(self.user_id, ep2_id)
+
+ def test_feed_generation(self) -> None:
+ """Generate valid RSS XML."""
+ response = self.client.get(f"/feed/{self.token}.rss")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(
+ response.headers["content-type"],
+ "application/rss+xml; charset=utf-8",
+ )
+
+ # Verify RSS structure
+ self.assertIn("<?xml", response.text)
+ self.assertIn("<rss", response.text)
+ self.assertIn("<channel>", response.text)
+ self.assertIn("<item>", response.text)
+
+ def test_feed_user_isolation(self) -> None:
+ """Only show user's episodes."""
+ # Create another user with episodes
+ user2_id, _ = Core.Database.create_user(
+ "other@example.com",
+ )
+ other_ep_id = Core.Database.create_episode(
+ "Other Episode",
+ "https://example.com/other.mp3",
+ 400,
+ 6000,
+ user2_id,
+ )
+ Core.Database.add_episode_to_user(user2_id, other_ep_id)
+
+ # Get first user's feed
+ response = self.client.get(f"/feed/{self.token}.rss")
+
+ # Should only have user's episodes
+ self.assertIn("Episode 1", response.text)
+ self.assertIn("Episode 2", response.text)
+ self.assertNotIn("Other Episode", response.text)
+
+ def test_feed_invalid_token(self) -> None:
+ """Return 404 for bad tokens."""
+ response = self.client.get("/feed/invalid-token.rss")
+
+ self.assertEqual(response.status_code, 404)
+
+ def test_feed_metadata(self) -> None:
+ """Verify personalized feed titles."""
+ response = self.client.get(f"/feed/{self.token}.rss")
+
+ # Should personalize based on email
+ self.assertIn("Test's Article Podcast", response.text)
+ self.assertIn("test@example.com", response.text)
+
+ def test_feed_episode_order(self) -> None:
+ """Ensure reverse chronological order."""
+ response = self.client.get(f"/feed/{self.token}.rss")
+
+ # Episode 2 should appear before Episode 1
+ ep2_pos = response.text.find("Episode 2")
+ ep1_pos = response.text.find("Episode 1")
+ self.assertLess(ep2_pos, ep1_pos)
+
+ def test_feed_enclosures(self) -> None:
+ """Verify audio URLs and metadata."""
+ response = self.client.get(f"/feed/{self.token}.rss")
+
+ # Check enclosure tags
+ self.assertIn("<enclosure", response.text)
+ self.assertIn('type="audio/mpeg"', response.text)
+
+ def test_feed_xml_alias_works(self) -> None:
+ """Test .xml extension works for backwards compatibility."""
+ # Get feed with .xml extension
+ response_xml = self.client.get(f"/feed/{self.token}.xml")
+ # Get feed with .rss extension
+ response_rss = self.client.get(f"/feed/{self.token}.rss")
+
+ # Both should work and return same content
+ self.assertEqual(response_xml.status_code, 200)
+ self.assertEqual(response_rss.status_code, 200)
+ self.assertEqual(response_xml.text, response_rss.text)
+
+ def test_public_feed_xml_alias_works(self) -> None:
+ """Test .xml extension works for public feed."""
+ # Get feed with .xml extension
+ response_xml = self.client.get("/public.xml")
+ # Get feed with .rss extension
+ response_rss = self.client.get("/public.rss")
+
+ # Both should work and return same content
+ self.assertEqual(response_xml.status_code, 200)
+ self.assertEqual(response_rss.status_code, 200)
+ self.assertEqual(response_xml.text, response_rss.text)
+
+
+class TestAdminInterface(BaseWebTest):
+ """Test admin interface functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in user."""
+ super().setUp()
+
+ # Create and login admin user
+ self.user_id, _ = Core.Database.create_user(
+ "ben@bensima.com",
+ )
+ Core.Database.update_user_status(
+ self.user_id,
+ "active",
+ )
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ # Create test data
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com/test",
+ "ben@bensima.com",
+ self.user_id,
+ )
+
+ def test_queue_status_view(self) -> None:
+ """Verify queue display."""
+ response = self.client.get("/admin")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Queue Status", response.text)
+ self.assertIn("https://example.com/test", response.text)
+
+ def test_retry_action(self) -> None:
+ """Test retry button functionality."""
+ # Set job to error state
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "Failed",
+ )
+
+ # Retry
+ response = self.client.post(f"/queue/{self.job_id}/retry")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("HX-Redirect", response.headers)
+
+ # Job should be pending again
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job is not None:
+ self.assertEqual(job["status"], "pending")
+
+ def test_delete_action(self) -> None:
+ """Test delete button functionality."""
+ response = self.client.delete(
+ f"/queue/{self.job_id}",
+ headers={"referer": "/admin"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("HX-Redirect", response.headers)
+
+ # Job should be gone
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNone(job)
+
+ def test_user_data_isolation(self) -> None:
+ """Ensure admin sees all data."""
+ # Create another user's job
+ user2_id, _ = Core.Database.create_user(
+ "other@example.com",
+ )
+ Core.Database.add_to_queue(
+ "https://example.com/other",
+ "other@example.com",
+ user2_id,
+ )
+
+ # View queue status as admin
+ response = self.client.get("/admin")
+
+ # Admin should see all jobs
+ self.assertIn("https://example.com/test", response.text)
+ self.assertIn("https://example.com/other", response.text)
+
+ def test_status_summary(self) -> None:
+ """Verify status counts display."""
+ # Create jobs with different statuses
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "Failed",
+ )
+ job2 = Core.Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ Core.Database.update_job_status(
+ job2,
+ "processing",
+ )
+
+ response = self.client.get("/admin")
+
+ # Should show status counts
+ self.assertIn("ERROR: 1", response.text)
+ self.assertIn("PROCESSING: 1", response.text)
+
+
+class TestMetricsDashboard(BaseWebTest):
+ """Test metrics dashboard functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in admin user."""
+ super().setUp()
+
+ # Create and login admin user
+ self.user_id, _ = Core.Database.create_user(
+ "ben@bensima.com",
+ )
+ Core.Database.update_user_status(
+ self.user_id,
+ "active",
+ )
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ def test_metrics_page_requires_admin(self) -> None:
+ """Verify non-admin users cannot access metrics."""
+ # Create non-admin user
+ user_id, _ = Core.Database.create_user("user@example.com")
+ Core.Database.update_user_status(user_id, "active")
+
+ # Login as non-admin
+ self.client.get("/logout")
+ self.client.post("/login", data={"email": "user@example.com"})
+
+ # Try to access metrics
+ response = self.client.get("/admin/metrics", follow_redirects=False)
+
+ # Should redirect
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response.headers["Location"], "/?error=forbidden")
+
+ def test_metrics_page_requires_login(self) -> None:
+ """Verify unauthenticated users are redirected."""
+ self.client.get("/logout")
+
+ response = self.client.get("/admin/metrics", follow_redirects=False)
+
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response.headers["Location"], "/")
+
+ def test_metrics_displays_summary(self) -> None:
+ """Verify metrics summary is displayed."""
+ # Create test episode
+ episode_id = Core.Database.create_episode(
+ title="Test Episode",
+ audio_url="http://example.com/audio.mp3",
+ content_length=1000,
+ duration=300,
+ )
+ Core.Database.add_episode_to_user(self.user_id, episode_id)
+
+ # Track some events
+ Core.Database.track_episode_event(episode_id, "played")
+ Core.Database.track_episode_event(episode_id, "played")
+ Core.Database.track_episode_event(episode_id, "downloaded")
+ Core.Database.track_episode_event(episode_id, "added", self.user_id)
+
+ # Get metrics page
+ response = self.client.get("/admin/metrics")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Episode Metrics", response.text)
+ self.assertIn("Total Episodes", response.text)
+ self.assertIn("Total Plays", response.text)
+
+ def test_growth_metrics_display(self) -> None:
+ """Verify growth and usage metrics are displayed."""
+ # Create an active subscriber
+ user2_id, _ = Core.Database.create_user("active@example.com")
+ Core.Database.update_user_subscription(
+ user2_id,
+ subscription_id="sub_test",
+ status="active",
+ period_start=datetime.now(timezone.utc),
+ period_end=datetime.now(timezone.utc),
+ tier="paid",
+ cancel_at_period_end=False,
+ )
+
+ # Create a queue item
+ Core.Database.add_to_queue(
+ "https://example.com/new",
+ "active@example.com",
+ user2_id,
+ )
+
+ # Get metrics page
+ response = self.client.get("/admin/metrics")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Growth &amp; Usage", response.text)
+ self.assertIn("Total Users", response.text)
+ self.assertIn("Active Subs", response.text)
+ self.assertIn("Submissions (24h)", response.text)
+
+ self.assertIn("Total Downloads", response.text)
+ self.assertIn("Total Adds", response.text)
+
+ def test_metrics_shows_top_episodes(self) -> None:
+ """Verify top episodes tables are displayed."""
+ # Create test episodes
+ episode1 = Core.Database.create_episode(
+ title="Popular Episode",
+ audio_url="http://example.com/popular.mp3",
+ content_length=1000,
+ duration=300,
+ author="Test Author",
+ )
+ Core.Database.add_episode_to_user(self.user_id, episode1)
+
+ episode2 = Core.Database.create_episode(
+ title="Less Popular Episode",
+ audio_url="http://example.com/less.mp3",
+ content_length=1000,
+ duration=300,
+ )
+ Core.Database.add_episode_to_user(self.user_id, episode2)
+
+ # Track events - more for episode1
+ for _ in range(5):
+ Core.Database.track_episode_event(episode1, "played")
+ for _ in range(2):
+ Core.Database.track_episode_event(episode2, "played")
+
+ for _ in range(3):
+ Core.Database.track_episode_event(episode1, "downloaded")
+ Core.Database.track_episode_event(episode2, "downloaded")
+
+ # Get metrics page
+ response = self.client.get("/admin/metrics")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Most Played", response.text)
+ self.assertIn("Most Downloaded", response.text)
+ self.assertIn("Popular Episode", response.text)
+
+ def test_metrics_empty_state(self) -> None:
+ """Verify metrics page works with no data."""
+ response = self.client.get("/admin/metrics")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Episode Metrics", response.text)
+ # Should show 0 for counts
+ self.assertIn("Total Episodes", response.text)
+
+
+class TestJobCancellation(BaseWebTest):
+ """Test job cancellation functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in user and pending job."""
+ super().setUp()
+
+ # Create and login user
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ )
+ Core.Database.update_user_status(
+ self.user_id,
+ "active",
+ )
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ # Create pending job
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com/test",
+ "test@example.com",
+ self.user_id,
+ )
+
+ def test_cancel_pending_job(self) -> None:
+ """Successfully cancel a pending job."""
+ response = self.client.post(f"/queue/{self.job_id}/cancel")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("HX-Trigger", response.headers)
+ self.assertEqual(response.headers["HX-Trigger"], "queue-updated")
+
+ # Verify job status is cancelled
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job is not None:
+ self.assertEqual(job["status"], "cancelled")
+ self.assertEqual(job.get("error_message", ""), "Cancelled by user")
+
+ def test_cannot_cancel_processing_job(self) -> None:
+ """Prevent cancelling jobs that are already processing."""
+ # Set job to processing
+ Core.Database.update_job_status(
+ self.job_id,
+ "processing",
+ )
+
+ response = self.client.post(f"/queue/{self.job_id}/cancel")
+
+ self.assertEqual(response.status_code, 400)
+ self.assertIn("Can only cancel pending jobs", response.text)
+
+ def test_cannot_cancel_completed_job(self) -> None:
+ """Prevent cancelling completed jobs."""
+ # Set job to completed
+ Core.Database.update_job_status(
+ self.job_id,
+ "completed",
+ )
+
+ response = self.client.post(f"/queue/{self.job_id}/cancel")
+
+ self.assertEqual(response.status_code, 400)
+
+ def test_cannot_cancel_other_users_job(self) -> None:
+ """Prevent users from cancelling other users' jobs."""
+ # Create another user's job
+ user2_id, _ = Core.Database.create_user(
+ "other@example.com",
+ )
+ other_job_id = Core.Database.add_to_queue(
+ "https://example.com/other",
+ "other@example.com",
+ user2_id,
+ )
+
+ # Try to cancel it
+ response = self.client.post(f"/queue/{other_job_id}/cancel")
+
+ self.assertEqual(response.status_code, 403)
+
+ def test_cancel_without_auth(self) -> None:
+ """Require authentication to cancel jobs."""
+ # Logout
+ self.client.get("/logout")
+
+ response = self.client.post(f"/queue/{self.job_id}/cancel")
+
+ self.assertEqual(response.status_code, 401)
+
+ def test_cancel_button_visibility(self) -> None:
+ """Cancel button only shows for pending jobs."""
+ # Create jobs with different statuses
+ processing_job = Core.Database.add_to_queue(
+ "https://example.com/processing",
+ "test@example.com",
+ self.user_id,
+ )
+ Core.Database.update_job_status(
+ processing_job,
+ "processing",
+ )
+
+ # Get status view
+ response = self.client.get("/status")
+
+ # Should have cancel button for pending job
+ self.assertIn(f'hx-post="/queue/{self.job_id}/cancel"', response.text)
+ self.assertIn("Cancel", response.text)
+
+ # Should NOT have cancel button for processing job
+ self.assertNotIn(
+ f'hx-post="/queue/{processing_job}/cancel"',
+ response.text,
+ )
+
+
+class TestEpisodeDetailPage(BaseWebTest):
+ """Test episode detail page functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with user and episode."""
+ super().setUp()
+
+ # Create user and episode
+ self.user_id, self.token = Core.Database.create_user(
+ "creator@example.com",
+ status="active",
+ )
+ self.episode_id = Core.Database.create_episode(
+ title="Test Episode",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=5000,
+ user_id=self.user_id,
+ author="Test Author",
+ original_url="https://example.com/article",
+ )
+ Core.Database.add_episode_to_user(self.user_id, self.episode_id)
+ self.episode_sqid = encode_episode_id(self.episode_id)
+
+ def test_episode_page_loads(self) -> None:
+ """Episode page should load successfully."""
+ response = self.client.get(f"/episode/{self.episode_sqid}")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Test Episode", response.text)
+ self.assertIn("Test Author", response.text)
+
+ def test_episode_not_found(self) -> None:
+ """Non-existent episode should return 404."""
+ response = self.client.get("/episode/invalidcode")
+
+ self.assertEqual(response.status_code, 404)
+
+ def test_audio_player_present(self) -> None:
+ """Audio player should be present on episode page."""
+ response = self.client.get(f"/episode/{self.episode_sqid}")
+
+ self.assertIn("<audio", response.text)
+ self.assertIn("controls", response.text)
+ self.assertIn("https://example.com/audio.mp3", response.text)
+
+ def test_share_button_present(self) -> None:
+ """Share button should be present."""
+ response = self.client.get(f"/episode/{self.episode_sqid}")
+
+ self.assertIn("Share Episode", response.text)
+ self.assertIn("navigator.clipboard.writeText", response.text)
+
+ def test_original_article_link(self) -> None:
+ """Original article link should be present."""
+ response = self.client.get(f"/episode/{self.episode_sqid}")
+
+ self.assertIn("View original article", response.text)
+ self.assertIn("https://example.com/article", response.text)
+
+ def test_signup_banner_for_non_authenticated(self) -> None:
+ """Non-authenticated users should see signup banner."""
+ response = self.client.get(f"/episode/{self.episode_sqid}")
+
+ self.assertIn("This episode was created by", response.text)
+ self.assertIn("creator@example.com", response.text)
+ self.assertIn("Sign Up", response.text)
+
+ def test_no_signup_banner_for_authenticated(self) -> None:
+ """Authenticated users should not see signup banner."""
+ # Login
+ self.client.post("/login", data={"email": "creator@example.com"})
+
+ response = self.client.get(f"/episode/{self.episode_sqid}")
+
+ self.assertNotIn("This episode was created by", response.text)
+
+ def test_episode_links_from_home_page(self) -> None:
+ """Episode titles on home page should link to detail page."""
+ # Login to see episodes
+ self.client.post("/login", data={"email": "creator@example.com"})
+
+ response = self.client.get("/")
+
+ self.assertIn(f'href="/episode/{self.episode_sqid}"', response.text)
+ self.assertIn("Test Episode", response.text)
+
+ def test_legacy_integer_id_redirects(self) -> None:
+ """Legacy integer episode IDs should redirect to sqid URLs."""
+ response = self.client.get(
+ f"/episode/{self.episode_id}",
+ follow_redirects=False,
+ )
+
+ self.assertEqual(response.status_code, 301)
+ self.assertEqual(
+ response.headers["location"],
+ f"/episode/{self.episode_sqid}",
+ )
+
+
+class TestPublicFeed(BaseWebTest):
+ """Test public feed functionality."""
+
+ def setUp(self) -> None:
+ """Set up test database, client, and create sample episodes."""
+ super().setUp()
+
+ # Create admin user
+ self.admin_id, _ = Core.Database.create_user(
+ "ben@bensima.com",
+ status="active",
+ )
+
+ # Create some episodes, some public, some private
+ self.public_episode_id = Core.Database.create_episode(
+ title="Public Episode",
+ audio_url="https://example.com/public.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.admin_id,
+ author="Test Author",
+ original_url="https://example.com/public",
+ original_url_hash=Core.hash_url("https://example.com/public"),
+ )
+ Core.Database.mark_episode_public(self.public_episode_id)
+
+ self.private_episode_id = Core.Database.create_episode(
+ title="Private Episode",
+ audio_url="https://example.com/private.mp3",
+ duration=200,
+ content_length=800,
+ user_id=self.admin_id,
+ author="Test Author",
+ original_url="https://example.com/private",
+ original_url_hash=Core.hash_url("https://example.com/private"),
+ )
+
+ def test_public_feed_page(self) -> None:
+ """Public feed page should show only public episodes."""
+ response = self.client.get("/public")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Public Episode", response.text)
+ self.assertNotIn("Private Episode", response.text)
+
+ def test_home_page_shows_public_feed_when_logged_out(self) -> None:
+ """Home page should show public episodes when user is not logged in."""
+ response = self.client.get("/")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Public Episode", response.text)
+ self.assertNotIn("Private Episode", response.text)
+
+ def test_admin_can_toggle_episode_public(self) -> None:
+ """Admin should be able to toggle episode public/private status."""
+ # Login as admin
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ # Toggle private episode to public
+ response = self.client.post(
+ f"/admin/episode/{self.private_episode_id}/toggle-public",
+ )
+
+ self.assertEqual(response.status_code, 200)
+
+ # Verify it's now public
+ episode = Core.Database.get_episode_by_id(self.private_episode_id)
+ self.assertEqual(episode["is_public"], 1) # type: ignore[index]
+
+ def test_non_admin_cannot_toggle_public(self) -> None:
+ """Non-admin users should not be able to toggle public status."""
+ # Create and login as regular user
+ _user_id, _ = Core.Database.create_user("user@example.com")
+ self.client.post("/login", data={"email": "user@example.com"})
+
+ # Try to toggle
+ response = self.client.post(
+ f"/admin/episode/{self.private_episode_id}/toggle-public",
+ )
+
+ self.assertEqual(response.status_code, 403)
+
+ def test_admin_can_add_user_episode_to_own_feed(self) -> None:
+ """Admin can add another user's episode to their own feed."""
+ # Create regular user and their episode
+ user_id, _ = Core.Database.create_user(
+ "user@example.com",
+ status="active",
+ )
+ user_episode_id = Core.Database.create_episode(
+ title="User Episode",
+ audio_url="https://example.com/user.mp3",
+ duration=400,
+ content_length=1200,
+ user_id=user_id,
+ author="User Author",
+ original_url="https://example.com/user-article",
+ original_url_hash=Core.hash_url("https://example.com/user-article"),
+ )
+ Core.Database.add_episode_to_user(user_id, user_episode_id)
+
+ # Login as admin
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ # Admin adds user's episode to their own feed
+ response = self.client.post(f"/episode/{user_episode_id}/add-to-feed")
+
+ self.assertEqual(response.status_code, 200)
+
+ # Verify episode is now in admin's feed
+ admin_episodes = Core.Database.get_user_episodes(self.admin_id)
+ episode_ids = [e["id"] for e in admin_episodes]
+ self.assertIn(user_episode_id, episode_ids)
+
+ # Verify "added" event was tracked
+ metrics = Core.Database.get_episode_metric_events(user_episode_id)
+ added_events = [m for m in metrics if m["event_type"] == "added"]
+ self.assertEqual(len(added_events), 1)
+ self.assertEqual(added_events[0]["user_id"], self.admin_id)
+
+ def test_admin_can_add_user_episode_to_public_feed(self) -> None:
+ """Admin should be able to add another user's episode to public feed."""
+ # Create regular user and their episode
+ user_id, _ = Core.Database.create_user(
+ "user@example.com",
+ status="active",
+ )
+ user_episode_id = Core.Database.create_episode(
+ title="User Episode for Public",
+ audio_url="https://example.com/user-public.mp3",
+ duration=500,
+ content_length=1500,
+ user_id=user_id,
+ author="User Author",
+ original_url="https://example.com/user-public-article",
+ original_url_hash=Core.hash_url(
+ "https://example.com/user-public-article",
+ ),
+ )
+ Core.Database.add_episode_to_user(user_id, user_episode_id)
+
+ # Verify episode is private initially
+ episode = Core.Database.get_episode_by_id(user_episode_id)
+ self.assertEqual(episode["is_public"], 0) # type: ignore[index]
+
+ # Login as admin
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ # Admin toggles episode to public
+ response = self.client.post(
+ f"/admin/episode/{user_episode_id}/toggle-public",
+ )
+
+ self.assertEqual(response.status_code, 200)
+
+ # Verify episode is now public
+ episode = Core.Database.get_episode_by_id(user_episode_id)
+ self.assertEqual(episode["is_public"], 1) # type: ignore[index]
+
+ # Verify episode appears in public feed
+ public_episodes = Core.Database.get_public_episodes()
+ episode_ids = [e["id"] for e in public_episodes]
+ self.assertIn(user_episode_id, episode_ids)
+
+
+class TestEpisodeDeduplication(BaseWebTest):
+ """Test episode deduplication functionality."""
+
+ def setUp(self) -> None:
+ """Set up test database, client, and create test user."""
+ super().setUp()
+
+ self.user_id, self.token = Core.Database.create_user(
+ "test@example.com",
+ status="active",
+ )
+
+ # Create an existing episode
+ self.existing_url = "https://example.com/article"
+ self.url_hash = Core.hash_url(self.existing_url)
+
+ self.episode_id = Core.Database.create_episode(
+ title="Existing Article",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test Author",
+ original_url=self.existing_url,
+ original_url_hash=self.url_hash,
+ )
+
+ def test_url_normalization(self) -> None:
+ """URLs should be normalized for deduplication."""
+ # Different URL variations that should be normalized to same hash
+ urls = [
+ "http://example.com/article",
+ "https://example.com/article",
+ "https://www.example.com/article",
+ "https://EXAMPLE.COM/article",
+ "https://example.com/article/",
+ ]
+
+ hashes = [Core.hash_url(url) for url in urls]
+
+ # All should produce the same hash
+ self.assertEqual(len(set(hashes)), 1)
+
+ def test_find_existing_episode_by_hash(self) -> None:
+ """Should find existing episode by normalized URL hash."""
+ # Try different URL variations
+ similar_urls = [
+ "http://example.com/article",
+ "https://www.example.com/article",
+ ]
+
+ for url in similar_urls:
+ url_hash = Core.hash_url(url)
+ episode = Core.Database.get_episode_by_url_hash(url_hash)
+
+ self.assertIsNotNone(episode)
+ if episode is not None:
+ self.assertEqual(episode["id"], self.episode_id)
+
+ def test_add_existing_episode_to_user_feed(self) -> None:
+ """Should add existing episode to new user's feed."""
+ # Create second user
+ user2_id, _ = Core.Database.create_user("user2@example.com")
+
+ # Add existing episode to their feed
+ Core.Database.add_episode_to_user(user2_id, self.episode_id)
+
+ # Verify it appears in their feed
+ episodes = Core.Database.get_user_episodes(user2_id)
+ episode_ids = [e["id"] for e in episodes]
+
+ self.assertIn(self.episode_id, episode_ids)
+
+
+class TestMetricsTracking(BaseWebTest):
+ """Test episode metrics tracking."""
+
+ def setUp(self) -> None:
+ """Set up test database, client, and create test episode."""
+ super().setUp()
+
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ status="active",
+ )
+
+ self.episode_id = Core.Database.create_episode(
+ title="Test Episode",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test Author",
+ original_url="https://example.com/article",
+ original_url_hash=Core.hash_url("https://example.com/article"),
+ )
+
+ def test_track_episode_added(self) -> None:
+ """Should track when episode is added to feed."""
+ Core.Database.track_episode_event(
+ self.episode_id,
+ "added",
+ self.user_id,
+ )
+
+ # Verify metric was recorded
+ metrics = Core.Database.get_episode_metric_events(self.episode_id)
+ self.assertEqual(len(metrics), 1)
+ self.assertEqual(metrics[0]["event_type"], "added")
+ self.assertEqual(metrics[0]["user_id"], self.user_id)
+
+ def test_track_episode_played(self) -> None:
+ """Should track when episode is played."""
+ Core.Database.track_episode_event(
+ self.episode_id,
+ "played",
+ self.user_id,
+ )
+
+ metrics = Core.Database.get_episode_metric_events(self.episode_id)
+ self.assertEqual(len(metrics), 1)
+ self.assertEqual(metrics[0]["event_type"], "played")
+
+ def test_track_anonymous_play(self) -> None:
+ """Should track plays from anonymous users."""
+ Core.Database.track_episode_event(
+ self.episode_id,
+ "played",
+ user_id=None,
+ )
+
+ metrics = Core.Database.get_episode_metric_events(self.episode_id)
+ self.assertEqual(len(metrics), 1)
+ self.assertEqual(metrics[0]["event_type"], "played")
+ self.assertIsNone(metrics[0]["user_id"])
+
+ def test_track_endpoint(self) -> None:
+ """POST /episode/{id}/track should record metrics."""
+ # Login as user
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ response = self.client.post(
+ f"/episode/{self.episode_id}/track",
+ data={"event_type": "played"},
+ )
+
+ self.assertEqual(response.status_code, 200)
+
+ # Verify metric was recorded
+ metrics = Core.Database.get_episode_metric_events(self.episode_id)
+ played_metrics = [m for m in metrics if m["event_type"] == "played"]
+ self.assertGreater(len(played_metrics), 0)
+
+
+class TestUsageLimits(BaseWebTest):
+ """Test usage tracking and limit enforcement."""
+
+ def setUp(self) -> None:
+ """Set up test with free tier user."""
+ super().setUp()
+
+ # Create free tier user
+ self.user_id, self.token = Core.Database.create_user(
+ "free@example.com",
+ status="active",
+ )
+ # Login
+ self.client.post("/login", data={"email": "free@example.com"})
+
+ def test_usage_counts_episodes_added_to_feed(self) -> None:
+ """Usage should count episodes added via user_episodes table."""
+ user = Core.Database.get_user_by_id(self.user_id)
+ self.assertIsNotNone(user)
+ assert user is not None # type narrowing # noqa: S101
+ period_start, period_end = Billing.get_period_boundaries(user)
+
+ # Initially no usage
+ usage = Billing.get_usage(self.user_id, period_start, period_end)
+ self.assertEqual(usage["articles"], 0)
+
+ # Add an episode to user's feed
+ ep_id = Core.Database.create_episode(
+ title="Test Episode",
+ audio_url="https://example.com/test.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test",
+ original_url="https://example.com/article",
+ original_url_hash=Core.hash_url("https://example.com/article"),
+ )
+ Core.Database.add_episode_to_user(self.user_id, ep_id)
+
+ # Usage should now be 1
+ usage = Billing.get_usage(self.user_id, period_start, period_end)
+ self.assertEqual(usage["articles"], 1)
+
+ def test_usage_counts_existing_episodes_correctly(self) -> None:
+ """Adding existing episodes should count toward usage."""
+ # Create another user who creates an episode
+ other_user_id, _ = Core.Database.create_user("other@example.com")
+ ep_id = Core.Database.create_episode(
+ title="Other User Episode",
+ audio_url="https://example.com/other.mp3",
+ duration=400,
+ content_length=1200,
+ user_id=other_user_id,
+ author="Other",
+ original_url="https://example.com/other-article",
+ original_url_hash=Core.hash_url(
+ "https://example.com/other-article",
+ ),
+ )
+ Core.Database.add_episode_to_user(other_user_id, ep_id)
+
+ # Free user adds it to their feed
+ Core.Database.add_episode_to_user(self.user_id, ep_id)
+
+ # Check usage for free user
+ user = Core.Database.get_user_by_id(self.user_id)
+ self.assertIsNotNone(user)
+ assert user is not None # type narrowing # noqa: S101
+ period_start, period_end = Billing.get_period_boundaries(user)
+ usage = Billing.get_usage(self.user_id, period_start, period_end)
+
+ # Should count as 1 article for free user
+ self.assertEqual(usage["articles"], 1)
+
+ def test_free_tier_limit_enforcement(self) -> None:
+ """Free tier users should be blocked at 10 articles."""
+ # Add 10 episodes (the free tier limit)
+ for i in range(10):
+ ep_id = Core.Database.create_episode(
+ title=f"Episode {i}",
+ audio_url=f"https://example.com/ep{i}.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test",
+ original_url=f"https://example.com/article{i}",
+ original_url_hash=Core.hash_url(
+ f"https://example.com/article{i}",
+ ),
+ )
+ Core.Database.add_episode_to_user(self.user_id, ep_id)
+
+ # Try to submit 11th article
+ response = self.client.post(
+ "/submit",
+ data={"url": "https://example.com/article11"},
+ )
+
+ # Should be blocked
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Limit reached", response.text)
+ self.assertIn("10", response.text)
+ self.assertIn("Upgrade", response.text)
+
+ def test_can_submit_blocks_at_limit(self) -> None:
+ """can_submit should return False at limit."""
+ # Add 10 episodes
+ for i in range(10):
+ ep_id = Core.Database.create_episode(
+ title=f"Episode {i}",
+ audio_url=f"https://example.com/ep{i}.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test",
+ original_url=f"https://example.com/article{i}",
+ original_url_hash=Core.hash_url(
+ f"https://example.com/article{i}",
+ ),
+ )
+ Core.Database.add_episode_to_user(self.user_id, ep_id)
+
+ # Check can_submit
+ allowed, msg, usage = Billing.can_submit(self.user_id)
+
+ self.assertFalse(allowed)
+ self.assertIn("10", msg)
+ self.assertIn("limit", msg.lower())
+ self.assertEqual(usage["articles"], 10)
+
+ def test_paid_tier_unlimited(self) -> None:
+ """Paid tier should have no article limits."""
+ # Create a paid tier user directly
+ paid_user_id, _ = Core.Database.create_user("paid@example.com")
+
+ # Simulate paid subscription via update_user_subscription
+ now = datetime.now(timezone.utc)
+ period_start = now
+ december = 12
+ january = 1
+ period_end = now.replace(
+ month=now.month + 1 if now.month < december else january,
+ )
+
+ Core.Database.update_user_subscription(
+ paid_user_id,
+ subscription_id="sub_test123",
+ status="active",
+ period_start=period_start,
+ period_end=period_end,
+ tier="paid",
+ cancel_at_period_end=False,
+ )
+
+ # Add 20 episodes (more than free limit)
+ for i in range(20):
+ ep_id = Core.Database.create_episode(
+ title=f"Episode {i}",
+ audio_url=f"https://example.com/ep{i}.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=paid_user_id,
+ author="Test",
+ original_url=f"https://example.com/article{i}",
+ original_url_hash=Core.hash_url(
+ f"https://example.com/article{i}",
+ ),
+ )
+ Core.Database.add_episode_to_user(paid_user_id, ep_id)
+
+ # Should still be allowed to submit
+ allowed, msg, usage = Billing.can_submit(paid_user_id)
+
+ self.assertTrue(allowed)
+ self.assertEqual(msg, "")
+ self.assertEqual(usage["articles"], 20)
+
+
+class TestAccountPage(BaseWebTest):
+ """Test account page functionality."""
+
+ def setUp(self) -> None:
+ """Set up test with user."""
+ super().setUp()
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ status="active",
+ )
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ def test_account_page_logged_in(self) -> None:
+ """Account page should render for logged-in users."""
+ # Create some usage to verify stats are shown
+ ep_id = Core.Database.create_episode(
+ title="Test Episode",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test Author",
+ original_url="https://example.com/article",
+ original_url_hash=Core.hash_url("https://example.com/article"),
+ )
+ Core.Database.add_episode_to_user(self.user_id, ep_id)
+
+ response = self.client.get("/account")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("My Account", response.text)
+ self.assertIn("test@example.com", response.text)
+ self.assertIn("1 / 10", response.text) # Usage / Limit for free tier
+
+ def test_account_page_login_required(self) -> None:
+ """Should redirect to login if not logged in."""
+ self.client.post("/logout")
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+ self.assertEqual(response.headers["location"], "/?error=login_required")
+
+ def test_logout(self) -> None:
+ """Logout should clear session."""
+ response = self.client.post("/logout", follow_redirects=False)
+ self.assertEqual(response.status_code, 303)
+ self.assertEqual(response.headers["location"], "/")
+
+ # Verify session cleared
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+
+ def test_billing_portal_redirect(self) -> None:
+ """Billing portal should redirect to Stripe."""
+ # First set a customer ID
+ Core.Database.set_user_stripe_customer(self.user_id, "cus_test")
+
+ # Mock the create_portal_session method
+ with patch(
+ "Biz.PodcastItLater.Billing.create_portal_session",
+ ) as mock_portal:
+ mock_portal.return_value = "https://billing.stripe.com/test"
+
+ response = self.client.post(
+ "/billing/portal",
+ follow_redirects=False,
+ )
+
+ self.assertEqual(response.status_code, 303)
+ self.assertEqual(
+ response.headers["location"],
+ "https://billing.stripe.com/test",
+ )
+
+ def test_update_email_success(self) -> None:
+ """Should allow updating email."""
+ # POST new email
+ response = self.client.post(
+ "/settings/email",
+ data={"email": "new@example.com"},
+ )
+ self.assertEqual(response.status_code, 200)
+
+ # Verify update in DB
+ user = Core.Database.get_user_by_id(self.user_id)
+ self.assertEqual(user["email"], "new@example.com") # type: ignore[index]
+
+ def test_update_email_duplicate(self) -> None:
+ """Should prevent updating to existing email."""
+ # Create another user
+ Core.Database.create_user("other@example.com")
+
+ # Try to update to their email
+ response = self.client.post(
+ "/settings/email",
+ data={"email": "other@example.com"},
+ )
+
+ # Should show error (return 200 with error message in form)
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("already taken", response.text.lower())
+
+ def test_delete_account(self) -> None:
+ """Should allow user to delete their account."""
+ # Delete account
+ response = self.client.delete("/account")
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("HX-Redirect", response.headers)
+
+ # Verify user gone
+ user = Core.Database.get_user_by_id(self.user_id)
+ self.assertIsNone(user)
+
+ # Verify session cleared
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+
+
+class TestAdminUsers(BaseWebTest):
+ """Test admin user management functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in admin user."""
+ super().setUp()
+
+ # Create and login admin user
+ self.user_id, _ = Core.Database.create_user(
+ "ben@bensima.com",
+ )
+ Core.Database.update_user_status(
+ self.user_id,
+ "active",
+ )
+ self.client.post("/login", data={"email": "ben@bensima.com"})
+
+ # Create another regular user
+ self.other_user_id, _ = Core.Database.create_user("user@example.com")
+ Core.Database.update_user_status(self.other_user_id, "active")
+
+ def test_admin_users_page_access(self) -> None:
+ """Admin can access users page."""
+ response = self.client.get("/admin/users")
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("User Management", response.text)
+ self.assertIn("user@example.com", response.text)
+
+ def test_non_admin_users_page_access(self) -> None:
+ """Non-admin cannot access users page."""
+ # Login as regular user
+ self.client.get("/logout")
+ self.client.post("/login", data={"email": "user@example.com"})
+
+ response = self.client.get("/admin/users")
+ self.assertEqual(response.status_code, 302)
+ self.assertIn("error=forbidden", response.headers["Location"])
+
+ def test_admin_can_update_user_status(self) -> None:
+ """Admin can update user status."""
+ response = self.client.post(
+ f"/admin/users/{self.other_user_id}/status",
+ data={"status": "disabled"},
+ )
+ self.assertEqual(response.status_code, 200)
+
+ user = Core.Database.get_user_by_id(self.other_user_id)
+ assert user is not None # noqa: S101
+ self.assertEqual(user["status"], "disabled")
+
+ def test_non_admin_cannot_update_user_status(self) -> None:
+ """Non-admin cannot update user status."""
+ # Login as regular user
+ self.client.get("/logout")
+ self.client.post("/login", data={"email": "user@example.com"})
+
+ response = self.client.post(
+ f"/admin/users/{self.other_user_id}/status",
+ data={"status": "disabled"},
+ )
+ self.assertEqual(response.status_code, 403)
+
+ user = Core.Database.get_user_by_id(self.other_user_id)
+ assert user is not None # noqa: S101
+ self.assertEqual(user["status"], "active")
+
+ def test_update_user_status_invalid_status(self) -> None:
+ """Invalid status validation."""
+ response = self.client.post(
+ f"/admin/users/{self.other_user_id}/status",
+ data={"status": "invalid_status"},
+ )
+ self.assertEqual(response.status_code, 400)
+
+ user = Core.Database.get_user_by_id(self.other_user_id)
+ assert user is not None # noqa: S101
+ self.assertEqual(user["status"], "active")
+
+
+def test() -> None:
+ """Run all tests for the web module."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestDurationFormatting,
+ TestAuthentication,
+ TestArticleSubmission,
+ TestRSSFeed,
+ TestAdminInterface,
+ TestJobCancellation,
+ TestEpisodeDetailPage,
+ TestPublicFeed,
+ TestEpisodeDeduplication,
+ TestMetricsTracking,
+ TestUsageLimits,
+ TestAccountPage,
+ TestAdminUsers,
+ ],
+ )
+
+
+def main() -> None:
+ """Run the web server."""
+ if "test" in sys.argv:
+ test()
+ else:
+ # Initialize database on startup
+ Core.Database.init_db()
+ uvicorn.run(app, host="0.0.0.0", port=PORT) # noqa: S104
diff --git a/Biz/PodcastItLater/Worker.nix b/Biz/PodcastItLater/Worker.nix
new file mode 100644
index 0000000..974a3ba
--- /dev/null
+++ b/Biz/PodcastItLater/Worker.nix
@@ -0,0 +1,63 @@
+{
+ options,
+ lib,
+ config,
+ pkgs,
+ ...
+}: let
+ cfg = config.services.podcastitlater-worker;
+in {
+ options.services.podcastitlater-worker = {
+ enable = lib.mkEnableOption "Enable the PodcastItLater worker service";
+ dataDir = lib.mkOption {
+ type = lib.types.path;
+ default = "/var/podcastitlater";
+ description = "Data directory for PodcastItLater (shared with web)";
+ };
+ package = lib.mkOption {
+ type = lib.types.package;
+ description = "PodcastItLater worker package to use";
+ };
+ };
+ config = lib.mkIf cfg.enable {
+ systemd.services.podcastitlater-worker = {
+ path = [cfg.package pkgs.ffmpeg]; # ffmpeg needed for pydub
+ wantedBy = ["multi-user.target"];
+ after = ["network.target"];
+ preStart = ''
+ # Create data directory if it doesn't exist
+ mkdir -p ${cfg.dataDir}
+
+ # Manual step: create this file with secrets
+ # OPENAI_API_KEY=your-openai-api-key
+ # S3_ENDPOINT=https://your-s3-endpoint.digitaloceanspaces.com
+ # S3_BUCKET=your-bucket-name
+ # S3_ACCESS_KEY=your-s3-access-key
+ # S3_SECRET_KEY=your-s3-secret-key
+ test -f /run/podcastitlater/worker-env
+ '';
+ script = ''
+ ${cfg.package}/bin/podcastitlater-worker
+ '';
+ description = ''
+ PodcastItLater Worker Service - processes articles to podcasts
+ '';
+ serviceConfig = {
+ Environment = [
+ "AREA=Live"
+ "DATA_DIR=${cfg.dataDir}"
+ ];
+ EnvironmentFile = "/run/podcastitlater/worker-env";
+ KillSignal = "TERM";
+ KillMode = "mixed";
+ Type = "simple";
+ Restart = "always";
+ RestartSec = "10";
+ # Give the worker time to finish current job
+ TimeoutStopSec = "300"; # 5 minutes
+ # Send SIGTERM first, then SIGKILL after timeout
+ SendSIGKILL = "yes";
+ };
+ };
+ };
+}
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py
new file mode 100644
index 0000000..bf6ef9e
--- /dev/null
+++ b/Biz/PodcastItLater/Worker.py
@@ -0,0 +1,2199 @@
+"""Background worker for processing article-to-podcast conversions."""
+
+# : dep boto3
+# : dep botocore
+# : dep openai
+# : dep psutil
+# : dep pydub
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+# : dep trafilatura
+# : out podcastitlater-worker
+# : run ffmpeg
+import Biz.PodcastItLater.Core as Core
+import boto3 # type: ignore[import-untyped]
+import concurrent.futures
+import io
+import json
+import logging
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import openai
+import operator
+import os
+import psutil # type: ignore[import-untyped]
+import pytest
+import signal
+import sys
+import tempfile
+import threading
+import time
+import trafilatura
+import typing
+import unittest.mock
+from botocore.exceptions import ClientError # type: ignore[import-untyped]
+from datetime import datetime
+from datetime import timedelta
+from datetime import timezone
+from pathlib import Path
+from pydub import AudioSegment # type: ignore[import-untyped]
+from typing import Any
+
+logger = logging.getLogger(__name__)
+Log.setup(logger)
+
+# Configuration from environment variables
+OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
+S3_ENDPOINT = os.getenv("S3_ENDPOINT")
+S3_BUCKET = os.getenv("S3_BUCKET")
+S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY")
+S3_SECRET_KEY = os.getenv("S3_SECRET_KEY")
+area = App.from_env()
+
+# Worker configuration
+MAX_CONTENT_LENGTH = 5000 # characters for TTS
+MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles
+POLL_INTERVAL = 30 # seconds
+MAX_RETRIES = 3
+TTS_MODEL = "tts-1"
+TTS_VOICE = "alloy"
+MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage
+CROSSFADE_DURATION = 500 # ms for crossfading segments
+PAUSE_DURATION = 1000 # ms for silence between segments
+
+
+class ShutdownHandler:
+ """Handles graceful shutdown of the worker."""
+
+ def __init__(self) -> None:
+ """Initialize shutdown handler."""
+ self.shutdown_requested = threading.Event()
+ self.current_job_id: int | None = None
+ self.lock = threading.Lock()
+
+ # Register signal handlers
+ signal.signal(signal.SIGTERM, self._handle_signal)
+ signal.signal(signal.SIGINT, self._handle_signal)
+
+ def _handle_signal(self, signum: int, _frame: Any) -> None:
+ """Handle shutdown signals."""
+ logger.info(
+ "Received signal %d, initiating graceful shutdown...",
+ signum,
+ )
+ self.shutdown_requested.set()
+
+ def is_shutdown_requested(self) -> bool:
+ """Check if shutdown has been requested."""
+ return self.shutdown_requested.is_set()
+
+ def set_current_job(self, job_id: int | None) -> None:
+ """Set the currently processing job."""
+ with self.lock:
+ self.current_job_id = job_id
+
+ def get_current_job(self) -> int | None:
+ """Get the currently processing job."""
+ with self.lock:
+ return self.current_job_id
+
+
+class ArticleProcessor:
+ """Handles the complete article-to-podcast conversion pipeline."""
+
+ def __init__(self, shutdown_handler: ShutdownHandler) -> None:
+ """Initialize the processor with required services.
+
+ Raises:
+ ValueError: If OPENAI_API_KEY environment variable is not set.
+ """
+ if not OPENAI_API_KEY:
+ msg = "OPENAI_API_KEY environment variable is required"
+ raise ValueError(msg)
+
+ self.openai_client: openai.OpenAI = openai.OpenAI(
+ api_key=OPENAI_API_KEY,
+ )
+ self.shutdown_handler = shutdown_handler
+
+ # Initialize S3 client for Digital Ocean Spaces
+ if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]):
+ self.s3_client: Any = boto3.client(
+ "s3",
+ endpoint_url=S3_ENDPOINT,
+ aws_access_key_id=S3_ACCESS_KEY,
+ aws_secret_access_key=S3_SECRET_KEY,
+ )
+ else:
+ logger.warning("S3 configuration incomplete, uploads will fail")
+ self.s3_client = None
+
+ @staticmethod
+ def extract_article_content(
+ url: str,
+ ) -> tuple[str, str, str | None, str | None]:
+ """Extract title, content, author, and date from article URL.
+
+ Returns:
+ tuple: (title, content, author, publication_date)
+
+ Raises:
+ ValueError: If content cannot be downloaded, extracted, or large.
+ """
+ try:
+ downloaded = trafilatura.fetch_url(url)
+ if not downloaded:
+ msg = f"Failed to download content from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Check size before processing
+ if (
+ len(downloaded) > MAX_ARTICLE_SIZE * 4
+ ): # Rough HTML to text ratio
+ msg = f"Article too large: {len(downloaded)} bytes"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Extract with metadata
+ result = trafilatura.extract(
+ downloaded,
+ include_comments=False,
+ include_tables=False,
+ with_metadata=True,
+ output_format="json",
+ )
+
+ if not result:
+ msg = f"Failed to extract content from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ data = json.loads(result)
+
+ title = data.get("title", "Untitled Article")
+ content = data.get("text", "")
+ author = data.get("author")
+ pub_date = data.get("date")
+
+ if not content:
+ msg = f"No content extracted from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Enforce content size limit
+ if len(content) > MAX_ARTICLE_SIZE:
+ logger.warning(
+ "Article content truncated from %d to %d characters",
+ len(content),
+ MAX_ARTICLE_SIZE,
+ )
+ content = content[:MAX_ARTICLE_SIZE]
+
+ logger.info(
+ "Extracted article: %s (%d chars, author: %s, date: %s)",
+ title,
+ len(content),
+ author or "unknown",
+ pub_date or "unknown",
+ )
+ except Exception:
+ logger.exception("Failed to extract content from %s", url)
+ raise
+ else:
+ return title, content, author, pub_date
+
+ def text_to_speech(
+ self,
+ text: str,
+ title: str,
+ author: str | None = None,
+ pub_date: str | None = None,
+ ) -> bytes:
+ """Convert text to speech with intro/outro using OpenAI TTS API.
+
+ Uses parallel processing for chunks while maintaining order.
+ Adds intro with metadata and outro with attribution.
+
+ Args:
+ text: Article content to convert
+ title: Article title
+ author: Article author (optional)
+ pub_date: Publication date (optional)
+
+ Raises:
+ ValueError: If no chunks are generated from text.
+ """
+ try:
+ # Generate intro audio
+ intro_text = self._create_intro_text(title, author, pub_date)
+ intro_audio = self._generate_tts_segment(intro_text)
+
+ # Generate outro audio
+ outro_text = self._create_outro_text(title, author)
+ outro_audio = self._generate_tts_segment(outro_text)
+
+ # Use LLM to prepare and chunk the main content
+ chunks = prepare_text_for_tts(text, title)
+
+ if not chunks:
+ msg = "No chunks generated from text"
+ raise ValueError(msg) # noqa: TRY301
+
+ logger.info("Processing %d chunks for TTS", len(chunks))
+
+ # Check memory before parallel processing
+ mem_usage = check_memory_usage()
+ if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer
+ logger.warning(
+ "High memory usage (%.1f%%), falling back to serial "
+ "processing",
+ mem_usage,
+ )
+ content_audio_bytes = self._text_to_speech_serial(chunks)
+ else:
+ # Determine max workers
+ max_workers = min(
+ 4, # Reasonable limit to avoid rate limiting
+ len(chunks), # No more workers than chunks
+ max(1, psutil.cpu_count() // 2), # Use half of CPU cores
+ )
+
+ logger.info(
+ "Using %d workers for parallel TTS processing",
+ max_workers,
+ )
+
+ # Process chunks in parallel
+ chunk_results: list[tuple[int, bytes]] = []
+
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=max_workers,
+ ) as executor:
+ # Submit all chunks for processing
+ future_to_index = {
+ executor.submit(self._process_tts_chunk, chunk, i): i
+ for i, chunk in enumerate(chunks)
+ }
+
+ # Collect results as they complete
+ for future in concurrent.futures.as_completed(
+ future_to_index,
+ ):
+ index = future_to_index[future]
+ try:
+ audio_data = future.result()
+ chunk_results.append((index, audio_data))
+ except Exception:
+ logger.exception(
+ "Failed to process chunk %d",
+ index,
+ )
+ raise
+
+ # Sort results by index to maintain order
+ chunk_results.sort(key=operator.itemgetter(0))
+
+ # Combine audio chunks
+ content_audio_bytes = self._combine_audio_chunks([
+ data for _, data in chunk_results
+ ])
+
+ # Combine intro, content, and outro with pauses
+ return ArticleProcessor._combine_intro_content_outro(
+ intro_audio,
+ content_audio_bytes,
+ outro_audio,
+ )
+
+ except Exception:
+ logger.exception("TTS generation failed")
+ raise
+
+ @staticmethod
+ def _create_intro_text(
+ title: str,
+ author: str | None,
+ pub_date: str | None,
+ ) -> str:
+ """Create intro text with available metadata."""
+ parts = [f"Title: {title}"]
+
+ if author:
+ parts.append(f"Author: {author}")
+
+ if pub_date:
+ parts.append(f"Published: {pub_date}")
+
+ return ". ".join(parts) + "."
+
+ @staticmethod
+ def _create_outro_text(title: str, author: str | None) -> str:
+ """Create outro text with attribution."""
+ if author:
+ return (
+ f"This has been an audio version of {title} "
+ f"by {author}, created using Podcast It Later."
+ )
+ return (
+ f"This has been an audio version of {title}, "
+ "created using Podcast It Later."
+ )
+
+ def _generate_tts_segment(self, text: str) -> bytes:
+ """Generate TTS audio for a single segment (intro/outro).
+
+ Args:
+ text: Text to convert to speech
+
+ Returns:
+ MP3 audio bytes
+ """
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=text,
+ )
+ return response.content
+
+ @staticmethod
+ def _combine_intro_content_outro(
+ intro_audio: bytes,
+ content_audio: bytes,
+ outro_audio: bytes,
+ ) -> bytes:
+ """Combine intro, content, and outro with crossfades.
+
+ Args:
+ intro_audio: MP3 bytes for intro
+ content_audio: MP3 bytes for main content
+ outro_audio: MP3 bytes for outro
+
+ Returns:
+ Combined MP3 audio bytes
+ """
+ # Load audio segments
+ intro = AudioSegment.from_mp3(io.BytesIO(intro_audio))
+ content = AudioSegment.from_mp3(io.BytesIO(content_audio))
+ outro = AudioSegment.from_mp3(io.BytesIO(outro_audio))
+
+ # Create bridge silence (pause + 2 * crossfade to account for overlap)
+ bridge = AudioSegment.silent(
+ duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION
+ )
+
+ def safe_append(
+ seg1: AudioSegment, seg2: AudioSegment, crossfade: int
+ ) -> AudioSegment:
+ if len(seg1) < crossfade or len(seg2) < crossfade:
+ logger.warning(
+ "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation",
+ crossfade,
+ len(seg1),
+ len(seg2),
+ )
+ return seg1 + seg2
+ return seg1.append(seg2, crossfade=crossfade)
+
+ # Combine segments with crossfades
+ # Intro -> Bridge -> Content -> Bridge -> Outro
+ # This effectively fades out the previous segment and fades in the next one
+ combined = safe_append(intro, bridge, CROSSFADE_DURATION)
+ combined = safe_append(combined, content, CROSSFADE_DURATION)
+ combined = safe_append(combined, bridge, CROSSFADE_DURATION)
+ combined = safe_append(combined, outro, CROSSFADE_DURATION)
+
+ # Export to bytes
+ output = io.BytesIO()
+ combined.export(output, format="mp3")
+ return output.getvalue()
+
+ def _process_tts_chunk(self, chunk: str, index: int) -> bytes:
+ """Process a single TTS chunk.
+
+ Args:
+ chunk: Text to convert to speech
+ index: Chunk index for logging
+
+ Returns:
+ Audio data as bytes
+ """
+ logger.info(
+ "Generating TTS for chunk %d (%d chars)",
+ index + 1,
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ return response.content
+
+ @staticmethod
+ def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes:
+ """Combine multiple audio chunks with silence gaps.
+
+ Args:
+ audio_chunks: List of audio data in order
+
+ Returns:
+ Combined audio data
+ """
+ if not audio_chunks:
+ msg = "No audio chunks to combine"
+ raise ValueError(msg)
+
+ # Create a temporary file for the combined audio
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+
+ try:
+ # Start with the first chunk
+ combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0]))
+
+ # Add remaining chunks with silence gaps
+ for chunk_data in audio_chunks[1:]:
+ chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data))
+ silence = AudioSegment.silent(duration=300) # 300ms gap
+ combined_audio = combined_audio + silence + chunk_audio
+
+ # Export to file
+ combined_audio.export(temp_path, format="mp3", bitrate="128k")
+
+ # Read back the combined audio
+ return Path(temp_path).read_bytes()
+
+ finally:
+ # Clean up temp file
+ if Path(temp_path).exists():
+ Path(temp_path).unlink()
+
+ def _text_to_speech_serial(self, chunks: list[str]) -> bytes:
+ """Fallback serial processing for high memory situations.
+
+ This is the original serial implementation.
+ """
+ # Create a temporary file for streaming audio concatenation
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+
+ try:
+ # Process first chunk
+ logger.info("Generating TTS for chunk 1/%d", len(chunks))
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunks[0],
+ response_format="mp3",
+ )
+
+ # Write first chunk directly to file
+ Path(temp_path).write_bytes(response.content)
+
+ # Process remaining chunks
+ for i, chunk in enumerate(chunks[1:], 1):
+ logger.info(
+ "Generating TTS for chunk %d/%d (%d chars)",
+ i + 1,
+ len(chunks),
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ # Append to existing file with silence gap
+ # Load only the current segment
+ current_segment = AudioSegment.from_mp3(
+ io.BytesIO(response.content),
+ )
+
+ # Load existing audio, append, and save back
+ existing_audio = AudioSegment.from_mp3(temp_path)
+ silence = AudioSegment.silent(duration=300)
+ combined = existing_audio + silence + current_segment
+
+ # Export back to the same file
+ combined.export(temp_path, format="mp3", bitrate="128k")
+
+ # Force garbage collection to free memory
+ del existing_audio, current_segment, combined
+
+ # Small delay between API calls
+ if i < len(chunks) - 1:
+ time.sleep(0.5)
+
+ # Read final result
+ audio_data = Path(temp_path).read_bytes()
+
+ logger.info(
+ "Generated combined TTS audio: %d bytes",
+ len(audio_data),
+ )
+ return audio_data
+
+ finally:
+ # Clean up temp file
+ temp_file_path = Path(temp_path)
+ if temp_file_path.exists():
+ temp_file_path.unlink()
+
+ def upload_to_s3(self, audio_data: bytes, filename: str) -> str:
+ """Upload audio file to S3-compatible storage and return public URL.
+
+ Raises:
+ ValueError: If S3 client is not configured.
+ ClientError: If S3 upload fails.
+ """
+ if not self.s3_client:
+ msg = "S3 client not configured"
+ raise ValueError(msg)
+
+ try:
+ # Upload file using streaming to minimize memory usage
+ audio_stream = io.BytesIO(audio_data)
+ self.s3_client.upload_fileobj(
+ audio_stream,
+ S3_BUCKET,
+ filename,
+ ExtraArgs={
+ "ContentType": "audio/mpeg",
+ "ACL": "public-read",
+ },
+ )
+
+ # Construct public URL
+ audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}"
+ logger.info(
+ "Uploaded audio to: %s (%d bytes)",
+ audio_url,
+ len(audio_data),
+ )
+ except ClientError:
+ logger.exception("S3 upload failed")
+ raise
+ else:
+ return audio_url
+
+ @staticmethod
+ def estimate_duration(audio_data: bytes) -> int:
+ """Estimate audio duration in seconds based on file size and bitrate."""
+ # Rough estimation: MP3 at 128kbps = ~16KB per second
+ estimated_seconds = len(audio_data) // 16000
+ return max(1, estimated_seconds) # Minimum 1 second
+
+ @staticmethod
+ def generate_filename(job_id: int, title: str) -> str:
+ """Generate unique filename for audio file."""
+ timestamp = int(datetime.now(tz=timezone.utc).timestamp())
+ # Create safe filename from title
+ safe_title = "".join(
+ c for c in title if c.isalnum() or c in {" ", "-", "_"}
+ ).rstrip()
+ safe_title = safe_title.replace(" ", "_")[:50] # Limit length
+ return f"episode_{timestamp}_{job_id}_{safe_title}.mp3"
+
+ def process_job(
+ self,
+ job: dict[str, Any],
+ ) -> None:
+ """Process a single job through the complete pipeline."""
+ job_id = job["id"]
+ url = job["url"]
+
+ # Check memory before starting
+ mem_usage = check_memory_usage()
+ if mem_usage > MEMORY_THRESHOLD:
+ logger.warning(
+ "High memory usage (%.1f%%), deferring job %d",
+ mem_usage,
+ job_id,
+ )
+ return
+
+ # Track current job for graceful shutdown
+ self.shutdown_handler.set_current_job(job_id)
+
+ try:
+ logger.info("Processing job %d: %s", job_id, url)
+
+ # Update status to processing
+ Core.Database.update_job_status(
+ job_id,
+ "processing",
+ )
+
+ # Check for shutdown before each major step
+ if self.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, aborting job %d", job_id)
+ Core.Database.update_job_status(job_id, "pending")
+ return
+
+ # Step 1: Extract article content
+ Core.Database.update_job_status(job_id, "extracting")
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(url)
+ )
+
+ if self.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, aborting job %d", job_id)
+ Core.Database.update_job_status(job_id, "pending")
+ return
+
+ # Step 2: Generate audio with metadata
+ Core.Database.update_job_status(job_id, "synthesizing")
+ audio_data = self.text_to_speech(content, title, author, pub_date)
+
+ if self.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, aborting job %d", job_id)
+ Core.Database.update_job_status(job_id, "pending")
+ return
+
+ # Step 3: Upload to S3
+ Core.Database.update_job_status(job_id, "uploading")
+ filename = ArticleProcessor.generate_filename(job_id, title)
+ audio_url = self.upload_to_s3(audio_data, filename)
+
+ # Step 4: Calculate duration
+ duration = ArticleProcessor.estimate_duration(audio_data)
+
+ # Step 5: Create episode record
+ url_hash = Core.hash_url(url)
+ episode_id = Core.Database.create_episode(
+ title=title,
+ audio_url=audio_url,
+ duration=duration,
+ content_length=len(content),
+ user_id=job.get("user_id"),
+ author=job.get("author"), # Pass author from job
+ original_url=url, # Pass the original article URL
+ original_url_hash=url_hash,
+ )
+
+ # Add episode to user's feed
+ user_id = job.get("user_id")
+ if user_id:
+ Core.Database.add_episode_to_user(user_id, episode_id)
+ Core.Database.track_episode_event(
+ episode_id,
+ "added",
+ user_id,
+ )
+
+ # Step 6: Mark job as complete
+ Core.Database.update_job_status(
+ job_id,
+ "completed",
+ )
+
+ logger.info(
+ "Successfully processed job %d -> episode %d",
+ job_id,
+ episode_id,
+ )
+
+ except Exception as e:
+ error_msg = str(e)
+ logger.exception("Job %d failed: %s", job_id, error_msg)
+ Core.Database.update_job_status(
+ job_id,
+ "error",
+ error_msg,
+ )
+ raise
+ finally:
+ # Clear current job
+ self.shutdown_handler.set_current_job(None)
+
+
+def prepare_text_for_tts(text: str, title: str) -> list[str]:
+ """Use LLM to prepare text for TTS, returning chunks ready for speech.
+
+ First splits text mechanically, then has LLM edit each chunk.
+ """
+ # First, split the text into manageable chunks
+ raw_chunks = split_text_into_chunks(text, max_chars=3000)
+
+ logger.info("Split article into %d raw chunks", len(raw_chunks))
+
+ # Prepare the first chunk with intro
+ edited_chunks = []
+
+ for i, chunk in enumerate(raw_chunks):
+ is_first = i == 0
+ is_last = i == len(raw_chunks) - 1
+
+ try:
+ edited_chunk = edit_chunk_for_speech(
+ chunk,
+ title=title if is_first else None,
+ is_first=is_first,
+ is_last=is_last,
+ )
+ edited_chunks.append(edited_chunk)
+ except Exception:
+ logger.exception("Failed to edit chunk %d", i + 1)
+ # Fall back to raw chunk if LLM fails
+ if is_first:
+ edited_chunks.append(
+ f"This is an audio version of {title}. {chunk}",
+ )
+ elif is_last:
+ edited_chunks.append(f"{chunk} This concludes the article.")
+ else:
+ edited_chunks.append(chunk)
+
+ return edited_chunks
+
+
+def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]:
+ """Split text into chunks at sentence boundaries."""
+ chunks = []
+ current_chunk = ""
+
+ # Split into paragraphs first
+ paragraphs = text.split("\n\n")
+
+ for para in paragraphs:
+ para_stripped = para.strip()
+ if not para_stripped:
+ continue
+
+ # If paragraph itself is too long, split by sentences
+ if len(para_stripped) > max_chars:
+ sentences = para_stripped.split(". ")
+ for sentence in sentences:
+ if len(current_chunk) + len(sentence) + 2 < max_chars:
+ current_chunk += sentence + ". "
+ else:
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+ current_chunk = sentence + ". "
+ # If adding this paragraph would exceed limit, start new chunk
+ elif len(current_chunk) + len(para_stripped) + 2 > max_chars:
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+ current_chunk = para_stripped + " "
+ else:
+ current_chunk += para_stripped + " "
+
+ # Don't forget the last chunk
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+
+ return chunks
+
+
+def edit_chunk_for_speech(
+ chunk: str,
+ title: str | None = None,
+ *,
+ is_first: bool = False,
+ is_last: bool = False,
+) -> str:
+ """Use LLM to lightly edit a single chunk for speech.
+
+ Raises:
+ ValueError: If no content is returned from LLM.
+ """
+ system_prompt = (
+ "You are a podcast script editor. Your job is to lightly edit text "
+ "to make it sound natural when spoken aloud.\n\n"
+ "Guidelines:\n"
+ )
+ system_prompt += """
+- Remove URLs and email addresses, replacing with descriptive phrases
+- Convert bullet points and lists into flowing sentences
+- Fix any awkward phrasing for speech
+- Remove references like "click here" or "see below"
+- Keep edits minimal - preserve the original content and style
+- Do NOT add commentary or explanations
+- Return ONLY the edited text, no JSON or formatting
+"""
+
+ user_prompt = chunk
+
+ # Add intro/outro if needed
+ if is_first and title:
+ user_prompt = (
+ f"Add a brief intro mentioning this is an audio version of "
+ f"'{title}', then edit this text:\n\n{chunk}"
+ )
+ elif is_last:
+ user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}"
+
+ try:
+ client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY)
+ response = client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ],
+ temperature=0.3, # Lower temperature for more consistent edits
+ max_tokens=4000,
+ )
+
+ content = response.choices[0].message.content
+ if not content:
+ msg = "No content returned from LLM"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Ensure the chunk isn't too long
+ max_chunk_length = 4000
+ if len(content) > max_chunk_length:
+ # Truncate at sentence boundary
+ sentences = content.split(". ")
+ truncated = ""
+ for sentence in sentences:
+ if len(truncated) + len(sentence) + 2 < max_chunk_length:
+ truncated += sentence + ". "
+ else:
+ break
+ content = truncated.strip()
+
+ except Exception:
+ logger.exception("LLM chunk editing failed")
+ raise
+ else:
+ return content
+
+
+def parse_datetime_with_timezone(created_at: str | datetime) -> datetime:
+ """Parse datetime string and ensure it has timezone info."""
+ if isinstance(created_at, str):
+ # Handle timezone-aware datetime strings
+ if created_at.endswith("Z"):
+ created_at = created_at[:-1] + "+00:00"
+ last_attempt = datetime.fromisoformat(created_at)
+ if last_attempt.tzinfo is None:
+ last_attempt = last_attempt.replace(tzinfo=timezone.utc)
+ else:
+ last_attempt = created_at
+ if last_attempt.tzinfo is None:
+ last_attempt = last_attempt.replace(tzinfo=timezone.utc)
+ return last_attempt
+
+
+def should_retry_job(job: dict[str, Any], max_retries: int) -> bool:
+ """Check if a job should be retried based on retry count and backoff time.
+
+ Uses exponential backoff to determine if enough time has passed.
+ """
+ retry_count = job["retry_count"]
+ if retry_count >= max_retries:
+ return False
+
+ # Exponential backoff: 30s, 60s, 120s
+ backoff_time = 30 * (2**retry_count)
+ last_attempt = parse_datetime_with_timezone(job["created_at"])
+ time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt
+
+ return time_since_attempt > timedelta(seconds=backoff_time)
+
+
+def process_pending_jobs(
+ processor: ArticleProcessor,
+) -> None:
+ """Process all pending jobs."""
+ pending_jobs = Core.Database.get_pending_jobs(
+ limit=5,
+ )
+
+ for job in pending_jobs:
+ if processor.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, stopping job processing")
+ break
+
+ current_job = job["id"]
+ try:
+ processor.process_job(job)
+ except Exception as e:
+ # Ensure job is marked as error even if process_job didn't handle it
+ logger.exception("Failed to process job: %d", current_job)
+ # Check if job is still in processing state
+ current_status = Core.Database.get_job_by_id(
+ current_job,
+ )
+ if current_status and current_status.get("status") == "processing":
+ Core.Database.update_job_status(
+ current_job,
+ "error",
+ str(e),
+ )
+ continue
+
+
+def process_retryable_jobs() -> None:
+ """Check and retry failed jobs with exponential backoff."""
+ retryable_jobs = Core.Database.get_retryable_jobs(
+ MAX_RETRIES,
+ )
+
+ for job in retryable_jobs:
+ if should_retry_job(job, MAX_RETRIES):
+ logger.info(
+ "Retrying job %d (attempt %d)",
+ job["id"],
+ job["retry_count"] + 1,
+ )
+ Core.Database.update_job_status(
+ job["id"],
+ "pending",
+ )
+
+
+def check_memory_usage() -> int | Any:
+ """Check current memory usage percentage."""
+ try:
+ process = psutil.Process()
+ # this returns an int but psutil is untyped
+ return process.memory_percent()
+ except (psutil.Error, OSError):
+ logger.warning("Failed to check memory usage")
+ return 0.0
+
+
+def cleanup_stale_jobs() -> None:
+ """Reset jobs stuck in processing state on startup."""
+ with Core.Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE queue
+ SET status = 'pending'
+ WHERE status = 'processing'
+ """,
+ )
+ affected = cursor.rowcount
+ conn.commit()
+
+ if affected > 0:
+ logger.info(
+ "Reset %d stale jobs from processing to pending",
+ affected,
+ )
+
+
+def main_loop() -> None:
+ """Poll for jobs and process them in a continuous loop."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+
+ # Clean up any stale jobs from previous runs
+ cleanup_stale_jobs()
+
+ logger.info("Worker started, polling for jobs...")
+
+ while not shutdown_handler.is_shutdown_requested():
+ try:
+ # Process pending jobs
+ process_pending_jobs(processor)
+ process_retryable_jobs()
+
+ # Check if there's any work
+ pending_jobs = Core.Database.get_pending_jobs(
+ limit=1,
+ )
+ retryable_jobs = Core.Database.get_retryable_jobs(
+ MAX_RETRIES,
+ )
+
+ if not pending_jobs and not retryable_jobs:
+ logger.debug("No jobs to process, sleeping...")
+
+ except Exception:
+ logger.exception("Error in main loop")
+
+ # Use interruptible sleep
+ if not shutdown_handler.is_shutdown_requested():
+ shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL)
+
+ # Graceful shutdown
+ current_job = shutdown_handler.get_current_job()
+ if current_job:
+ logger.info(
+ "Waiting for job %d to complete before shutdown...",
+ current_job,
+ )
+ # The job will complete or be reset to pending
+
+ logger.info("Worker shutdown complete")
+
+
+def move() -> None:
+ """Make the worker move."""
+ try:
+ # Initialize database
+ Core.Database.init_db()
+
+ # Start main processing loop
+ main_loop()
+
+ except KeyboardInterrupt:
+ logger.info("Worker stopped by user")
+ except Exception:
+ logger.exception("Worker crashed")
+ raise
+
+
+class TestArticleExtraction(Test.TestCase):
+ """Test article extraction functionality."""
+
+ def test_extract_valid_article(self) -> None:
+ """Extract from well-formed HTML."""
+ # Mock trafilatura.fetch_url and extract
+ mock_html = (
+ "<html><body><h1>Test Article</h1><p>Content here</p></body></html>"
+ )
+ mock_result = json.dumps({
+ "title": "Test Article",
+ "text": "Content here",
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Test Article")
+ self.assertEqual(content, "Content here")
+ self.assertIsNone(author)
+ self.assertIsNone(pub_date)
+
+ def test_extract_missing_title(self) -> None:
+ """Handle articles without titles."""
+ mock_html = "<html><body><p>Content without title</p></body></html>"
+ mock_result = json.dumps({"text": "Content without title"})
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Untitled Article")
+ self.assertEqual(content, "Content without title")
+ self.assertIsNone(author)
+ self.assertIsNone(pub_date)
+
+ def test_extract_empty_content(self) -> None:
+ """Handle empty articles."""
+ mock_html = "<html><body></body></html>"
+ mock_result = json.dumps({"title": "Empty Article", "text": ""})
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ pytest.raises(ValueError, match="No content extracted") as cm,
+ ):
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertIn("No content extracted", str(cm.value))
+
+ def test_extract_network_error(self) -> None:
+ """Handle connection failures."""
+ with (
+ unittest.mock.patch("trafilatura.fetch_url", return_value=None),
+ pytest.raises(ValueError, match="Failed to download") as cm,
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ self.assertIn("Failed to download", str(cm.value))
+
+ @staticmethod
+ def test_extract_timeout() -> None:
+ """Handle slow responses."""
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ side_effect=TimeoutError("Timeout"),
+ ),
+ pytest.raises(TimeoutError),
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ def test_content_sanitization(self) -> None:
+ """Remove unwanted elements."""
+ mock_html = """
+ <html><body>
+ <h1>Article</h1>
+ <p>Good content</p>
+ <script>alert('bad')</script>
+ <table><tr><td>data</td></tr></table>
+ </body></html>
+ """
+ mock_result = json.dumps({
+ "title": "Article",
+ "text": "Good content", # Tables and scripts removed
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ _title, content, _author, _pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(content, "Good content")
+ self.assertNotIn("script", content)
+ self.assertNotIn("table", content)
+
+
+class TestTextToSpeech(Test.TestCase):
+ """Test text-to-speech functionality."""
+
+ def setUp(self) -> None:
+ """Set up mocks."""
+ # Mock OpenAI API key
+ self.env_patcher = unittest.mock.patch.dict(
+ os.environ,
+ {"OPENAI_API_KEY": "test-key"},
+ )
+ self.env_patcher.start()
+
+ # Mock OpenAI response
+ self.mock_audio_response: unittest.mock.MagicMock = (
+ unittest.mock.MagicMock()
+ )
+ self.mock_audio_response.content = b"fake-audio-data"
+
+ # Mock AudioSegment to avoid ffmpeg issues in tests
+ self.mock_audio_segment: unittest.mock.MagicMock = (
+ unittest.mock.MagicMock()
+ )
+ self.mock_audio_segment.export.return_value = None
+ self.audio_segment_patcher = unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ return_value=self.mock_audio_segment,
+ )
+ self.audio_segment_patcher.start()
+
+ # Mock the concatenation operations
+ self.mock_audio_segment.__add__.return_value = self.mock_audio_segment
+
+ def tearDown(self) -> None:
+ """Clean up mocks."""
+ self.env_patcher.stop()
+ self.audio_segment_patcher.stop()
+
+ def test_tts_generation(self) -> None:
+ """Generate audio from text."""
+
+ # Mock the export to write test audio data
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=["Test content"],
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech(
+ "Test content",
+ "Test Title",
+ )
+
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_tts_chunking(self) -> None:
+ """Handle long articles with chunking."""
+ long_text = "Long content " * 1000
+ chunks = ["Chunk 1", "Chunk 2", "Chunk 3"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock AudioSegment.silent
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech(
+ long_text,
+ "Long Article",
+ )
+
+ # Should have called TTS for each chunk
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_tts_empty_text(self) -> None:
+ """Handle empty input."""
+ with unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=[],
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ with pytest.raises(ValueError, match="No chunks generated") as cm:
+ processor.text_to_speech("", "Empty")
+
+ self.assertIn("No chunks generated", str(cm.value))
+
+ def test_tts_special_characters(self) -> None:
+ """Handle unicode and special chars."""
+ special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"'
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=[special_text],
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech(
+ special_text,
+ "Special",
+ )
+
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_llm_text_preparation(self) -> None:
+ """Verify LLM editing."""
+ # Test the actual text preparation functions
+ chunks = split_text_into_chunks("Short text", max_chars=100)
+ self.assertEqual(len(chunks), 1)
+ self.assertEqual(chunks[0], "Short text")
+
+ # Test long text splitting
+ long_text = "Sentence one. " * 100
+ chunks = split_text_into_chunks(long_text, max_chars=100)
+ self.assertGreater(len(chunks), 1)
+ for chunk in chunks:
+ self.assertLessEqual(len(chunk), 100)
+
+ @staticmethod
+ def test_llm_failure_fallback() -> None:
+ """Handle LLM API failures."""
+ # Mock LLM failure
+ with unittest.mock.patch("openai.OpenAI") as mock_openai:
+ mock_client = mock_openai.return_value
+ mock_client.chat.completions.create.side_effect = Exception(
+ "API Error",
+ )
+
+ # Should fall back to raw text
+ with pytest.raises(Exception, match="API Error"):
+ edit_chunk_for_speech("Test chunk", "Title", is_first=True)
+
+ def test_chunk_concatenation(self) -> None:
+ """Verify audio joining."""
+ # Mock multiple audio segments
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test", "Title")
+
+ # Should produce combined audio
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_parallel_tts_generation(self) -> None:
+ """Test parallel TTS processing."""
+ chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"]
+
+ # Mock responses for each chunk
+ mock_responses = []
+ for i in range(len(chunks)):
+ mock_resp = unittest.mock.MagicMock()
+ mock_resp.content = f"audio-{i}".encode()
+ mock_responses.append(mock_resp)
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+
+ # Make create return different responses for each call
+ mock_speech.create.side_effect = mock_responses
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Mock AudioSegment operations
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.__add__.return_value = mock_segment
+
+ def mock_export(path: str, **_kwargs: typing.Any) -> None:
+ Path(path).write_bytes(b"combined-audio")
+
+ mock_segment.export = mock_export
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0, # Normal memory usage
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ # Verify all chunks were processed
+ self.assertEqual(mock_speech.create.call_count, len(chunks))
+ self.assertEqual(audio_data, b"combined-audio")
+
+ def test_parallel_tts_high_memory_fallback(self) -> None:
+ """Test fallback to serial processing when memory is high."""
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"serial-audio")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=65.0, # High memory usage
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ # Should use serial processing
+ self.assertEqual(audio_data, b"serial-audio")
+
+ @staticmethod
+ def test_parallel_tts_error_handling() -> None:
+ """Test error handling in parallel TTS processing."""
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ # Mock OpenAI client with one failure
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+
+ # First call succeeds, second fails
+ mock_resp1 = unittest.mock.MagicMock()
+ mock_resp1.content = b"audio-1"
+ mock_speech.create.side_effect = [mock_resp1, Exception("API Error")]
+
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Set up the test context
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ pytest.raises(Exception, match="API Error"),
+ ):
+ processor.text_to_speech("Test content", "Test Title")
+
+ def test_parallel_tts_order_preservation(self) -> None:
+ """Test that chunks are combined in the correct order."""
+ chunks = ["First", "Second", "Third", "Fourth", "Fifth"]
+
+ # Create mock responses with identifiable content
+ mock_responses = []
+ for chunk in chunks:
+ mock_resp = unittest.mock.MagicMock()
+ mock_resp.content = f"audio-{chunk}".encode()
+ mock_responses.append(mock_resp)
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.side_effect = mock_responses
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Track the order of segments being combined
+ combined_order = []
+
+ def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock:
+ content = data.read()
+ combined_order.append(content.decode())
+ segment = unittest.mock.MagicMock()
+ segment.__add__.return_value = segment
+ return segment
+
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.__add__.return_value = mock_segment
+
+ def mock_export(path: str, **_kwargs: typing.Any) -> None:
+ # Verify order is preserved
+ expected_order = [f"audio-{chunk}" for chunk in chunks]
+ if combined_order != expected_order:
+ msg = f"Order mismatch: {combined_order} != {expected_order}"
+ raise AssertionError(msg)
+ Path(path).write_bytes(b"ordered-audio")
+
+ mock_segment.export = mock_export
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ side_effect=mock_from_mp3,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ self.assertEqual(audio_data, b"ordered-audio")
+
+
+class TestIntroOutro(Test.TestCase):
+ """Test intro and outro generation with metadata."""
+
+ def test_create_intro_text_full_metadata(self) -> None:
+ """Test intro text creation with all metadata."""
+ intro = ArticleProcessor._create_intro_text( # noqa: SLF001
+ title="Test Article",
+ author="John Doe",
+ pub_date="2024-01-15",
+ )
+ self.assertIn("Title: Test Article", intro)
+ self.assertIn("Author: John Doe", intro)
+ self.assertIn("Published: 2024-01-15", intro)
+
+ def test_create_intro_text_no_author(self) -> None:
+ """Test intro text without author."""
+ intro = ArticleProcessor._create_intro_text( # noqa: SLF001
+ title="Test Article",
+ author=None,
+ pub_date="2024-01-15",
+ )
+ self.assertIn("Title: Test Article", intro)
+ self.assertNotIn("Author:", intro)
+ self.assertIn("Published: 2024-01-15", intro)
+
+ def test_create_intro_text_minimal(self) -> None:
+ """Test intro text with only title."""
+ intro = ArticleProcessor._create_intro_text( # noqa: SLF001
+ title="Test Article",
+ author=None,
+ pub_date=None,
+ )
+ self.assertEqual(intro, "Title: Test Article.")
+
+ def test_create_outro_text_with_author(self) -> None:
+ """Test outro text with author."""
+ outro = ArticleProcessor._create_outro_text( # noqa: SLF001
+ title="Test Article",
+ author="Jane Smith",
+ )
+ self.assertIn("Test Article", outro)
+ self.assertIn("Jane Smith", outro)
+ self.assertIn("Podcast It Later", outro)
+
+ def test_create_outro_text_no_author(self) -> None:
+ """Test outro text without author."""
+ outro = ArticleProcessor._create_outro_text( # noqa: SLF001
+ title="Test Article",
+ author=None,
+ )
+ self.assertIn("Test Article", outro)
+ self.assertNotIn("by", outro)
+ self.assertIn("Podcast It Later", outro)
+
+ def test_extract_with_metadata(self) -> None:
+ """Test that extraction returns metadata."""
+ mock_html = "<html><body><p>Content</p></body></html>"
+ mock_result = json.dumps({
+ "title": "Test Article",
+ "text": "Article content",
+ "author": "Test Author",
+ "date": "2024-01-15",
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Test Article")
+ self.assertEqual(content, "Article content")
+ self.assertEqual(author, "Test Author")
+ self.assertEqual(pub_date, "2024-01-15")
+
+
+class TestMemoryEfficiency(Test.TestCase):
+ """Test memory-efficient processing."""
+
+ def test_large_article_size_limit(self) -> None:
+ """Test that articles exceeding size limits are rejected."""
+ huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=huge_text * 4, # Simulate large HTML
+ ),
+ pytest.raises(ValueError, match="Article too large") as cm,
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ self.assertIn("Article too large", str(cm.value))
+
+ def test_content_truncation(self) -> None:
+ """Test that oversized content is truncated."""
+ large_content = "Content " * 100_000 # Create large content
+ mock_result = json.dumps({
+ "title": "Large Article",
+ "text": large_content,
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value="<html><body>content</body></html>",
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, _author, _pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Large Article")
+ self.assertLessEqual(len(content), MAX_ARTICLE_SIZE)
+
+ def test_memory_usage_check(self) -> None:
+ """Test memory usage monitoring."""
+ usage = check_memory_usage()
+ self.assertIsInstance(usage, float)
+ self.assertGreaterEqual(usage, 0.0)
+ self.assertLessEqual(usage, 100.0)
+
+
+class TestJobProcessing(Test.TestCase):
+ """Test job processing functionality."""
+
+ def setUp(self) -> None:
+ """Set up test environment."""
+ Core.Database.init_db()
+
+ # Create test user and job
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ )
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com/article",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Mock environment
+ self.env_patcher = unittest.mock.patch.dict(
+ os.environ,
+ {
+ "OPENAI_API_KEY": "test-key",
+ "S3_ENDPOINT": "https://s3.example.com",
+ "S3_BUCKET": "test-bucket",
+ "S3_ACCESS_KEY": "test-access",
+ "S3_SECRET_KEY": "test-secret",
+ },
+ )
+ self.env_patcher.start()
+
+ def tearDown(self) -> None:
+ """Clean up."""
+ self.env_patcher.stop()
+ Core.Database.teardown()
+
+ def test_process_job_success(self) -> None:
+ """Complete pipeline execution."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ # Mock all external calls
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=(
+ "Test Title",
+ "Test content",
+ "Test Author",
+ "2024-01-15",
+ ),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ return_value=b"audio-data",
+ ),
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ return_value="https://s3.example.com/audio.mp3",
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.create_episode",
+ ) as mock_create,
+ ):
+ mock_create.return_value = 1
+ processor.process_job(job)
+
+ # Verify job was marked complete
+ mock_update.assert_called_with(self.job_id, "completed")
+ mock_create.assert_called_once()
+
+ def test_process_job_extraction_failure(self) -> None:
+ """Handle bad URLs."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ side_effect=ValueError("Bad URL"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ pytest.raises(ValueError, match="Bad URL"),
+ ):
+ processor.process_job(job)
+
+ # Job should be marked as error
+ mock_update.assert_called_with(self.job_id, "error", "Bad URL")
+
+ def test_process_job_tts_failure(self) -> None:
+ """Handle TTS errors."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Title", "Content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ side_effect=Exception("TTS Error"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ pytest.raises(Exception, match="TTS Error"),
+ ):
+ processor.process_job(job)
+
+ mock_update.assert_called_with(self.job_id, "error", "TTS Error")
+
+ def test_process_job_s3_failure(self) -> None:
+ """Handle upload errors."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Title", "Content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ return_value=b"audio",
+ ),
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ side_effect=ClientError({}, "PutObject"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ),
+ pytest.raises(ClientError),
+ ):
+ processor.process_job(job)
+
+ def test_job_retry_logic(self) -> None:
+ """Verify exponential backoff."""
+ # Set job to error with retry count
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "First failure",
+ )
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "Second failure",
+ )
+
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ self.assertEqual(job["retry_count"], 2)
+
+ # Should be retryable
+ retryable = Core.Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 1)
+
+ def test_max_retries(self) -> None:
+ """Stop after max attempts."""
+ # Exceed retry limit
+ for i in range(4):
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ f"Failure {i}",
+ )
+
+ # Should not be retryable
+ retryable = Core.Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 0)
+
+ def test_graceful_shutdown(self) -> None:
+ """Test graceful shutdown during job processing."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ def mock_tts(*_args: Any) -> bytes:
+ shutdown_handler.shutdown_requested.set()
+ return b"audio-data"
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=(
+ "Test Title",
+ "Test content",
+ "Test Author",
+ "2024-01-15",
+ ),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ side_effect=mock_tts,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ ):
+ processor.process_job(job)
+
+ # Job should be reset to pending due to shutdown
+ mock_update.assert_any_call(self.job_id, "pending")
+
+ def test_cleanup_stale_jobs(self) -> None:
+ """Test cleanup of stale processing jobs."""
+ # Manually set job to processing
+ Core.Database.update_job_status(self.job_id, "processing")
+
+ # Run cleanup
+ cleanup_stale_jobs()
+
+ # Job should be back to pending
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+ self.assertEqual(job["status"], "pending")
+
+ def test_concurrent_processing(self) -> None:
+ """Handle multiple jobs."""
+ # Create multiple jobs
+ job2 = Core.Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ job3 = Core.Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Get pending jobs
+ jobs = Core.Database.get_pending_jobs(limit=5)
+
+ self.assertEqual(len(jobs), 3)
+ self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3})
+
+ def test_memory_threshold_deferral(self) -> None:
+ """Test that jobs are deferred when memory usage is high."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ # Mock high memory usage
+ with (
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=90.0, # High memory usage
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ ):
+ processor.process_job(job)
+
+ # Job should not be processed (no status updates)
+ mock_update.assert_not_called()
+
+
+class TestWorkerErrorHandling(Test.TestCase):
+ """Test worker error handling and recovery."""
+
+ def setUp(self) -> None:
+ """Set up test environment."""
+ Core.Database.init_db()
+ self.user_id, _ = Core.Database.create_user("test@example.com")
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+ self.shutdown_handler = ShutdownHandler()
+ self.processor = ArticleProcessor(self.shutdown_handler)
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up."""
+ Core.Database.teardown()
+
+ def test_process_pending_jobs_exception_handling(self) -> None:
+ """Test that process_pending_jobs handles exceptions."""
+
+ def side_effect(job: dict[str, Any]) -> None:
+ # Simulate process_job starting and setting status to processing
+ Core.Database.update_job_status(job["id"], "processing")
+ msg = "Unexpected Error"
+ raise ValueError(msg)
+
+ with (
+ unittest.mock.patch.object(
+ self.processor,
+ "process_job",
+ side_effect=side_effect,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ side_effect=Core.Database.update_job_status,
+ ) as _mock_update,
+ ):
+ process_pending_jobs(self.processor)
+
+ # Job should be marked as error
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job:
+ self.assertEqual(job["status"], "error")
+ self.assertIn("Unexpected Error", job["error_message"])
+
+ def test_process_retryable_jobs_success(self) -> None:
+ """Test processing of retryable jobs."""
+ # Set up a retryable job
+ Core.Database.update_job_status(self.job_id, "error", "Fail 1")
+
+ # Modify created_at to be in the past to satisfy backoff
+ with Core.Database.get_connection() as conn:
+ conn.execute(
+ "UPDATE queue SET created_at = ? WHERE id = ?",
+ (
+ (
+ datetime.now(tz=timezone.utc) - timedelta(minutes=5)
+ ).isoformat(),
+ self.job_id,
+ ),
+ )
+ conn.commit()
+
+ process_retryable_jobs()
+
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job:
+ self.assertEqual(job["status"], "pending")
+
+ def test_process_retryable_jobs_not_ready(self) -> None:
+ """Test that jobs are not retried before backoff period."""
+ # Set up a retryable job that just failed
+ Core.Database.update_job_status(self.job_id, "error", "Fail 1")
+
+ # created_at is now, so backoff should prevent retry
+ process_retryable_jobs()
+
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job:
+ self.assertEqual(job["status"], "error")
+
+
+class TestTextChunking(Test.TestCase):
+ """Test text chunking edge cases."""
+
+ def test_split_text_single_long_word(self) -> None:
+ """Handle text with a single word exceeding limit."""
+ long_word = "a" * 4000
+ chunks = split_text_into_chunks(long_word, max_chars=3000)
+
+ # Should keep it as one chunk or split?
+ # The current implementation does not split words
+ self.assertEqual(len(chunks), 1)
+ self.assertEqual(len(chunks[0]), 4000)
+
+ def test_split_text_no_sentence_boundaries(self) -> None:
+ """Handle long text with no sentence boundaries."""
+ text = "word " * 1000 # 5000 chars
+ chunks = split_text_into_chunks(text, max_chars=3000)
+
+ # Should keep it as one chunk as it can't split by ". "
+ self.assertEqual(len(chunks), 1)
+ self.assertGreater(len(chunks[0]), 3000)
+
+
+def test() -> None:
+ """Run the tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestArticleExtraction,
+ TestTextToSpeech,
+ TestMemoryEfficiency,
+ TestJobProcessing,
+ TestWorkerErrorHandling,
+ TestTextChunking,
+ ],
+ )
+
+
+def main() -> None:
+ """Entry point for the worker."""
+ if "test" in sys.argv:
+ test()
+ else:
+ move()