diff options
Diffstat (limited to 'Biz/PodcastItLater')
| -rw-r--r-- | Biz/PodcastItLater/Admin.py | 1068 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Billing.py | 581 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 2174 | ||||
| -rw-r--r-- | Biz/PodcastItLater/DESIGN.md | 43 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Episode.py | 390 | ||||
| -rw-r--r-- | Biz/PodcastItLater/INFRASTRUCTURE.md | 38 | ||||
| -rw-r--r-- | Biz/PodcastItLater/STRIPE_TESTING.md | 114 | ||||
| -rw-r--r-- | Biz/PodcastItLater/TESTING.md | 45 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Test.py | 276 | ||||
| -rw-r--r-- | Biz/PodcastItLater/TestMetricsView.py | 121 | ||||
| -rw-r--r-- | Biz/PodcastItLater/UI.py | 755 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.nix | 93 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 3480 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker.nix | 63 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 2199 |
15 files changed, 11440 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Admin.py b/Biz/PodcastItLater/Admin.py new file mode 100644 index 0000000..6f60948 --- /dev/null +++ b/Biz/PodcastItLater/Admin.py @@ -0,0 +1,1068 @@ +""" +PodcastItLater Admin Interface. + +Admin pages and functionality for managing users and queue items. +""" + +# : out podcastitlater-admin +# : dep ludic +# : dep httpx +# : dep starlette +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.UI as UI +import ludic.html as html + +# i need to import these unused because bild cannot get local transitive python +# dependencies yet +import Omni.App as App # noqa: F401 +import Omni.Log as Log # noqa: F401 +import Omni.Test as Test # noqa: F401 +import sys +import typing +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from ludic.web import Request +from ludic.web.datastructures import FormData +from ludic.web.responses import Response +from typing import override + + +class MetricsAttrs(Attrs): + """Attributes for Metrics component.""" + + metrics: dict[str, typing.Any] + user: dict[str, typing.Any] | None + + +class MetricCardAttrs(Attrs): + """Attributes for MetricCard component.""" + + title: str + value: int + icon: str + + +class MetricCard(Component[AnyChildren, MetricCardAttrs]): + """Display a single metric card.""" + + @override + def render(self) -> html.div: + title = self.attrs["title"] + value = self.attrs["value"] + icon = self.attrs.get("icon", "bi-bar-chart") + + return html.div( + html.div( + html.div( + html.i(classes=["bi", icon, "text-primary", "fs-2"]), + classes=["col-auto"], + ), + html.div( + html.h6(title, classes=["text-muted", "mb-1"]), + html.h3(str(value), classes=["mb-0"]), + classes=["col"], + ), + classes=["row", "align-items-center"], + ), + classes=["card-body"], + ) + + +class TopEpisodesTableAttrs(Attrs): + """Attributes for TopEpisodesTable component.""" + + episodes: list[dict[str, typing.Any]] + metric_name: str + count_key: str + + +class TopEpisodesTable(Component[AnyChildren, TopEpisodesTableAttrs]): + """Display a table of top episodes by a metric.""" + + @override + def render(self) -> html.div: + episodes = self.attrs["episodes"] + metric_name = self.attrs["metric_name"] + count_key = self.attrs["count_key"] + + if not episodes: + return html.div( + html.p( + "No data yet", + classes=["text-muted", "text-center", "py-3"], + ), + classes=["card-body"], + ) + + return html.div( + html.div( + html.table( + html.thead( + html.tr( + html.th("#", classes=["text-muted"]), + html.th("Title"), + html.th("Author", classes=["text-muted"]), + html.th( + metric_name, + classes=["text-end", "text-muted"], + ), + ), + classes=["table-light"], + ), + html.tbody( + *[ + html.tr( + html.td( + str(idx + 1), + classes=["text-muted"], + ), + html.td( + TruncatedText( + text=episode["title"], + max_length=Core.TITLE_TRUNCATE_LENGTH, + ), + ), + html.td( + episode.get("author") or "-", + classes=["text-muted"], + ), + html.td( + str(episode[count_key]), + classes=["text-end"], + ), + ) + for idx, episode in enumerate(episodes) + ], + ), + classes=["table", "table-hover", "mb-0"], + ), + classes=["table-responsive"], + ), + classes=["card-body", "p-0"], + ) + + +class MetricsDashboard(Component[AnyChildren, MetricsAttrs]): + """Admin metrics dashboard showing aggregate statistics.""" + + @override + def render(self) -> UI.PageLayout: + metrics = self.attrs["metrics"] + user = self.attrs.get("user") + + return UI.PageLayout( + html.div( + html.h2( + html.i(classes=["bi", "bi-people", "me-2"]), + "Growth & Usage", + classes=["mb-4"], + ), + # Growth & Usage cards + html.div( + html.div( + html.div( + MetricCard( + title="Total Users", + value=metrics.get("total_users", 0), + icon="bi-people", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Active Subs", + value=metrics.get("active_subscriptions", 0), + icon="bi-credit-card", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Submissions (24h)", + value=metrics.get("submissions_24h", 0), + icon="bi-activity", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Submissions (7d)", + value=metrics.get("submissions_7d", 0), + icon="bi-calendar-week", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + classes=["row", "g-3", "mb-5"], + ), + html.h2( + html.i(classes=["bi", "bi-graph-up", "me-2"]), + "Episode Metrics", + classes=["mb-4"], + ), + # Summary cards + html.div( + html.div( + html.div( + MetricCard( + title="Total Episodes", + value=metrics["total_episodes"], + icon="bi-collection", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Total Plays", + value=metrics["total_plays"], + icon="bi-play-circle", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Total Downloads", + value=metrics["total_downloads"], + icon="bi-download", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Total Adds", + value=metrics["total_adds"], + icon="bi-plus-circle", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + classes=["row", "g-3", "mb-4"], + ), + # Top episodes tables + html.div( + html.div( + html.div( + html.div( + html.h5( + html.i( + classes=[ + "bi", + "bi-play-circle-fill", + "me-2", + ], + ), + "Most Played", + classes=["card-title", "mb-0"], + ), + classes=["card-header", "bg-white"], + ), + TopEpisodesTable( + episodes=metrics["most_played"], + metric_name="Plays", + count_key="play_count", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-4"], + ), + html.div( + html.div( + html.div( + html.h5( + html.i( + classes=[ + "bi", + "bi-download", + "me-2", + ], + ), + "Most Downloaded", + classes=["card-title", "mb-0"], + ), + classes=["card-header", "bg-white"], + ), + TopEpisodesTable( + episodes=metrics["most_downloaded"], + metric_name="Downloads", + count_key="download_count", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-4"], + ), + html.div( + html.div( + html.div( + html.h5( + html.i( + classes=[ + "bi", + "bi-plus-circle-fill", + "me-2", + ], + ), + "Most Added to Feeds", + classes=["card-title", "mb-0"], + ), + classes=["card-header", "bg-white"], + ), + TopEpisodesTable( + episodes=metrics["most_added"], + metric_name="Adds", + count_key="add_count", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-4"], + ), + classes=["row", "g-3"], + ), + ), + user=user, + current_page="admin-metrics", + error=None, + ) + + +class AdminUsersAttrs(Attrs): + """Attributes for AdminUsers component.""" + + users: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + + +class StatusBadgeAttrs(Attrs): + """Attributes for StatusBadge component.""" + + status: str + count: int | None + + +class StatusBadge(Component[AnyChildren, StatusBadgeAttrs]): + """Display a status badge with optional count.""" + + @override + def render(self) -> html.span: + status = self.attrs["status"] + count = self.attrs.get("count", None) + + text = f"{status.upper()}: {count}" if count is not None else status + badge_class = self.get_status_badge_class(status) + + return html.span( + text, + classes=["badge", badge_class, "me-3" if count is not None else ""], + ) + + @staticmethod + def get_status_badge_class(status: str) -> str: + """Get Bootstrap badge class for status.""" + return { + "pending": "bg-warning text-dark", + "processing": "bg-primary", + "completed": "bg-success", + "active": "bg-success", + "error": "bg-danger", + "cancelled": "bg-secondary", + "disabled": "bg-danger", + }.get(status, "bg-secondary") + + +class TruncatedTextAttrs(Attrs): + """Attributes for TruncatedText component.""" + + text: str + max_length: int + max_width: str + + +class TruncatedText(Component[AnyChildren, TruncatedTextAttrs]): + """Display truncated text with tooltip.""" + + @override + def render(self) -> html.div: + text = self.attrs["text"] + max_length = self.attrs["max_length"] + max_width = self.attrs.get("max_width", "200px") + + truncated = ( + text[:max_length] + "..." if len(text) > max_length else text + ) + + return html.div( + truncated, + title=text, + classes=["text-truncate"], + style={"max-width": max_width}, + ) + + +class ActionButtonsAttrs(Attrs): + """Attributes for ActionButtons component.""" + + job_id: int + status: str + + +class ActionButtons(Component[AnyChildren, ActionButtonsAttrs]): + """Render action buttons for queue items.""" + + @override + def render(self) -> html.div: + job_id = self.attrs["job_id"] + status = self.attrs["status"] + + buttons = [] + + if status != "completed": + buttons.append( + html.button( + html.i(classes=["bi", "bi-arrow-clockwise", "me-1"]), + "Retry", + hx_post=f"/queue/{job_id}/retry", + hx_target="body", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-success", "me-1"], + disabled=status == "completed", + ), + ) + + buttons.append( + html.button( + html.i(classes=["bi", "bi-trash", "me-1"]), + "Delete", + hx_delete=f"/queue/{job_id}", + hx_confirm="Are you sure you want to delete this queue item?", + hx_target="body", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-danger"], + ), + ) + + return html.div( + *buttons, + classes=["btn-group"], + ) + + +class QueueTableRowAttrs(Attrs): + """Attributes for QueueTableRow component.""" + + item: dict[str, typing.Any] + + +class QueueTableRow(Component[AnyChildren, QueueTableRowAttrs]): + """Render a single queue table row.""" + + @override + def render(self) -> html.tr: + item = self.attrs["item"] + + return html.tr( + html.td(str(item["id"])), + html.td( + TruncatedText( + text=item["url"], + max_length=Core.TITLE_TRUNCATE_LENGTH, + max_width="300px", + ), + ), + html.td( + TruncatedText( + text=item.get("title") or "-", + max_length=Core.TITLE_TRUNCATE_LENGTH, + ), + ), + html.td(item["email"] or "-"), + html.td(StatusBadge(status=item["status"])), + html.td(str(item.get("retry_count", 0))), + html.td(html.small(item["created_at"], classes=["text-muted"])), + html.td( + TruncatedText( + text=item["error_message"] or "-", + max_length=Core.ERROR_TRUNCATE_LENGTH, + ) + if item["error_message"] + else html.span("-", classes=["text-muted"]), + ), + html.td(ActionButtons(job_id=item["id"], status=item["status"])), + ) + + +class EpisodeTableRowAttrs(Attrs): + """Attributes for EpisodeTableRow component.""" + + episode: dict[str, typing.Any] + + +class EpisodeTableRow(Component[AnyChildren, EpisodeTableRowAttrs]): + """Render a single episode table row.""" + + @override + def render(self) -> html.tr: + episode = self.attrs["episode"] + + return html.tr( + html.td(str(episode["id"])), + html.td( + TruncatedText( + text=episode["title"], + max_length=Core.TITLE_TRUNCATE_LENGTH, + ), + ), + html.td( + html.a( + html.i(classes=["bi", "bi-play-circle", "me-1"]), + "Listen", + href=episode["audio_url"], + target="_blank", + classes=["btn", "btn-sm", "btn-outline-primary"], + ), + ), + html.td( + f"{episode['duration']}s" if episode["duration"] else "-", + ), + html.td( + f"{episode['content_length']:,} chars" + if episode["content_length"] + else "-", + ), + html.td(html.small(episode["created_at"], classes=["text-muted"])), + ) + + +class UserTableRowAttrs(Attrs): + """Attributes for UserTableRow component.""" + + user: dict[str, typing.Any] + + +class UserTableRow(Component[AnyChildren, UserTableRowAttrs]): + """Render a single user table row.""" + + @override + def render(self) -> html.tr: + user = self.attrs["user"] + + return html.tr( + html.td(user["email"]), + html.td(html.small(user["created_at"], classes=["text-muted"])), + html.td(StatusBadge(status=user.get("status", "pending"))), + html.td( + html.select( + html.option( + "Pending", + value="pending", + selected=user.get("status") == "pending", + ), + html.option( + "Active", + value="active", + selected=user.get("status") == "active", + ), + html.option( + "Disabled", + value="disabled", + selected=user.get("status") == "disabled", + ), + name="status", + hx_post=f"/admin/users/{user['id']}/status", + hx_trigger="change", + hx_target="body", + hx_swap="outerHTML", + classes=["form-select", "form-select-sm"], + ), + ), + ) + + +def create_table_header(columns: list[str]) -> html.thead: + """Create a table header with given column names.""" + return html.thead( + html.tr(*[html.th(col, scope="col") for col in columns]), + classes=["table-light"], + ) + + +class AdminUsers(Component[AnyChildren, AdminUsersAttrs]): + """Admin view for managing users.""" + + @override + def render(self) -> UI.PageLayout: + users = self.attrs["users"] + user = self.attrs.get("user") + + return UI.PageLayout( + html.h2( + "User Management", + classes=["mb-4"], + ), + self._render_users_table(users), + user=user, + current_page="admin-users", + error=None, + ) + + @staticmethod + def _render_users_table( + users: list[dict[str, typing.Any]], + ) -> html.div: + """Render users table.""" + return html.div( + html.h2("All Users", classes=["mb-3"]), + html.div( + html.table( + create_table_header([ + "Email", + "Created At", + "Status", + "Actions", + ]), + html.tbody(*[UserTableRow(user=user) for user in users]), + classes=["table", "table-hover", "table-striped"], + ), + classes=["table-responsive"], + ), + ) + + +class AdminViewAttrs(Attrs): + """Attributes for AdminView component.""" + + queue_items: list[dict[str, typing.Any]] + episodes: list[dict[str, typing.Any]] + status_counts: dict[str, int] + user: dict[str, typing.Any] | None + + +class AdminView(Component[AnyChildren, AdminViewAttrs]): + """Admin view showing all queue items and episodes in tables.""" + + @override + def render(self) -> UI.PageLayout: + queue_items = self.attrs["queue_items"] + episodes = self.attrs["episodes"] + status_counts = self.attrs.get("status_counts", {}) + user = self.attrs.get("user") + + return UI.PageLayout( + html.div( + AdminView.render_content( + queue_items, + episodes, + status_counts, + ), + id="admin-content", + hx_get="/admin", + hx_trigger="every 10s", + hx_swap="innerHTML", + hx_target="#admin-content", + ), + user=user, + current_page="admin", + error=None, + ) + + @staticmethod + def render_content( + queue_items: list[dict[str, typing.Any]], + episodes: list[dict[str, typing.Any]], + status_counts: dict[str, int], + ) -> html.div: + """Render the main content of the admin page.""" + return html.div( + html.h2( + "Admin Queue Status", + classes=["mb-4"], + ), + AdminView.render_status_summary(status_counts), + AdminView.render_queue_table(queue_items), + AdminView.render_episodes_table(episodes), + ) + + @staticmethod + def render_status_summary(status_counts: dict[str, int]) -> html.div: + """Render status summary section.""" + return html.div( + html.h2("Status Summary", classes=["mb-3"]), + html.div( + *[ + StatusBadge(status=status, count=count) + for status, count in status_counts.items() + ], + classes=["mb-4"], + ), + ) + + @staticmethod + def render_queue_table( + queue_items: list[dict[str, typing.Any]], + ) -> html.div: + """Render queue items table.""" + return html.div( + html.h2("Queue Items", classes=["mb-3"]), + html.div( + html.table( + create_table_header([ + "ID", + "URL", + "Title", + "Email", + "Status", + "Retries", + "Created", + "Error", + "Actions", + ]), + html.tbody(*[ + QueueTableRow(item=item) for item in queue_items + ]), + classes=["table", "table-hover", "table-sm"], + ), + classes=["table-responsive", "mb-5"], + ), + ) + + @staticmethod + def render_episodes_table( + episodes: list[dict[str, typing.Any]], + ) -> html.div: + """Render episodes table.""" + return html.div( + html.h2("Completed Episodes", classes=["mb-3"]), + html.div( + html.table( + create_table_header([ + "ID", + "Title", + "Audio URL", + "Duration", + "Content Length", + "Created", + ]), + html.tbody(*[ + EpisodeTableRow(episode=episode) for episode in episodes + ]), + classes=["table", "table-hover", "table-sm"], + ), + classes=["table-responsive"], + ), + ) + + +def admin_queue_status(request: Request) -> AdminView | Response | html.div: + """Return admin view showing all queue items and episodes.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + # Redirect to login + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id( + user_id, + ) + if not user: + # Invalid session + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + # Check if user is admin + if not Core.is_admin(user): + # Forbidden - redirect to home with error + return Response( + "", + status_code=302, + headers={"Location": "/?error=forbidden"}, + ) + + # Admins can see all data (excluding completed items) + all_queue_items = [ + item + for item in Core.Database.get_all_queue_items(None) + if item.get("status") != "completed" + ] + all_episodes = Core.Database.get_all_episodes( + None, + ) + + # Get overall status counts for all users + status_counts: dict[str, int] = {} + for item in all_queue_items: + status = item.get("status", "unknown") + status_counts[status] = status_counts.get(status, 0) + 1 + + # Check if this is an HTMX request for auto-update + if request.headers.get("HX-Request") == "true": + # Return just the content div for HTMX updates + content = AdminView.render_content( + all_queue_items, + all_episodes, + status_counts, + ) + return html.div( + content, + hx_get="/admin", + hx_trigger="every 10s", + hx_swap="innerHTML", + ) + + return AdminView( + queue_items=all_queue_items, + episodes=all_episodes, + status_counts=status_counts, + user=user, + ) + + +def retry_queue_item(request: Request, job_id: int) -> Response: + """Retry a failed queue item.""" + try: + # Check if user owns this job or is admin + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + job = Core.Database.get_job_by_id( + job_id, + ) + if job is None: + return Response("Job not found", status_code=404) + + # Check ownership or admin status + user = Core.Database.get_user_by_id(user_id) + if job.get("user_id") != user_id and not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + Core.Database.retry_job(job_id) + + # Check if request is from admin page via referer header + is_from_admin = "/admin" in request.headers.get("referer", "") + + # Redirect to admin if from admin page, trigger update otherwise + if is_from_admin: + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin"}, + ) + return Response( + "", + status_code=200, + headers={"HX-Trigger": "queue-updated"}, + ) + except (ValueError, KeyError) as e: + return Response( + f"Error retrying job: {e!s}", + status_code=500, + ) + + +def delete_queue_item(request: Request, job_id: int) -> Response: + """Delete a queue item.""" + try: + # Check if user owns this job or is admin + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + job = Core.Database.get_job_by_id( + job_id, + ) + if job is None: + return Response("Job not found", status_code=404) + + # Check ownership or admin status + user = Core.Database.get_user_by_id(user_id) + if job.get("user_id") != user_id and not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + Core.Database.delete_job(job_id) + + # Check if request is from admin page via referer header + is_from_admin = "/admin" in request.headers.get("referer", "") + + # Redirect to admin if from admin page, trigger update otherwise + if is_from_admin: + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin"}, + ) + return Response( + "", + status_code=200, + headers={"HX-Trigger": "queue-updated"}, + ) + except (ValueError, KeyError) as e: + return Response( + f"Error deleting job: {e!s}", + status_code=500, + ) + + +def admin_users(request: Request) -> AdminUsers | Response: + """Admin page for managing users.""" + # Check if user is logged in and is admin + user_id = request.session.get("user_id") + if not user_id: + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id( + user_id, + ) + if not user or not Core.is_admin(user): + return Response( + "", + status_code=302, + headers={"Location": "/?error=forbidden"}, + ) + + # Get all users + with Core.Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id, email, created_at, status FROM users " + "ORDER BY created_at DESC", + ) + rows = cursor.fetchall() + users = [dict(row) for row in rows] + + return AdminUsers(users=users, user=user) + + +def update_user_status( + request: Request, + user_id: int, + data: FormData, +) -> Response: + """Update user account status.""" + # Check if user is logged in and is admin + session_user_id = request.session.get("user_id") + if not session_user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id( + session_user_id, + ) + if not user or not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + # Get new status from form data + new_status_raw = data.get("status", "pending") + new_status = ( + new_status_raw if isinstance(new_status_raw, str) else "pending" + ) + if new_status not in {"pending", "active", "disabled"}: + return Response("Invalid status", status_code=400) + + # Update user status + Core.Database.update_user_status( + user_id, + new_status, + ) + + # Redirect back to users page + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin/users"}, + ) + + +def toggle_episode_public(request: Request, episode_id: int) -> Response: + """Toggle episode public/private status.""" + # Check if user is logged in and is admin + session_user_id = request.session.get("user_id") + if not session_user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id(session_user_id) + if not user or not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + # Get current episode status + episode = Core.Database.get_episode_by_id(episode_id) + if not episode: + return Response("Episode not found", status_code=404) + + # Toggle public status + current_public = episode.get("is_public", 0) == 1 + if current_public: + Core.Database.unmark_episode_public(episode_id) + else: + Core.Database.mark_episode_public(episode_id) + + # Redirect to home page to see updated status + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/"}, + ) + + +def admin_metrics(request: Request) -> MetricsDashboard | Response: + """Admin metrics dashboard showing episode statistics.""" + # Check if user is logged in and is admin + user_id = request.session.get("user_id") + if not user_id: + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id( + user_id, + ) + if not user or not Core.is_admin(user): + return Response( + "", + status_code=302, + headers={"Location": "/?error=forbidden"}, + ) + + # Get metrics data + metrics = Core.Database.get_metrics_summary() + + return MetricsDashboard(metrics=metrics, user=user) + + +def main() -> None: + """Admin tests are currently in Web.""" + if "test" in sys.argv: + sys.exit(0) diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py new file mode 100644 index 0000000..9f3739d --- /dev/null +++ b/Biz/PodcastItLater/Billing.py @@ -0,0 +1,581 @@ +""" +PodcastItLater Billing Integration. + +Stripe subscription management and usage enforcement. +""" + +# : out podcastitlater-billing +# : dep stripe +# : dep pytest +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import stripe +import sys +import typing +from datetime import datetime +from datetime import timezone + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Stripe configuration +stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") + +# Price IDs from Stripe dashboard +STRIPE_PRICE_ID_PAID = os.getenv("STRIPE_PRICE_ID_PAID", "") + +# Map Stripe price IDs to tier names +PRICE_TO_TIER = { + STRIPE_PRICE_ID_PAID: "paid", +} + +# Tier limits (None = unlimited) +TIER_LIMITS: dict[str, dict[str, int | None]] = { + "free": { + "articles_per_period": 10, + "minutes_per_period": None, + }, + "paid": { + "articles_per_period": None, + "minutes_per_period": None, + }, +} + +# Price map for checkout +PRICE_MAP = { + "paid": STRIPE_PRICE_ID_PAID, +} + + +def get_period_boundaries( + user: dict[str, typing.Any], +) -> tuple[datetime, datetime]: + """Get billing period boundaries for user. + + For paid users: use Stripe subscription period. + For free users: lifetime (from account creation to far future). + + Returns: + tuple: (period_start, period_end) as datetime objects + """ + if user.get("plan_tier") != "free" and user.get("current_period_start"): + # Paid user - use Stripe billing period + period_start = datetime.fromisoformat(user["current_period_start"]) + period_end = datetime.fromisoformat(user["current_period_end"]) + else: + # Free user - lifetime limit from account creation + period_start = datetime.fromisoformat(user["created_at"]) + # Set far future end date (100 years from now) + now = datetime.now(timezone.utc) + period_end = now.replace(year=now.year + 100) + + return period_start, period_end + + +def get_usage( + user_id: int, + period_start: datetime, + period_end: datetime, +) -> dict[str, int]: + """Get usage stats for user in billing period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + return Core.Database.get_usage(user_id, period_start, period_end) + + +def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]: + """Check if user can submit article based on tier limits. + + Returns: + tuple: (allowed: bool, message: str, usage: dict) + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + return False, "User not found", {} + + tier = user.get("plan_tier", "free") + limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) + + # Get billing period boundaries + period_start, period_end = get_period_boundaries(user) + + # Get current usage + usage = get_usage(user_id, period_start, period_end) + + # Check article limit + article_limit = limits.get("articles_per_period") + if article_limit is not None and usage["articles"] >= article_limit: + msg = ( + f"You've reached your limit of {article_limit} articles " + "per period. Upgrade to continue." + ) + return (False, msg, usage) + + # Check minutes limit (if implemented) + minute_limit = limits.get("minutes_per_period") + if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: + return ( + False, + f"You've reached your limit of {minute_limit} minutes per period. " + "Please upgrade to continue.", + usage, + ) + + return True, "", usage + + +def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: + """Create Stripe Checkout session for subscription. + + Args: + user_id: User ID + tier: Subscription tier (paid) + base_url: Base URL for success/cancel redirects + + Returns: + Checkout session URL to redirect user to + + Raises: + ValueError: If tier is invalid or price ID not configured + """ + if tier not in PRICE_MAP: + msg = f"Invalid tier: {tier}" + raise ValueError(msg) + + price_id = PRICE_MAP[tier] + if not price_id: + msg = f"Stripe price ID not configured for tier: {tier}" + raise ValueError(msg) + + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + # Create checkout session + session_params = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "success_url": f"{base_url}/?status=success", + "cancel_url": f"{base_url}/?status=cancel", + "client_reference_id": str(user_id), + "metadata": {"user_id": str(user_id), "tier": tier}, + "allow_promotion_codes": True, + } + + # Use existing customer if available + if user.get("stripe_customer_id"): + session_params["customer"] = user["stripe_customer_id"] + else: + session_params["customer_email"] = user["email"] + + session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type] + + logger.info( + "Created checkout session for user %s, tier %s: %s", + user_id, + tier, + session.id, + ) + + return session.url # type: ignore[return-value] + + +def create_portal_session(user_id: int, base_url: str) -> str: + """Create Stripe Billing Portal session. + + Args: + user_id: User ID + base_url: Base URL for return redirect + + Returns: + Portal session URL to redirect user to + + Raises: + ValueError: If user has no Stripe customer ID or portal not configured + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + if not user.get("stripe_customer_id"): + msg = "User has no Stripe customer ID" + raise ValueError(msg) + + session = stripe.billing_portal.Session.create( + customer=user["stripe_customer_id"], + return_url=f"{base_url}/account", + ) + + logger.info( + "Created portal session for user %s: %s", + user_id, + session.id, + ) + + return session.url + + +def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]: + """Verify and process Stripe webhook event. + + Args: + payload: Raw webhook body + sig_header: Stripe-Signature header value + + Returns: + dict with processing status + + Note: + May raise stripe.error.SignatureVerificationError if invalid signature + """ + # Verify webhook signature (skip in test mode if secret not configured) + if STRIPE_WEBHOOK_SECRET: + event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call] + payload, + sig_header, + STRIPE_WEBHOOK_SECRET, + ) + else: + # Test mode without signature verification + logger.warning( + "Webhook signature verification skipped (no STRIPE_WEBHOOK_SECRET)", + ) + event = json.loads(payload.decode("utf-8")) + + event_id = event["id"] + event_type = event["type"] + + # Check if already processed (idempotency) + if Core.Database.has_processed_stripe_event(event_id): + logger.info("Skipping already processed event: %s", event_id) + return {"status": "skipped", "reason": "already_processed"} + + # Process event based on type + logger.info("Processing webhook event: %s (%s)", event_id, event_type) + + try: + if event_type == "checkout.session.completed": + _handle_checkout_completed(event["data"]["object"]) + elif event_type == "customer.subscription.created": + _handle_subscription_created(event["data"]["object"]) + elif event_type == "customer.subscription.updated": + _handle_subscription_updated(event["data"]["object"]) + elif event_type == "customer.subscription.deleted": + _handle_subscription_deleted(event["data"]["object"]) + elif event_type == "invoice.payment_failed": + _handle_payment_failed(event["data"]["object"]) + else: + logger.info("Unhandled event type: %s", event_type) + return {"status": "ignored", "type": event_type} + + # Mark event as processed + Core.Database.mark_stripe_event_processed(event_id, event_type, payload) + except Exception: + logger.exception("Error processing webhook event %s", event_id) + raise + else: + return {"status": "processed", "type": event_type} + + +def _handle_checkout_completed(session: dict[str, typing.Any]) -> None: + """Handle checkout.session.completed event.""" + client_ref = session.get("client_reference_id") or session.get( + "metadata", + {}, + ).get("user_id") + customer_id = session.get("customer") + + if not client_ref or not customer_id: + logger.warning( + "Missing user_id or customer_id in checkout session: %s", + session.get("id", "unknown"), + ) + return + + try: + user_id = int(client_ref) + except (ValueError, TypeError): + logger.warning( + "Invalid user_id in checkout session: %s", + client_ref, + ) + return + + # Link Stripe customer to user + Core.Database.set_user_stripe_customer(user_id, customer_id) + logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) + + +def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.created event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.updated event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.deleted event.""" + customer_id = subscription["customer"] + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Downgrade to free + Core.Database.downgrade_to_free(user["id"]) + logger.info("Downgraded user %s to free tier", user["id"]) + + +def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None: + """Handle invoice.payment_failed event.""" + customer_id = invoice["customer"] + subscription_id = invoice.get("subscription") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update subscription status to past_due + if subscription_id: + Core.Database.update_subscription_status(user["id"], "past_due") + logger.warning( + "Payment failed for user %s, subscription %s", + user["id"], + subscription_id, + ) + + +def _update_subscription_state(subscription: dict[str, typing.Any]) -> None: + """Update user subscription state from Stripe subscription object.""" + customer_id = subscription.get("customer") + subscription_id = subscription.get("id") + status = subscription.get("status") + cancel_at_period_end = subscription.get("cancel_at_period_end", False) + + if not customer_id or not subscription_id or not status: + logger.warning( + "Missing required fields in subscription: %s", + subscription_id, + ) + return + + # Get billing period - try multiple field names for API compatibility + period_start_ts = ( + subscription.get("current_period_start") + or subscription.get("billing_cycle_anchor") + or subscription.get("start_date") + ) + period_end_ts = subscription.get("current_period_end") + + if not period_start_ts: + logger.warning( + "Missing period start in subscription: %s", + subscription_id, + ) + return + + period_start = datetime.fromtimestamp(period_start_ts, tz=timezone.utc) + + # Calculate period end if not provided (assume monthly) + december = 12 + january = 1 + if not period_end_ts: + if period_start.month == december: + period_end = period_start.replace( + year=period_start.year + 1, + month=january, + ) + else: + period_end = period_start.replace(month=period_start.month + 1) + else: + period_end = datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + + # Determine tier from price ID + items = subscription.get("items", {}) + data = items.get("data", []) + if not data: + logger.warning("No items in subscription: %s", subscription_id) + return + + price_id = data[0].get("price", {}).get("id") + if not price_id: + logger.warning("No price ID in subscription: %s", subscription_id) + return + + tier = PRICE_TO_TIER.get(price_id, "free") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update user subscription + Core.Database.update_user_subscription( + user["id"], + subscription_id, + status, + period_start, + period_end, + tier, + cancel_at_period_end, + ) + + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user["id"], + tier, + status, + ) + + +def get_tier_info(tier: str) -> dict[str, typing.Any]: + """Get tier information for display. + + Returns: + dict with keys: name, articles_limit, price, description + """ + tier_info: dict[str, dict[str, typing.Any]] = { + "free": { + "name": "Free", + "articles_limit": 10, + "price": "$0", + "description": "10 articles total", + }, + "paid": { + "name": "Paid", + "articles_limit": None, + "price": "$12/mo", + "description": "Unlimited articles", + }, + } + return tier_info.get(tier, tier_info["free"]) + + +# Tests +# ruff: noqa: PLR6301, PLW0603, S101 + + +class TestWebhookHandling(Test.TestCase): + """Test Stripe webhook handling.""" + + def setUp(self) -> None: + """Set up test database.""" + Core.Database.init_db() + + def tearDown(self) -> None: + """Clean up test database.""" + Core.Database.teardown() + + def test_full_checkout_flow(self) -> None: + """Test complete checkout flow from session to subscription.""" + # Create test user + user_id, _token = Core.Database.create_user("test@example.com") + + # Temporarily set price mapping for test + global PRICE_TO_TIER + old_mapping = PRICE_TO_TIER.copy() + PRICE_TO_TIER["price_test_paid"] = "paid" + + try: + # Step 1: Handle checkout.session.completed + checkout_session = { + "id": "cs_test123", + "client_reference_id": str(user_id), + "customer": "cus_test123", + "metadata": {"user_id": str(user_id), "tier": "paid"}, + } + _handle_checkout_completed(checkout_session) + + # Verify customer was linked + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["stripe_customer_id"], "cus_test123") + + # Step 2: Handle customer.subscription.created + # (newer API uses billing_cycle_anchor instead of current_period_*) + subscription = { + "id": "sub_test123", + "customer": "cus_test123", + "status": "active", + "billing_cycle_anchor": 1700000000, + "cancel_at_period_end": False, + "items": { + "data": [ + { + "price": { + "id": "price_test_paid", + }, + }, + ], + }, + } + _update_subscription_state(subscription) + + # Verify subscription was created and user upgraded + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["plan_tier"], "paid") + self.assertEqual(user["subscription_status"], "active") + self.assertEqual(user["stripe_subscription_id"], "sub_test123") + self.assertEqual(user["stripe_customer_id"], "cus_test123") + finally: + PRICE_TO_TIER = old_mapping + + def test_webhook_missing_fields(self) -> None: + """Test handling webhook with missing required fields.""" + # Create test user + user_id, _token = Core.Database.create_user("test@example.com") + Core.Database.set_user_stripe_customer(user_id, "cus_test456") + + # Mock subscription with missing current_period_start + subscription = { + "id": "sub_test456", + "customer": "cus_test456", + "status": "active", + # Missing current_period_start and current_period_end + "cancel_at_period_end": False, + "items": {"data": []}, + } + + # Should not crash, just log warning and return + _update_subscription_state(subscription) + + # User should remain on free tier + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["plan_tier"], "free") + + +def main() -> None: + """Run tests.""" + if len(sys.argv) > 1 and sys.argv[1] == "test": + os.environ["AREA"] = "Test" + Test.run(App.Area.Test, [TestWebhookHandling]) + else: + logger.error("Usage: billing.py test") + + +if __name__ == "__main__": + main() diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py new file mode 100644 index 0000000..3a88f22 --- /dev/null +++ b/Biz/PodcastItLater/Core.py @@ -0,0 +1,2174 @@ +"""Core, shared logic for PodcastItalater. + +Includes: +- Database models +- Data access layer +- Shared types +""" + +# : out podcastitlater-core +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import hashlib +import logging +import Omni.App as App +import Omni.Test as Test +import os +import pathlib +import pytest +import secrets +import sqlite3 +import sys +import time +import typing +import urllib.parse +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = logging.getLogger(__name__) + + +CODEROOT = pathlib.Path(os.getenv("CODEROOT", ".")) +DATA_DIR = pathlib.Path( + os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"), +) + +# Constants for UI display +URL_TRUNCATE_LENGTH = 80 +TITLE_TRUNCATE_LENGTH = 50 +ERROR_TRUNCATE_LENGTH = 50 + +# Admin whitelist +ADMIN_EMAILS = ["ben@bensima.com", "admin@example.com"] + + +def is_admin(user: dict[str, typing.Any] | None) -> bool: + """Check if user is an admin based on email whitelist.""" + if not user: + return False + return user.get("email", "").lower() in [ + email.lower() for email in ADMIN_EMAILS + ] + + +def normalize_url(url: str) -> str: + """Normalize URL for comparison and hashing. + + Normalizes: + - Protocol (http/https) + - Domain case (lowercase) + - www prefix (removed) + - Trailing slash (removed) + - Preserves query params and fragments as they may be meaningful + + Args: + url: URL to normalize + + Returns: + Normalized URL string + """ + parsed = urllib.parse.urlparse(url.strip()) + + # Normalize domain to lowercase, remove www prefix + domain = parsed.netloc.lower() + domain = domain.removeprefix("www.") + + # Normalize path - remove trailing slash unless it's the root + path = parsed.path.rstrip("/") if parsed.path != "/" else "/" + + # Rebuild URL with normalized components + # Use https as the canonical protocol + return urllib.parse.urlunparse(( + "https", # Always use https + domain, + path, + parsed.params, + parsed.query, + parsed.fragment, + )) + + +def hash_url(url: str) -> str: + """Generate a hash of a URL for deduplication. + + Args: + url: URL to hash + + Returns: + SHA256 hash of the normalized URL + """ + normalized = normalize_url(url) + return hashlib.sha256(normalized.encode()).hexdigest() + + +class Database: # noqa: PLR0904 + """Data access layer for PodcastItLater database operations.""" + + @staticmethod + def teardown() -> None: + """Delete the existing database, for cleanup after tests.""" + db_path = DATA_DIR / "podcast.db" + if db_path.exists(): + db_path.unlink() + + @staticmethod + @contextmanager + def get_connection() -> Iterator[sqlite3.Connection]: + """Context manager for database connections. + + Yields: + sqlite3.Connection: Database connection with row factory set. + """ + db_path = DATA_DIR / "podcast.db" + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + @staticmethod + def init_db() -> None: + """Initialize database with required tables.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Queue table for job processing + cursor.execute(""" + CREATE TABLE IF NOT EXISTS queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + email TEXT, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + error_message TEXT, + title TEXT, + author TEXT + ) + """) + + # Episodes table for completed podcasts + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + content_length INTEGER, + audio_url TEXT NOT NULL, + duration INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes for performance + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_created " + "ON queue(created_at)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_created " + "ON episodes(created_at)", + ) + + conn.commit() + logger.info("Database initialized successfully") + + # Run migration to add user support + Database.migrate_to_multi_user() + + # Run migration to add metadata fields + Database.migrate_add_metadata_fields() + + # Run migration to add episode metadata fields + Database.migrate_add_episode_metadata() + + # Run migration to add user status field + Database.migrate_add_user_status() + + # Run migration to add default titles + Database.migrate_add_default_titles() + + # Run migration to add billing fields + Database.migrate_add_billing_fields() + + # Run migration to add stripe events table + Database.migrate_add_stripe_events_table() + + # Run migration to add public feed features + Database.migrate_add_public_feed() + + @staticmethod + def add_to_queue( + url: str, + email: str, + user_id: int, + title: str | None = None, + author: str | None = None, + ) -> int: + """Insert new job into queue with metadata, return job ID. + + Raises: + ValueError: If job ID cannot be retrieved after insert. + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO queue (url, email, user_id, title, author) " + "VALUES (?, ?, ?, ?, ?)", + (url, email, user_id, title, author), + ) + conn.commit() + job_id = cursor.lastrowid + if job_id is None: + msg = "Failed to get job ID after insert" + raise ValueError(msg) + logger.info("Added job %s to queue: %s", job_id, url) + return job_id + + @staticmethod + def get_pending_jobs( + limit: int = 10, + ) -> list[dict[str, Any]]: + """Fetch jobs with status='pending' ordered by creation time.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'pending' " + "ORDER BY created_at ASC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_job_status( + job_id: int, + status: str, + error: str | None = None, + ) -> None: + """Update job status and error message.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if error is not None: + if status == "error": + cursor.execute( + "UPDATE queue SET status = ?, error_message = ?, " + "retry_count = retry_count + 1 WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ?, " + "error_message = ? WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ? WHERE id = ?", + (status, job_id), + ) + conn.commit() + logger.info("Updated job %s status to %s", job_id, status) + + @staticmethod + def get_job_by_id( + job_id: int, + ) -> dict[str, Any] | None: + """Fetch single job by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def create_episode( # noqa: PLR0913, PLR0917 + title: str, + audio_url: str, + duration: int, + content_length: int, + user_id: int | None = None, + author: str | None = None, + original_url: str | None = None, + original_url_hash: str | None = None, + ) -> int: + """Insert episode record, return episode ID. + + Raises: + ValueError: If episode ID cannot be retrieved after insert. + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episodes " + "(title, audio_url, duration, content_length, user_id, " + "author, original_url, original_url_hash) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + title, + audio_url, + duration, + content_length, + user_id, + author, + original_url, + original_url_hash, + ), + ) + conn.commit() + episode_id = cursor.lastrowid + if episode_id is None: + msg = "Failed to get episode ID after insert" + raise ValueError(msg) + logger.info("Created episode %s: %s", episode_id, title) + return episode_id + + @staticmethod + def get_recent_episodes( + limit: int = 20, + ) -> list[dict[str, Any]]: + """Get recent episodes for RSS feed generation.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_queue_status_summary() -> dict[str, Any]: + """Get queue status summary for web interface.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count jobs by status + cursor.execute( + "SELECT status, COUNT(*) as count FROM queue GROUP BY status", + ) + rows = cursor.fetchall() + status_counts = {row["status"]: row["count"] for row in rows} + + # Get recent jobs + cursor.execute( + "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10", + ) + rows = cursor.fetchall() + recent_jobs = [dict(row) for row in rows] + + return {"status_counts": status_counts, "recent_jobs": recent_jobs} + + @staticmethod + def get_queue_status() -> list[dict[str, Any]]: + """Return pending/processing/error items for web interface.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, url, email, status, created_at, error_message, + title, author + FROM queue + WHERE status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) + ORDER BY created_at DESC + LIMIT 20 + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_episode_by_id(episode_id: int) -> dict[str, Any] | None: + """Fetch single episode by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url, user_id, is_public + FROM episodes + WHERE id = ? + """, + (episode_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_all_episodes( + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all episodes for RSS feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_retryable_jobs( + max_retries: int = 3, + ) -> list[dict[str, Any]]: + """Get failed jobs that can be retried.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'error' " + "AND retry_count < ? ORDER BY created_at ASC", + (max_retries,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def retry_job(job_id: int) -> None: + """Reset a job to pending status for retry.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE queue SET status = 'pending', " + "error_message = NULL WHERE id = ?", + (job_id,), + ) + conn.commit() + logger.info("Reset job %s to pending for retry", job_id) + + @staticmethod + def delete_job(job_id: int) -> None: + """Delete a job from the queue.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) + conn.commit() + logger.info("Deleted job %s from queue", job_id) + + @staticmethod + def get_all_queue_items( + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all queue items for admin view.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, url, email, status, retry_count, created_at, + error_message, title, author + FROM queue + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, url, email, status, retry_count, created_at, + error_message, title, author + FROM queue + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_status_counts() -> dict[str, int]: + """Get count of queue items by status.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT status, COUNT(*) as count + FROM queue + GROUP BY status + """) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def get_user_status_counts( + user_id: int, + ) -> dict[str, int]: + """Get count of queue items by status for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) as count + FROM queue + WHERE user_id = ? + GROUP BY status + """, + (user_id,), + ) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def migrate_to_multi_user() -> None: + """Migrate database to support multiple users.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add user_id columns to existing tables + # Check if columns already exist to make migration idempotent + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "user_id" not in queue_columns: + cursor.execute( + "ALTER TABLE queue ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "user_id" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + # Create indexes + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_user_id " + "ON queue(user_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_user_id " + "ON episodes(user_id)", + ) + + conn.commit() + logger.info("Database migrated to support multiple users") + + @staticmethod + def migrate_add_metadata_fields() -> None: + """Add title and author fields to queue table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if columns already exist + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "title" not in queue_columns: + cursor.execute("ALTER TABLE queue ADD COLUMN title TEXT") + + if "author" not in queue_columns: + cursor.execute("ALTER TABLE queue ADD COLUMN author TEXT") + + conn.commit() + logger.info("Database migrated to support metadata fields") + + @staticmethod + def migrate_add_episode_metadata() -> None: + """Add author and original_url fields to episodes table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if columns already exist + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "author" not in episodes_columns: + cursor.execute("ALTER TABLE episodes ADD COLUMN author TEXT") + + if "original_url" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN original_url TEXT", + ) + + conn.commit() + logger.info("Database migrated to support episode metadata fields") + + @staticmethod + def migrate_add_user_status() -> None: + """Add status field to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if column already exists + cursor.execute("PRAGMA table_info(users)") + users_info = cursor.fetchall() + users_columns = [col[1] for col in users_info] + + if "status" not in users_columns: + # Add status column with default 'active' + cursor.execute( + "ALTER TABLE users ADD COLUMN status TEXT DEFAULT 'active'", + ) + + conn.commit() + logger.info("Database migrated to support user status") + + @staticmethod + def migrate_add_billing_fields() -> None: + """Add billing-related fields to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add columns one by one (SQLite limitation) + # Note: SQLite ALTER TABLE doesn't support adding UNIQUE constraints + # We add them without UNIQUE and rely on application logic + columns_to_add = [ + ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"), + ("stripe_customer_id", "TEXT"), + ("stripe_subscription_id", "TEXT"), + ("subscription_status", "TEXT"), + ("current_period_start", "TIMESTAMP"), + ("current_period_end", "TIMESTAMP"), + ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"), + ] + + for column_name, column_def in columns_to_add: + try: + query = f"ALTER TABLE users ADD COLUMN {column_name} " + cursor.execute(query + column_def) + logger.info("Added column users.%s", column_name) + except sqlite3.OperationalError as e: # noqa: PERF203 + # Column already exists, skip + logger.debug( + "Column users.%s already exists: %s", + column_name, + e, + ) + + conn.commit() + + @staticmethod + def migrate_add_stripe_events_table() -> None: + """Create stripe_events table for webhook idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS stripe_events ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_stripe_events_created " + "ON stripe_events(created_at)", + ) + conn.commit() + logger.info("Created stripe_events table") + + @staticmethod + def migrate_add_public_feed() -> None: + """Add is_public column and related tables for public feed feature.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add is_public column to episodes + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "is_public" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN is_public INTEGER " + "NOT NULL DEFAULT 0", + ) + logger.info("Added is_public column to episodes") + + # Create user_episodes junction table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS user_episodes ( + user_id INTEGER NOT NULL, + episode_id INTEGER NOT NULL, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (user_id, episode_id), + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (episode_id) REFERENCES episodes(id) + ) + """) + logger.info("Created user_episodes junction table") + + # Create index on episode_id for reverse lookups + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_user_episodes_episode " + "ON user_episodes(episode_id)", + ) + + # Add original_url_hash column to episodes + if "original_url_hash" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN original_url_hash TEXT", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_url_hash " + "ON episodes(original_url_hash)", + ) + logger.info("Added original_url_hash column to episodes") + + # Create episode_metrics table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + episode_id INTEGER NOT NULL, + user_id INTEGER, + event_type TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (episode_id) REFERENCES episodes(id), + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + logger.info("Created episode_metrics table") + + # Create indexes for metrics queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episode_metrics_episode " + "ON episode_metrics(episode_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episode_metrics_type " + "ON episode_metrics(event_type)", + ) + + # Create index on is_public for efficient public feed queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_public " + "ON episodes(is_public)", + ) + + conn.commit() + logger.info("Database migrated for public feed feature") + + @staticmethod + def migrate_add_default_titles() -> None: + """Add default titles to queue items that have None titles.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Update queue items with NULL titles to have a default + cursor.execute(""" + UPDATE queue + SET title = 'Untitled Article' + WHERE title IS NULL + """) + + # Get count of updated rows + updated_count = cursor.rowcount + + conn.commit() + logger.info( + "Updated %s queue items with default titles", + updated_count, + ) + + @staticmethod + def create_user(email: str, status: str = "active") -> tuple[int, str]: + """Create a new user and return (user_id, token). + + Args: + email: User email address + status: Initial status (active or disabled) + + Raises: + ValueError: If user ID cannot be retrieved after insert or if user + not found, or if status is invalid. + """ + if status not in {"pending", "active", "disabled"}: + msg = f"Invalid status: {status}" + raise ValueError(msg) + + # Generate a secure token for RSS feed access + token = secrets.token_urlsafe(32) + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO users (email, token, status) VALUES (?, ?, ?)", + (email, token, status), + ) + conn.commit() + user_id = cursor.lastrowid + if user_id is None: + msg = "Failed to get user ID after insert" + raise ValueError(msg) + logger.info( + "Created user %s with email %s (status: %s)", + user_id, + email, + status, + ) + except sqlite3.IntegrityError: + # User already exists + cursor.execute( + "SELECT id, token FROM users WHERE email = ?", + (email,), + ) + row = cursor.fetchone() + if row is None: + msg = f"User with email {email} not found" + raise ValueError(msg) from None + return int(row["id"]), str(row["token"]) + else: + return int(user_id), str(token) + + @staticmethod + def get_user_by_email( + email: str, + ) -> dict[str, Any] | None: + """Get user by email address.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_token( + token: str, + ) -> dict[str, Any] | None: + """Get user by RSS token.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_id( + user_id: int, + ) -> dict[str, Any] | None: + """Get user by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_queue_position(job_id: int) -> int | None: + """Get position of job in pending queue.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + # Get created_at of this job + cursor.execute( + "SELECT created_at FROM queue WHERE id = ?", + (job_id,), + ) + row = cursor.fetchone() + if not row: + return None + created_at = row[0] + + # Count pending items created before or at same time + cursor.execute( + """ + SELECT COUNT(*) FROM queue + WHERE status = 'pending' AND created_at <= ? + """, + (created_at,), + ) + return int(cursor.fetchone()[0]) + + @staticmethod + def get_user_queue_status( + user_id: int, + ) -> list[dict[str, Any]]: + """Return pending/processing/error items for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, url, email, status, created_at, error_message, + title, author + FROM queue + WHERE user_id = ? AND + status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) + ORDER BY created_at DESC + LIMIT 20 + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_recent_episodes( + user_id: int, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Get recent episodes for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC LIMIT ?", + (user_id, limit), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_all_episodes( + user_id: int, + ) -> list[dict[str, Any]]: + """Get all episodes for a specific user for RSS feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC", + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_user_status( + user_id: int, + status: str, + ) -> None: + """Update user account status.""" + if status not in {"pending", "active", "disabled"}: + msg = f"Invalid status: {status}" + raise ValueError(msg) + + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info("Updated user %s status to %s", user_id, status) + + @staticmethod + def delete_user(user_id: int) -> None: + """Delete user and all associated data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # 1. Get owned episode IDs + cursor.execute( + "SELECT id FROM episodes WHERE user_id = ?", + (user_id,), + ) + owned_episode_ids = [row[0] for row in cursor.fetchall()] + + # 2. Delete references to owned episodes + if owned_episode_ids: + # Construct placeholders for IN clause + placeholders = ",".join("?" * len(owned_episode_ids)) + + # Delete from user_episodes where these episodes are referenced + query = f"DELETE FROM user_episodes WHERE episode_id IN ({placeholders})" # noqa: S608, E501 + cursor.execute(query, tuple(owned_episode_ids)) + + # Delete metrics for these episodes + query = f"DELETE FROM episode_metrics WHERE episode_id IN ({placeholders})" # noqa: S608, E501 + cursor.execute(query, tuple(owned_episode_ids)) + + # 3. Delete owned episodes + cursor.execute("DELETE FROM episodes WHERE user_id = ?", (user_id,)) + + # 4. Delete user's data referencing others or themselves + cursor.execute( + "DELETE FROM user_episodes WHERE user_id = ?", + (user_id,), + ) + cursor.execute( + "DELETE FROM episode_metrics WHERE user_id = ?", + (user_id,), + ) + cursor.execute("DELETE FROM queue WHERE user_id = ?", (user_id,)) + + # 5. Delete user + cursor.execute("DELETE FROM users WHERE id = ?", (user_id,)) + + conn.commit() + logger.info("Deleted user %s and all associated data", user_id) + + @staticmethod + def update_user_email(user_id: int, new_email: str) -> None: + """Update user's email address. + + Args: + user_id: ID of the user to update + new_email: New email address + + Raises: + ValueError: If email is already taken by another user + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "UPDATE users SET email = ? WHERE id = ?", + (new_email, user_id), + ) + conn.commit() + logger.info("Updated user %s email to %s", user_id, new_email) + except sqlite3.IntegrityError: + msg = f"Email {new_email} is already taken" + raise ValueError(msg) from None + + @staticmethod + def mark_episode_public(episode_id: int) -> None: + """Mark an episode as public.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE episodes SET is_public = 1 WHERE id = ?", + (episode_id,), + ) + conn.commit() + logger.info("Marked episode %s as public", episode_id) + + @staticmethod + def unmark_episode_public(episode_id: int) -> None: + """Mark an episode as private (not public).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE episodes SET is_public = 0 WHERE id = ?", + (episode_id,), + ) + conn.commit() + logger.info("Unmarked episode %s as public", episode_id) + + @staticmethod + def get_public_episodes(limit: int = 50) -> list[dict[str, Any]]: + """Get public episodes for public feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + WHERE is_public = 1 + ORDER BY created_at DESC + LIMIT ? + """, + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def add_episode_to_user(user_id: int, episode_id: int) -> None: + """Add an episode to a user's feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO user_episodes (user_id, episode_id) " + "VALUES (?, ?)", + (user_id, episode_id), + ) + conn.commit() + logger.info( + "Added episode %s to user %s feed", + episode_id, + user_id, + ) + except sqlite3.IntegrityError: + # Episode already in user's feed + logger.info( + "Episode %s already in user %s feed", + episode_id, + user_id, + ) + + @staticmethod + def user_has_episode(user_id: int, episode_id: int) -> bool: + """Check if a user has an episode in their feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM user_episodes " + "WHERE user_id = ? AND episode_id = ?", + (user_id, episode_id), + ) + return cursor.fetchone() is not None + + @staticmethod + def track_episode_metric( + episode_id: int, + event_type: str, + user_id: int | None = None, + ) -> None: + """Track an episode metric event. + + Args: + episode_id: ID of the episode + event_type: Type of event ('added', 'played', 'downloaded') + user_id: Optional user ID (None for anonymous events) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episode_metrics (episode_id, user_id, event_type) " + "VALUES (?, ?, ?)", + (episode_id, user_id, event_type), + ) + conn.commit() + logger.info( + "Tracked %s event for episode %s (user: %s)", + event_type, + episode_id, + user_id or "anonymous", + ) + + @staticmethod + def get_user_episodes(user_id: int) -> list[dict[str, Any]]: + """Get all episodes in a user's feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT e.id, e.title, e.audio_url, e.duration, e.created_at, + e.content_length, e.author, e.original_url, e.is_public, + ue.added_at + FROM episodes e + JOIN user_episodes ue ON e.id = ue.episode_id + WHERE ue.user_id = ? + ORDER BY ue.added_at DESC + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_episode_by_url_hash(url_hash: str) -> dict[str, Any] | None: + """Get episode by original URL hash.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE original_url_hash = ?", + (url_hash,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_metrics_summary() -> dict[str, Any]: + """Get aggregate metrics summary for admin dashboard. + + Returns: + dict with keys: + - total_episodes: Total number of episodes + - total_plays: Total play events + - total_downloads: Total download events + - total_adds: Total add events + - most_played: List of top 10 most played episodes + - most_downloaded: List of top 10 most downloaded episodes + - most_added: List of top 10 most added episodes + - total_users: Total number of users + - active_subscriptions: Number of active subscriptions + - submissions_24h: Submissions in last 24 hours + - submissions_7d: Submissions in last 7 days + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Get total episodes + cursor.execute("SELECT COUNT(*) as count FROM episodes") + total_episodes = cursor.fetchone()["count"] + + # Get event counts + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'played'", + ) + total_plays = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'downloaded'", + ) + total_downloads = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'added'", + ) + total_adds = cursor.fetchone()["count"] + + # Get most played episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as play_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'played' + GROUP BY em.episode_id + ORDER BY play_count DESC + LIMIT 10 + """, + ) + most_played = [dict(row) for row in cursor.fetchall()] + + # Get most downloaded episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as download_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'downloaded' + GROUP BY em.episode_id + ORDER BY download_count DESC + LIMIT 10 + """, + ) + most_downloaded = [dict(row) for row in cursor.fetchall()] + + # Get most added episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as add_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'added' + GROUP BY em.episode_id + ORDER BY add_count DESC + LIMIT 10 + """, + ) + most_added = [dict(row) for row in cursor.fetchall()] + + # Get user metrics + cursor.execute("SELECT COUNT(*) as count FROM users") + total_users = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM users " + "WHERE subscription_status = 'active'", + ) + active_subscriptions = cursor.fetchone()["count"] + + # Get recent submission metrics + cursor.execute( + "SELECT COUNT(*) as count FROM queue " + "WHERE created_at >= datetime('now', '-1 day')", + ) + submissions_24h = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM queue " + "WHERE created_at >= datetime('now', '-7 days')", + ) + submissions_7d = cursor.fetchone()["count"] + + return { + "total_episodes": total_episodes, + "total_plays": total_plays, + "total_downloads": total_downloads, + "total_adds": total_adds, + "most_played": most_played, + "most_downloaded": most_downloaded, + "most_added": most_added, + "total_users": total_users, + "active_subscriptions": active_subscriptions, + "submissions_24h": submissions_24h, + "submissions_7d": submissions_7d, + } + + @staticmethod + def track_episode_event( + episode_id: int, + event_type: str, + user_id: int | None = None, + ) -> None: + """Track an episode event (added, played, downloaded).""" + if event_type not in {"added", "played", "downloaded"}: + msg = f"Invalid event type: {event_type}" + raise ValueError(msg) + + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episode_metrics " + "(episode_id, user_id, event_type) VALUES (?, ?, ?)", + (episode_id, user_id, event_type), + ) + conn.commit() + logger.info( + "Tracked %s event for episode %s", + event_type, + episode_id, + ) + + @staticmethod + def get_episode_metrics(episode_id: int) -> dict[str, int]: + """Get aggregated metrics for an episode.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT event_type, COUNT(*) as count + FROM episode_metrics + WHERE episode_id = ? + GROUP BY event_type + """, + (episode_id,), + ) + rows = cursor.fetchall() + return {row["event_type"]: row["count"] for row in rows} + + @staticmethod + def get_episode_metric_events(episode_id: int) -> list[dict[str, Any]]: + """Get raw metric events for an episode (for testing).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, episode_id, user_id, event_type, created_at + FROM episode_metrics + WHERE episode_id = ? + ORDER BY created_at DESC + """, + (episode_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def set_user_stripe_customer(user_id: int, customer_id: str) -> None: + """Link Stripe customer ID to user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET stripe_customer_id = ? WHERE id = ?", + (customer_id, user_id), + ) + conn.commit() + logger.info( + "Linked user %s to Stripe customer %s", + user_id, + customer_id, + ) + + @staticmethod + def update_user_subscription( # noqa: PLR0913, PLR0917 + user_id: int, + subscription_id: str, + status: str, + period_start: Any, + period_end: Any, + tier: str, + cancel_at_period_end: bool, # noqa: FBT001 + ) -> None: + """Update user subscription details.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + stripe_subscription_id = ?, + subscription_status = ?, + current_period_start = ?, + current_period_end = ?, + plan_tier = ?, + cancel_at_period_end = ? + WHERE id = ? + """, + ( + subscription_id, + status, + period_start.isoformat(), + period_end.isoformat(), + tier, + 1 if cancel_at_period_end else 0, + user_id, + ), + ) + conn.commit() + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user_id, + tier, + status, + ) + + @staticmethod + def update_subscription_status(user_id: int, status: str) -> None: + """Update only the subscription status (e.g., past_due).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET subscription_status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info( + "Updated user %s subscription status to %s", + user_id, + status, + ) + + @staticmethod + def downgrade_to_free(user_id: int) -> None: + """Downgrade user to free tier and clear subscription data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + plan_tier = 'free', + subscription_status = 'canceled', + stripe_subscription_id = NULL, + current_period_start = NULL, + current_period_end = NULL, + cancel_at_period_end = 0 + WHERE id = ? + """, + (user_id,), + ) + conn.commit() + logger.info("Downgraded user %s to free tier", user_id) + + @staticmethod + def get_user_by_stripe_customer_id( + customer_id: str, + ) -> dict[str, Any] | None: + """Get user by Stripe customer ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM users WHERE stripe_customer_id = ?", + (customer_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def has_processed_stripe_event(event_id: str) -> bool: + """Check if Stripe event has already been processed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id FROM stripe_events WHERE id = ?", + (event_id,), + ) + return cursor.fetchone() is not None + + @staticmethod + def mark_stripe_event_processed( + event_id: str, + event_type: str, + payload: bytes, + ) -> None: + """Mark Stripe event as processed for idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO stripe_events (id, type, payload) " + "VALUES (?, ?, ?)", + (event_id, event_type, payload.decode("utf-8")), + ) + conn.commit() + + @staticmethod + def get_usage( + user_id: int, + period_start: Any, + period_end: Any, + ) -> dict[str, int]: + """Get usage stats for user in period. + + Counts episodes added to user's feed (via user_episodes table) + during the billing period, regardless of who created them. + + Returns: + dict with keys: articles (int), minutes (int) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count articles added to user's feed in period + # Uses user_episodes junction table to track when episodes + # were added, which correctly handles shared/existing episodes + cursor.execute( + """ + SELECT COUNT(*) as count, SUM(e.duration) as total_seconds + FROM user_episodes ue + JOIN episodes e ON e.id = ue.episode_id + WHERE ue.user_id = ? AND ue.added_at >= ? AND ue.added_at < ? + """, + (user_id, period_start.isoformat(), period_end.isoformat()), + ) + row = cursor.fetchone() + + articles = row["count"] if row else 0 + total_seconds = ( + row["total_seconds"] if row and row["total_seconds"] else 0 + ) + minutes = total_seconds // 60 + + return {"articles": articles, "minutes": minutes} + + +class TestDatabase(Test.TestCase): + """Test the Database class.""" + + @staticmethod + def setUp() -> None: + """Set up test database.""" + Database.init_db() + + def tearDown(self) -> None: + """Clean up test database.""" + Database.teardown() + # Clear user ID + self.user_id = None + + def test_init_db(self) -> None: + """Verify all tables and indexes are created correctly.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check tables exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + self.assertIn("queue", tables) + self.assertIn("episodes", tables) + self.assertIn("users", tables) + + # Check indexes exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='index'") + indexes = {row[0] for row in cursor.fetchall()} + self.assertIn("idx_queue_status", indexes) + self.assertIn("idx_queue_created", indexes) + self.assertIn("idx_episodes_created", indexes) + self.assertIn("idx_queue_user_id", indexes) + self.assertIn("idx_episodes_user_id", indexes) + + def test_connection_context_manager(self) -> None: + """Ensure connections are properly closed.""" + # Get a connection and verify it works + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchone() + self.assertEqual(result[0], 1) + + # Connection should be closed after context manager + with pytest.raises(sqlite3.ProgrammingError): + cursor.execute("SELECT 1") + + def test_migration_idempotency(self) -> None: + """Verify migrations can run multiple times safely.""" + # Run migration multiple times + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + + # Should still work fine + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users") + # Should not raise an error + # Test completed successfully - migration worked + self.assertIsNotNone(conn) + + def test_get_metrics_summary_extended(self) -> None: + """Verify extended metrics summary.""" + # Create some data + user_id, _ = Database.create_user("test@example.com") + Database.create_episode( + "Test Article", + "url", + 100, + 1000, + user_id, + ) + + # Create a queue item + Database.add_to_queue( + "https://example.com", + "test@example.com", + user_id, + ) + + metrics = Database.get_metrics_summary() + + self.assertIn("total_users", metrics) + self.assertIn("active_subscriptions", metrics) + self.assertIn("submissions_24h", metrics) + self.assertIn("submissions_7d", metrics) + + self.assertEqual(metrics["total_users"], 1) + self.assertEqual(metrics["submissions_24h"], 1) + self.assertEqual(metrics["submissions_7d"], 1) + + +class TestUserManagement(Test.TestCase): + """Test user management functionality.""" + + @staticmethod + def setUp() -> None: + """Set up test database.""" + Database.init_db() + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_create_user(self) -> None: + """Create user with unique email and token.""" + user_id, token = Database.create_user("test@example.com") + + self.assertIsInstance(user_id, int) + self.assertIsInstance(token, str) + self.assertGreater(len(token), 20) # Should be a secure token + + def test_create_duplicate_user(self) -> None: + """Verify duplicate emails return existing user.""" + # Create first user + user_id1, token1 = Database.create_user( + "test@example.com", + ) + + # Try to create duplicate + user_id2, token2 = Database.create_user( + "test@example.com", + ) + + # Should return same user + self.assertIsNotNone(user_id1) + self.assertIsNotNone(user_id2) + self.assertEqual(user_id1, user_id2) + self.assertEqual(token1, token2) + + def test_get_user_by_email(self) -> None: + """Retrieve user by email.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_email("test@example.com") + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_get_user_by_token(self) -> None: + """Retrieve user by RSS token.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_token(token) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + + def test_get_user_by_id(self) -> None: + """Retrieve user by ID.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_invalid_user_lookups(self) -> None: + """Verify None returned for non-existent users.""" + self.assertIsNone( + Database.get_user_by_email("nobody@example.com"), + ) + self.assertIsNone( + Database.get_user_by_token("invalid-token"), + ) + self.assertIsNone(Database.get_user_by_id(9999)) + + def test_token_uniqueness(self) -> None: + """Ensure tokens are cryptographically unique.""" + tokens = set() + for i in range(10): + _, token = Database.create_user( + f"user{i}@example.com", + ) + tokens.add(token) + + # All tokens should be unique + self.assertEqual(len(tokens), 10) + + def test_delete_user(self) -> None: + """Test user deletion and cleanup.""" + # Create user + user_id, _ = Database.create_user("delete_me@example.com") + + # Create some data for the user + Database.add_to_queue( + "https://example.com/article", + "delete_me@example.com", + user_id, + ) + + ep_id = Database.create_episode( + title="Test Episode", + audio_url="url", + duration=100, + content_length=1000, + user_id=user_id, + ) + Database.add_episode_to_user(user_id, ep_id) + Database.track_episode_metric(ep_id, "played", user_id) + + # Delete user + Database.delete_user(user_id) + + # Verify user is gone + self.assertIsNone(Database.get_user_by_id(user_id)) + + # Verify queue items are gone + queue = Database.get_user_queue_status(user_id) + self.assertEqual(len(queue), 0) + + # Verify episodes are gone (direct lookup) + self.assertIsNone(Database.get_episode_by_id(ep_id)) + + def test_update_user_email(self) -> None: + """Update user email address.""" + user_id, _ = Database.create_user("old@example.com") + + # Update email + Database.update_user_email(user_id, "new@example.com") + + # Verify update + user = Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + if user: + self.assertEqual(user["email"], "new@example.com") + + # Old email should not exist + self.assertIsNone(Database.get_user_by_email("old@example.com")) + + @staticmethod + def test_update_user_email_duplicate() -> None: + """Cannot update to an existing email.""" + user_id1, _ = Database.create_user("user1@example.com") + Database.create_user("user2@example.com") + + # Try to update user1 to user2's email + with pytest.raises(ValueError, match="already taken"): + Database.update_user_email(user_id1, "user2@example.com") + + +class TestQueueOperations(Test.TestCase): + """Test queue operations.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_add_to_queue(self) -> None: + """Add job with user association.""" + job_id = Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + self.assertIsInstance(job_id, int) + self.assertGreater(job_id, 0) + + def test_get_pending_jobs(self) -> None: + """Retrieve jobs in correct order.""" + # Add multiple jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) # Ensure different timestamps + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Database.get_pending_jobs(limit=10) + + self.assertEqual(len(jobs), 3) + # Should be in order of creation (oldest first) + self.assertEqual(jobs[0]["id"], job1) + self.assertEqual(jobs[1]["id"], job2) + self.assertEqual(jobs[2]["id"], job3) + + def test_update_job_status(self) -> None: + """Update status and error messages.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Update to processing + Database.update_job_status(job_id, "processing") + job = Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "processing") + + # Update to error with message + Database.update_job_status( + job_id, + "error", + "Network timeout", + ) + job = Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "error") + self.assertEqual(job["error_message"], "Network timeout") + self.assertEqual(job["retry_count"], 1) + + def test_retry_job(self) -> None: + """Reset failed jobs for retry.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Set to error + Database.update_job_status(job_id, "error", "Failed") + + # Retry + Database.retry_job(job_id) + job = Database.get_job_by_id(job_id) + + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertIsNone(job["error_message"]) + + def test_delete_job(self) -> None: + """Remove jobs from queue.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Delete job + Database.delete_job(job_id) + + # Should not exist + job = Database.get_job_by_id(job_id) + self.assertIsNone(job) + + def test_get_retryable_jobs(self) -> None: + """Find jobs eligible for retry.""" + # Add job and mark as error + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + Database.update_job_status(job_id, "error", "Failed") + + # Should be retryable + retryable = Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + self.assertEqual(retryable[0]["id"], job_id) + + # Exceed retry limit + Database.update_job_status( + job_id, + "error", + "Failed again", + ) + Database.update_job_status( + job_id, + "error", + "Failed yet again", + ) + + # Should not be retryable anymore + retryable = Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_user_queue_isolation(self) -> None: + """Ensure users only see their own jobs.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com") + + # Add jobs for both users + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "user2@example.com", + user2_id, + ) + + # Get user-specific queue status + user1_jobs = Database.get_user_queue_status(self.user_id) + user2_jobs = Database.get_user_queue_status(user2_id) + + self.assertEqual(len(user1_jobs), 1) + self.assertEqual(user1_jobs[0]["id"], job1) + + self.assertEqual(len(user2_jobs), 1) + self.assertEqual(user2_jobs[0]["id"], job2) + + def test_status_counts(self) -> None: + """Verify status aggregation queries.""" + # Add jobs with different statuses + Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + Database.update_job_status(job2, "processing") + Database.update_job_status(job3, "error", "Failed") + + # Get status counts + counts = Database.get_user_status_counts(self.user_id) + + self.assertEqual(counts.get("pending", 0), 1) + self.assertEqual(counts.get("processing", 0), 1) + self.assertEqual(counts.get("error", 0), 1) + + def test_queue_position(self) -> None: + """Verify queue position calculation.""" + # Add multiple pending jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Check positions + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertEqual(Database.get_queue_position(job2), 2) + self.assertEqual(Database.get_queue_position(job3), 3) + + # Move job 2 to processing + Database.update_job_status(job2, "processing") + + # Check positions (job 3 should now be 2nd pending job) + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertIsNone(Database.get_queue_position(job2)) + self.assertEqual(Database.get_queue_position(job3), 2) + + +class TestEpisodeManagement(Test.TestCase): + """Test episode management functionality.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_create_episode(self) -> None: + """Create episode with user association.""" + episode_id = Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + ) + + self.assertIsInstance(episode_id, int) + self.assertGreater(episode_id, 0) + + def test_get_recent_episodes(self) -> None: + """Retrieve episodes in reverse chronological order.""" + # Create multiple episodes + ep1 = Database.create_episode( + "Article 1", + "url1", + 100, + 1000, + self.user_id, + ) + time.sleep(0.01) + ep2 = Database.create_episode( + "Article 2", + "url2", + 200, + 2000, + self.user_id, + ) + time.sleep(0.01) + ep3 = Database.create_episode( + "Article 3", + "url3", + 300, + 3000, + self.user_id, + ) + + # Get recent episodes + episodes = Database.get_recent_episodes(limit=10) + + self.assertEqual(len(episodes), 3) + # Should be in reverse chronological order + self.assertEqual(episodes[0]["id"], ep3) + self.assertEqual(episodes[1]["id"], ep2) + self.assertEqual(episodes[2]["id"], ep1) + + def test_get_user_episodes(self) -> None: + """Ensure user isolation for episodes.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com") + + # Create episodes for both users + ep1 = Database.create_episode( + "User1 Article", + "url1", + 100, + 1000, + self.user_id, + ) + ep2 = Database.create_episode( + "User2 Article", + "url2", + 200, + 2000, + user2_id, + ) + + # Get user-specific episodes + user1_episodes = Database.get_user_all_episodes( + self.user_id, + ) + user2_episodes = Database.get_user_all_episodes(user2_id) + + self.assertEqual(len(user1_episodes), 1) + self.assertEqual(user1_episodes[0]["id"], ep1) + + self.assertEqual(len(user2_episodes), 1) + self.assertEqual(user2_episodes[0]["id"], ep2) + + def test_episode_metadata(self) -> None: + """Verify duration and content_length storage.""" + Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=12345, + content_length=98765, + user_id=self.user_id, + ) + + episodes = Database.get_user_all_episodes(self.user_id) + episode = episodes[0] + + self.assertEqual(episode["duration"], 12345) + self.assertEqual(episode["content_length"], 98765) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestDatabase, + TestUserManagement, + TestQueueOperations, + TestEpisodeManagement, + ], + ) + + +def main() -> None: + """Run all PodcastItLater.Core tests.""" + if "test" in sys.argv: + test() diff --git a/Biz/PodcastItLater/DESIGN.md b/Biz/PodcastItLater/DESIGN.md new file mode 100644 index 0000000..29c4837 --- /dev/null +++ b/Biz/PodcastItLater/DESIGN.md @@ -0,0 +1,43 @@ +# PodcastItLater Design & Architecture + +## Overview +Service converting web articles to podcast episodes via email/web submission. + +## Architecture +- **Web**: `Biz/PodcastItLater/Web.py` (Ludic + HTMX + Starlette) +- **Worker**: `Biz/PodcastItLater/Worker.py` (Background processing) +- **Core**: `Biz/PodcastItLater/Core.py` (DB & Shared Logic) +- **Billing**: `Biz/PodcastItLater/Billing.py` (Stripe Integration) + +## Key Features +1. **User Management**: Email-based magic links, RSS tokens. +2. **Article Processing**: Trafilatura extraction -> LLM cleanup -> TTS. +3. **Billing (In Progress)**: Stripe Checkout/Portal, Freemium model. + +## Path to Paid Product (Epic: t-143KQl2) + +### 1. Billing Infrastructure +- **Stripe**: Use Stripe Checkout for subs, Portal for management. +- **Webhooks**: Handle `checkout.session.completed`, `customer.subscription.*`. +- **Tiers**: + - `Free`: 10 articles/month. + - `Paid`: Unlimited (initially). + +### 2. Usage Tracking +- **Table**: `users` table needs `plan_tier`, `subscription_status`, `stripe_customer_id`. +- **Logic**: Check usage count vs tier limit before allowing submission. +- **Reset**: Usage counters reset at billing period boundary. + +### 3. Admin Dashboard +- View all users and their status. +- Manually retry/delete jobs. +- View metrics (signups, conversions). + +## UX Polish (Epic: t-1vIPJYG) +- **Mobile**: Ensure all pages work on mobile. +- **Feedback**: Better error messages for failed URLs. +- **Navigation**: Clean up navbar, account management access. + +## Audio Improvements +- **Intro/Outro**: Add metadata-rich intro ("Title by Author"). +- **Sound Design**: Crossfade intro music. diff --git a/Biz/PodcastItLater/Episode.py b/Biz/PodcastItLater/Episode.py new file mode 100644 index 0000000..7090c70 --- /dev/null +++ b/Biz/PodcastItLater/Episode.py @@ -0,0 +1,390 @@ +""" +PodcastItLater Episode Detail Components. + +Components for displaying individual episode pages with media player, +share functionality, and signup prompts for non-authenticated users. +""" + +# : out podcastitlater-episode +# : dep ludic +import Biz.PodcastItLater.UI as UI +import ludic.html as html +import sys +import typing +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from typing import override + + +class EpisodePlayerAttrs(Attrs): + """Attributes for EpisodePlayer component.""" + + audio_url: str + title: str + episode_id: int + + +class EpisodePlayer(Component[AnyChildren, EpisodePlayerAttrs]): + """HTML5 audio player for episode playback.""" + + @override + def render(self) -> html.div: + audio_url = self.attrs["audio_url"] + episode_id = self.attrs["episode_id"] + + return html.div( + html.div( + html.div( + html.h5( + html.i(classes=["bi", "bi-play-circle", "me-2"]), + "Listen", + classes=["card-title", "mb-3"], + ), + html.audio( + html.source(src=audio_url, type="audio/mpeg"), + "Your browser does not support the audio element.", + controls=True, + preload="metadata", + id=f"audio-player-{episode_id}", + classes=["w-100"], + style={"max-width": "100%"}, + ), + # JavaScript to track play events + html.script( + f""" + (function() {{ + var player = document.getElementById( + 'audio-player-{episode_id}' + ); + var hasTrackedPlay = false; + + player.addEventListener('play', function() {{ + // Track first play only + if (!hasTrackedPlay) {{ + hasTrackedPlay = true; + + // Send play event to server + fetch('/episode/{episode_id}/track', {{ + method: 'POST', + headers: {{ + 'Content-Type': + 'application/x-www-form-urlencoded' + }}, + body: 'event_type=played' + }}); + }} + }}); + }})(); + """, + ), + classes=["card-body"], + ), + classes=["card", "mb-4"], + ), + ) + + +class ShareButtonAttrs(Attrs): + """Attributes for ShareButton component.""" + + share_url: str + + +class ShareButton(Component[AnyChildren, ShareButtonAttrs]): + """Button to copy episode URL to clipboard.""" + + @override + def render(self) -> html.div: + share_url = self.attrs["share_url"] + + return html.div( + html.div( + html.div( + html.h5( + html.i(classes=["bi", "bi-share", "me-2"]), + "Share Episode", + classes=["card-title", "mb-3"], + ), + html.div( + html.div( + html.button( + html.i(classes=["bi", "bi-copy", "me-1"]), + "Copy", + type="button", + id="share-button", + on_click=f"navigator.clipboard.writeText('{share_url}'); " # noqa: E501 + "const btn = document.getElementById('share-button'); " # noqa: E501 + "const originalHTML = btn.innerHTML; " + "btn.innerHTML = '<i class=\"bi bi-check me-1\"></i>Copied!'; " # noqa: E501 + "btn.classList.remove('btn-outline-secondary'); " # noqa: E501 + "btn.classList.add('btn-success'); " + "setTimeout(() => {{ " + "btn.innerHTML = originalHTML; " + "btn.classList.remove('btn-success'); " + "btn.classList.add('btn-outline-secondary'); " + "}}, 2000);", + classes=["btn", "btn-outline-secondary"], + ), + html.input( + type="text", + value=share_url, + readonly=True, + on_focus="this.select()", + classes=["form-control"], + ), + classes=["input-group"], + ), + ), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + + +class SignupBannerAttrs(Attrs): + """Attributes for SignupBanner component.""" + + creator_email: str + base_url: str + + +class SignupBanner(Component[AnyChildren, SignupBannerAttrs]): + """Banner prompting non-authenticated users to sign up.""" + + @override + def render(self) -> html.div: + return html.div( + html.div( + html.div( + html.div( + html.i( + classes=[ + "bi", + "bi-info-circle-fill", + "me-2", + ], + ), + html.strong( + "This episode was created using PodcastItLater.", + ), + classes=["mb-3"], + ), + html.div( + html.p( + "Want to convert your own articles " + "to podcast episodes?", + classes=["mb-2"], + ), + html.form( + html.div( + html.input( + type="email", + name="email", + placeholder="Enter your email to start", + required=True, + classes=["form-control"], + ), + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-right-circle", + "me-2", + ], + ), + "Sign Up", + type="submit", + classes=["btn", "btn-primary"], + ), + classes=["input-group"], + ), + hx_post="/login", + hx_target="#signup-result", + hx_swap="innerHTML", + ), + html.div(id="signup-result", classes=["mt-2"]), + ), + classes=["card-body"], + ), + classes=["card", "border-primary"], + ), + classes=["mb-4"], + ) + + +class EpisodeDetailPageAttrs(Attrs): + """Attributes for EpisodeDetailPage component.""" + + episode: dict[str, typing.Any] + episode_sqid: str + creator_email: str | None + user: dict[str, typing.Any] | None + base_url: str + user_has_episode: bool + + +class EpisodeDetailPage(Component[AnyChildren, EpisodeDetailPageAttrs]): + """Full page view for a single episode.""" + + @override + def render(self) -> UI.PageLayout: + episode = self.attrs["episode"] + episode_sqid = self.attrs["episode_sqid"] + creator_email = self.attrs.get("creator_email") + user = self.attrs.get("user") + base_url = self.attrs["base_url"] + user_has_episode = self.attrs.get("user_has_episode", False) + + share_url = f"{base_url}/episode/{episode_sqid}" + duration_str = UI.format_duration(episode.get("duration")) + + # Build page title + page_title = f"{episode['title']} - PodcastItLater" + + # Build meta tags for Open Graph + meta_tags = [ + html.meta(property="og:title", content=episode["title"]), + html.meta(property="og:type", content="website"), + html.meta(property="og:url", content=share_url), + html.meta( + property="og:description", + content=f"Listen to this article read aloud. " + f"Duration: {duration_str}" + + (f" by {episode['author']}" if episode.get("author") else ""), + ), + html.meta( + property="og:site_name", + content="PodcastItLater", + ), + html.meta(property="og:audio", content=episode["audio_url"]), + html.meta(property="og:audio:type", content="audio/mpeg"), + ] + + # Add Twitter Card tags + meta_tags.extend([ + html.meta(name="twitter:card", content="summary"), + html.meta(name="twitter:title", content=episode["title"]), + html.meta( + name="twitter:description", + content=f"Listen to this article. Duration: {duration_str}", + ), + ]) + + # Add author if available + if episode.get("author"): + meta_tags.append( + html.meta(property="article:author", content=episode["author"]), + ) + + return UI.PageLayout( + # Show signup banner if user is not logged in + SignupBanner( + creator_email=creator_email or "a user", + base_url=base_url, + ) + if not user and creator_email + else html.div(), + # Episode title and metadata + html.div( + html.h2( + episode["title"], + classes=["display-6", "mb-3"], + ), + html.div( + html.span( + html.i(classes=["bi", "bi-person", "me-1"]), + f"by {episode['author']}", + classes=["text-muted", "me-3"], + ) + if episode.get("author") + else html.span(), + html.span( + html.i(classes=["bi", "bi-clock", "me-1"]), + f"Duration: {duration_str}", + classes=["text-muted", "me-3"], + ), + html.span( + html.i(classes=["bi", "bi-calendar", "me-1"]), + f"Created: {episode['created_at']}", + classes=["text-muted"], + ), + classes=["mb-3"], + ), + html.div( + html.a( + html.i(classes=["bi", "bi-link-45deg", "me-1"]), + "View original article", + href=episode["original_url"], + target="_blank", + rel="noopener", + classes=["btn", "btn-sm", "btn-outline-secondary"], + ), + ) + if episode.get("original_url") + else html.div(), + classes=["mb-4"], + ), + # Audio player + EpisodePlayer( + audio_url=episode["audio_url"], + title=episode["title"], + episode_id=episode["id"], + ), + # Share button + ShareButton(share_url=share_url), + # Add to feed button (logged-in users without episode) + html.div( + html.div( + html.div( + html.h5( + html.i(classes=["bi", "bi-plus-circle", "me-2"]), + "Add to Your Feed", + classes=["card-title", "mb-3"], + ), + html.p( + "Save this episode to your personal feed " + "to listen later.", + classes=["text-muted", "mb-3"], + ), + html.button( + html.i(classes=["bi", "bi-plus-lg", "me-1"]), + "Add to My Feed", + hx_post=f"/episode/{episode['id']}/add-to-feed", + hx_target="#add-to-feed-result", + hx_swap="innerHTML", + classes=["btn", "btn-primary"], + ), + html.div(id="add-to-feed-result", classes=["mt-2"]), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + if user and not user_has_episode + else html.div(), + # Back to home link + html.div( + html.a( + html.i(classes=["bi", "bi-arrow-left", "me-1"]), + "Back to Home", + href="/", + classes=["btn", "btn-link"], + ), + classes=["mt-4"], + ), + user=user, + current_page="", + error=None, + page_title=page_title, + meta_tags=meta_tags, + ) + + +def main() -> None: + """Episode module has no tests currently.""" + if "test" in sys.argv: + sys.exit(0) diff --git a/Biz/PodcastItLater/INFRASTRUCTURE.md b/Biz/PodcastItLater/INFRASTRUCTURE.md new file mode 100644 index 0000000..1c61618 --- /dev/null +++ b/Biz/PodcastItLater/INFRASTRUCTURE.md @@ -0,0 +1,38 @@ +# Infrastructure Setup for PodcastItLater + +## Mailgun Setup + +Since PodcastItLater requires sending transactional emails (magic links), we use Mailgun. + +### 1. Sign up for Mailgun +Sign up at [mailgun.com](https://www.mailgun.com/). + +### 2. Add Domain +Add `podcastitlater.com` (or `mg.podcastitlater.com`) to Mailgun. +We recommend using the root domain `podcastitlater.com` if you want emails to come from `@podcastitlater.com`. + +### 3. Configure DNS +Mailgun will provide DNS records to verify the domain and authorize email sending. You must add these to your DNS provider (e.g., Cloudflare, Namecheap). + +Required records usually include: +- **TXT** (SPF): `v=spf1 include:mailgun.org ~all` +- **TXT** (DKIM): `k=rsa; p=...` (Provided by Mailgun) +- **MX** (if receiving email, optional for just sending): `10 mxa.mailgun.org`, `10 mxb.mailgun.org` +- **CNAME** (for tracking, optional): `email.podcastitlater.com` -> `mailgun.org` + +### 4. Verify Domain +Click "Verify DNS Settings" in Mailgun dashboard. This may take up to 24 hours but is usually instant. + +### 5. Generate API Key / SMTP Credentials +Go to "Sending" -> "Domain Settings" -> "SMTP Credentials". +Create a new SMTP user (e.g., `postmaster@podcastitlater.com`). +**Save the password immediately.** + +### 6. Update Secrets +Update the production secrets file on the server (`/run/podcastitlater/env`): + +```bash +SMTP_SERVER=smtp.mailgun.org +SMTP_PASSWORD=your-new-smtp-password +EMAIL_FROM=noreply@podcastitlater.com +``` diff --git a/Biz/PodcastItLater/STRIPE_TESTING.md b/Biz/PodcastItLater/STRIPE_TESTING.md new file mode 100644 index 0000000..1461c06 --- /dev/null +++ b/Biz/PodcastItLater/STRIPE_TESTING.md @@ -0,0 +1,114 @@ +# Stripe Testing Guide + +## Testing Stripe Integration Without Real Transactions + +### 1. Use Stripe Test Mode + +Stripe provides test API keys that allow you to simulate payments without real money: + +1. Get your test keys from https://dashboard.stripe.com/test/apikeys +2. Set environment variables with test keys: + ```bash + export STRIPE_SECRET_KEY="sk_test_..." + export STRIPE_WEBHOOK_SECRET="whsec_test_..." + export STRIPE_PRICE_ID_PRO="price_test_..." + ``` + +### 2. Use Stripe Test Cards + +In test mode, use these test card numbers: +- **Success**: `4242 4242 4242 4242` +- **Decline**: `4000 0000 0000 0002` +- **3D Secure**: `4000 0025 0000 3155` + +Any future expiry date and any 3-digit CVC will work. + +### 3. Trigger Test Webhooks + +Use Stripe CLI to trigger webhook events locally: + +```bash +# Install Stripe CLI +# https://stripe.com/docs/stripe-cli + +# Login +stripe login + +# Forward webhooks to local server +stripe listen --forward-to localhost:8000/stripe/webhook + +# Trigger specific events +stripe trigger checkout.session.completed +stripe trigger customer.subscription.created +stripe trigger customer.subscription.updated +stripe trigger invoice.payment_failed +``` + +### 4. Run Unit Tests + +The billing module includes unit tests that mock Stripe webhooks: + +```bash +# Run billing tests +AREA=Test python3 Biz/PodcastItLater/Billing.py test + +# Or use bild +bild --test Biz/PodcastItLater/Billing.py +``` + +### 5. Test Migration on Production + +To fix the production database missing columns issue, you need to trigger the migration. + +The migration runs automatically when `Database.init_db()` is called, but production may have an old database. + +**Option A: Restart the web service** +The init_db() runs on startup, so restarting should apply migrations. + +**Option B: Run migration manually** +```bash +# SSH to production +# Run Python REPL with proper environment +python3 +>>> import os +>>> os.environ["AREA"] = "Live" +>>> os.environ["DATA_DIR"] = "/var/podcastitlater" +>>> import Biz.PodcastItLater.Core as Core +>>> Core.Database.init_db() +``` + +### 6. Verify Database Schema + +Check that billing columns exist: + +```bash +sqlite3 /var/podcastitlater/podcast.db +.schema users +``` + +Should show: +- `stripe_customer_id TEXT` +- `stripe_subscription_id TEXT` +- `subscription_status TEXT` +- `current_period_start TEXT` +- `current_period_end TEXT` +- `plan_tier TEXT NOT NULL DEFAULT 'free'` +- `cancel_at_period_end INTEGER DEFAULT 0` + +### 7. End-to-End Test Flow + +1. Start in test mode: `AREA=Test PORT=8000 python3 Biz/PodcastItLater/Web.py` +2. Login with test account +3. Go to /billing +4. Click "Upgrade Now" +5. Use test card: 4242 4242 4242 4242 +6. Stripe CLI will forward webhook to your local server +7. Verify subscription updated in database + +### 8. Common Issues + +**KeyError in webhook**: Make sure you're using safe `.get()` access for all Stripe object fields, as the structure can vary. + +**Database column missing**: Run migrations by restarting the service or calling `Database.init_db()`. + +**Webhook signature verification fails**: Make sure `STRIPE_WEBHOOK_SECRET` matches your endpoint secret from Stripe dashboard. diff --git a/Biz/PodcastItLater/TESTING.md b/Biz/PodcastItLater/TESTING.md new file mode 100644 index 0000000..2911610 --- /dev/null +++ b/Biz/PodcastItLater/TESTING.md @@ -0,0 +1,45 @@ +# PodcastItLater Testing Strategy + +## Overview +We use `pytest` with `Omni.Test` integration. Tests are co-located with code or in `Biz/PodcastItLater/Test.py` for E2E. + +## Test Categories + +### 1. Core (Database/Logic) +- **Location**: `Biz/PodcastItLater/Core.py` +- **Scope**: User creation, Job queue ops, Episode management. +- **Key Tests**: + - `test_create_user`: Unique tokens. + - `test_queue_isolation`: Users see only their jobs. + +### 2. Web (HTTP/UI) +- **Location**: `Biz/PodcastItLater/Web.py` +- **Scope**: Routes, Auth, HTMX responses. +- **Key Tests**: + - `test_submit_requires_auth`. + - `test_rss_feed_xml`. + - `test_admin_access_control`. + +### 3. Worker (Processing) +- **Location**: `Biz/PodcastItLater/Worker.py` +- **Scope**: Extraction, TTS, S3 upload. +- **Key Tests**: + - `test_extract_content`: Mocked network calls. + - `test_tts_chunking`: Handle long text. + - **Error Handling**: Ensure retries work and errors are logged. + +### 4. Billing (Stripe) +- **Location**: `Biz/PodcastItLater/Billing.py` +- **Scope**: Webhook processing, Entitlement checks. +- **Key Tests**: + - `test_webhook_subscription_update`: Update local DB. + - `test_enforce_limits`: Block submission if over limit. + +## Running Tests +```bash +# Run all +bild --test Biz/PodcastItLater.hs + +# Run specific file +./Biz/PodcastItLater/Web.py test +``` diff --git a/Biz/PodcastItLater/Test.py b/Biz/PodcastItLater/Test.py new file mode 100644 index 0000000..86b04f4 --- /dev/null +++ b/Biz/PodcastItLater/Test.py @@ -0,0 +1,276 @@ +"""End-to-end tests for PodcastItLater.""" + +# : dep boto3 +# : dep botocore +# : dep feedgen +# : dep httpx +# : dep itsdangerous +# : dep ludic +# : dep openai +# : dep psutil +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep sqids +# : dep starlette +# : dep stripe +# : dep trafilatura +# : dep uvicorn +# : out podcastitlater-e2e-test +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.UI as UI +import Biz.PodcastItLater.Web as Web +import Biz.PodcastItLater.Worker as Worker +import Omni.App as App +import Omni.Test as Test +import pathlib +import re +import sys +import unittest.mock +from starlette.testclient import TestClient + + +class BaseWebTest(Test.TestCase): + """Base test class with common setup.""" + + def setUp(self) -> None: + """Set up test environment.""" + self.app = Web.app + self.client = TestClient(self.app) + + # Initialize database for each test + Core.Database.init_db() + + @staticmethod + def tearDown() -> None: + """Clean up after each test.""" + Core.Database.teardown() + + +class TestEndToEnd(BaseWebTest): + """Test complete end-to-end flows.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + + # Create and login user + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + def test_full_article_to_rss_flow(self) -> None: # noqa: PLR0915 + """Test complete flow: submit URL → process → appears in RSS feed.""" + # Step 1: Submit article URL + response = self.client.post( + "/submit", + data={"url": "https://example.com/great-article"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Article submitted successfully", response.text) + + # Extract job ID from response + match = re.search(r"Job ID: (\d+)", response.text) + self.assertIsNotNone(match) + if match is None: + self.fail("Job ID not found in response") + job_id = int(match.group(1)) + + # Verify job was created + job = Core.Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertEqual(job["user_id"], self.user_id) + + # Step 2: Process the job with mocked external services + shutdown_handler = Worker.ShutdownHandler() + processor = Worker.ArticleProcessor(shutdown_handler) + + # Mock external dependencies + mock_audio_data = b"fake-mp3-audio-content-12345" + + with ( + unittest.mock.patch.object( + Worker.ArticleProcessor, + "extract_article_content", + return_value=( + "Great Article Title", + "This is the article content.", + ), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["This is the article content."], + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + unittest.mock.patch.object( + processor.openai_client.audio.speech, + "create", + ) as mock_tts, + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://cdn.example.com/episode_123_Great_Article.mp3", + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + ) as mock_audio_segment, + unittest.mock.patch( + "pathlib.Path.read_bytes", + return_value=mock_audio_data, + ), + ): + # Mock TTS response + mock_tts_response = unittest.mock.MagicMock() + mock_tts_response.content = mock_audio_data + mock_tts.return_value = mock_tts_response + + # Mock audio segment + mock_segment = unittest.mock.MagicMock() + mock_segment.export = lambda path, **_kwargs: pathlib.Path( + path, + ).write_bytes( + mock_audio_data, + ) + mock_audio_segment.return_value = mock_segment + + # Process the pending job + Worker.process_pending_jobs(processor) + + # Step 3: Verify job was marked completed + job = Core.Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "completed") + + # Step 4: Verify episode was created + episodes = Core.Database.get_user_all_episodes(self.user_id) + self.assertEqual(len(episodes), 1) + + episode = episodes[0] + self.assertEqual(episode["title"], "Great Article Title") + self.assertEqual( + episode["audio_url"], + "https://cdn.example.com/episode_123_Great_Article.mp3", + ) + self.assertGreater(episode["duration"], 0) + self.assertEqual(episode["user_id"], self.user_id) + + # Step 5: Verify episode appears in RSS feed + response = self.client.get(f"/feed/{self.token}.rss") + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.headers["content-type"], + "application/rss+xml; charset=utf-8", + ) + + # Check RSS contains the episode + self.assertIn("Great Article Title", response.text) + self.assertIn( + "https://cdn.example.com/episode_123_Great_Article.mp3", + response.text, + ) + self.assertIn("<enclosure", response.text) + self.assertIn('type="audio/mpeg"', response.text) + + # Step 6: Verify only this user's episode is in their feed + # Create another user with their own episode + user2_id, token2 = Core.Database.create_user("other@example.com") + Core.Database.create_episode( + "Other User's Article", + "https://cdn.example.com/other.mp3", + 200, + 3000, + user2_id, + ) + + # Original user's feed should not contain other user's episode + response = self.client.get(f"/feed/{self.token}.rss") + self.assertIn("Great Article Title", response.text) + self.assertNotIn("Other User's Article", response.text) + + # Other user's feed should only contain their episode + response = self.client.get(f"/feed/{token2}.rss") + self.assertNotIn("Great Article Title", response.text) + self.assertIn("Other User's Article", response.text) + + +class TestUI(Test.TestCase): + """Test UI components.""" + + def test_render_navbar(self) -> None: + """Test navbar rendering.""" + user = {"email": "test@example.com", "id": 1} + layout = UI.PageLayout( + user=user, + current_page="home", + error=None, + page_title="Test", + meta_tags=[], + ) + navbar = layout._render_navbar(user, "home") # noqa: SLF001 + html_output = navbar.to_html() + + # Check basic structure + self.assertIn("navbar", html_output) + self.assertIn("Home", html_output) + self.assertIn("Public Feed", html_output) + self.assertIn("Pricing", html_output) + self.assertIn("Manage Account", html_output) + + # Check active state + self.assertIn("active", html_output) + + # Check non-admin user doesn't see admin menu + self.assertNotIn("Admin", html_output) + + def test_render_navbar_admin(self) -> None: + """Test navbar rendering for admin.""" + user = {"email": "ben@bensima.com", "id": 1} # Admin email + layout = UI.PageLayout( + user=user, + current_page="admin", + error=None, + page_title="Test", + meta_tags=[], + ) + navbar = layout._render_navbar(user, "admin") # noqa: SLF001 + html_output = navbar.to_html() + + # Check admin menu present + self.assertIn("Admin", html_output) + self.assertIn("Queue Status", html_output) + + +def test() -> None: + """Run all end-to-end tests.""" + Test.run( + App.Area.Test, + [ + TestEndToEnd, + TestUI, + ], + ) + + +def main() -> None: + """Run the tests.""" + if "test" in sys.argv: + test() + else: + test() diff --git a/Biz/PodcastItLater/TestMetricsView.py b/Biz/PodcastItLater/TestMetricsView.py new file mode 100644 index 0000000..c6fdd46 --- /dev/null +++ b/Biz/PodcastItLater/TestMetricsView.py @@ -0,0 +1,121 @@ +"""Tests for Admin metrics view.""" + +# : out podcastitlater-test-metrics +# : dep pytest +# : dep starlette +# : dep httpx +# : dep ludic +# : dep feedgen +# : dep itsdangerous +# : dep uvicorn +# : dep stripe +# : dep sqids + +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.Web as Web +import Omni.Test as Test +from starlette.testclient import TestClient + + +class BaseWebTest(Test.TestCase): + """Base class for web tests.""" + + def setUp(self) -> None: + """Set up test database and client.""" + Core.Database.init_db() + self.client = TestClient(Web.app) + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Core.Database.teardown() + + +class TestMetricsView(BaseWebTest): + """Test Admin Metrics View.""" + + def test_admin_metrics_view_access(self) -> None: + """Admin user should be able to access metrics view.""" + # Create admin user + _admin_id, _ = Core.Database.create_user("ben@bensima.com") + self.client.post("/login", data={"email": "ben@bensima.com"}) + + response = self.client.get("/admin/metrics") + self.assertEqual(response.status_code, 200) + self.assertIn("Growth & Usage", response.text) + self.assertIn("Total Users", response.text) + + def test_admin_metrics_data(self) -> None: + """Metrics view should show correct data.""" + # Create admin user + admin_id, _ = Core.Database.create_user("ben@bensima.com") + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Create some data + # 1. Users + Core.Database.create_user("user1@example.com") + user2_id, _ = Core.Database.create_user("user2@example.com") + + # 2. Subscriptions (simulate by setting subscription_status) + with Core.Database.get_connection() as conn: + conn.execute( + "UPDATE users SET subscription_status = 'active' WHERE id = ?", + (user2_id,), + ) + conn.commit() + + # 3. Submissions + Core.Database.add_to_queue( + "http://example.com/1", + "user1@example.com", + admin_id, + ) + + # Get metrics page + response = self.client.get("/admin/metrics") + self.assertEqual(response.status_code, 200) + + # Check labels + self.assertIn("Total Users", response.text) + self.assertIn("Active Subs", response.text) + self.assertIn("Submissions (24h)", response.text) + + # Check values (metrics dict is passed to template, + # we check rendered HTML) + # Total users: 3 (admin + user1 + user2) + # Active subs: 1 (user2) + # Submissions 24h: 1 + + # Check for values in HTML + # Note: This is a bit brittle, but effective for quick verification + self.assertIn('<h3 class="mb-0">3</h3>', response.text) + self.assertIn('<h3 class="mb-0">1</h3>', response.text) + + def test_non_admin_access_denied(self) -> None: + """Non-admin users should be denied access.""" + # Create regular user + Core.Database.create_user("regular@example.com") + self.client.post("/login", data={"email": "regular@example.com"}) + + response = self.client.get("/admin/metrics") + # Should redirect to /?error=forbidden + self.assertEqual(response.status_code, 302) + self.assertIn("error=forbidden", response.headers["Location"]) + + def test_anonymous_access_redirect(self) -> None: + """Anonymous users should be redirected to login.""" + response = self.client.get("/admin/metrics") + self.assertEqual(response.status_code, 302) + self.assertEqual(response.headers["Location"], "/") + + +def main() -> None: + """Run the tests.""" + Test.run( + Web.area, + [TestMetricsView], + ) + + +if __name__ == "__main__": + main() diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py new file mode 100644 index 0000000..e9ef27d --- /dev/null +++ b/Biz/PodcastItLater/UI.py @@ -0,0 +1,755 @@ +""" +PodcastItLater Shared UI Components. + +Common UI components and utilities shared across web pages. +""" + +# : lib +# : dep ludic +import Biz.PodcastItLater.Core as Core +import ludic.html as html +import typing +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from typing import override + + +def format_duration(seconds: int | None) -> str: + """Format duration from seconds to human-readable format. + + Examples: + 300 -> "5m" + 3840 -> "1h 4m" + 11520 -> "3h 12m" + """ + if seconds is None or seconds <= 0: + return "Unknown" + + # Constants for time conversion + seconds_per_minute = 60 + minutes_per_hour = 60 + seconds_per_hour = 3600 + + # Round up to nearest minute + minutes = (seconds + seconds_per_minute - 1) // seconds_per_minute + + # Show as minutes only if under 60 minutes (exclusive) + # 3599 seconds rounds up to 60 minutes, which we keep as "60m" + if minutes <= minutes_per_hour: + # If exactly 3600 seconds (already 60 full minutes without rounding) + if seconds >= seconds_per_hour: + return "1h" + return f"{minutes}m" + + hours = minutes // minutes_per_hour + remaining_minutes = minutes % minutes_per_hour + + if remaining_minutes == 0: + return f"{hours}h" + + return f"{hours}h {remaining_minutes}m" + + +def create_bootstrap_styles() -> html.style: + """Load Bootstrap CSS and icons.""" + return html.style( + "@import url('https://cdn.jsdelivr.net/npm/bootstrap@5.3.2" + "/dist/css/bootstrap.min.css');" + "@import url('https://cdn.jsdelivr.net/npm/bootstrap-icons" + "@1.11.3/font/bootstrap-icons.min.css');", + ) + + +def create_auto_dark_mode_style() -> html.style: + """Create CSS for automatic dark mode based on prefers-color-scheme.""" + return html.style( + """ + /* Auto dark mode - applies Bootstrap dark theme via media query */ + @media (prefers-color-scheme: dark) { + :root { + color-scheme: dark; + --bs-body-color: #dee2e6; + --bs-body-color-rgb: 222, 226, 230; + --bs-body-bg: #212529; + --bs-body-bg-rgb: 33, 37, 41; + --bs-emphasis-color: #fff; + --bs-emphasis-color-rgb: 255, 255, 255; + --bs-secondary-color: rgba(222, 226, 230, 0.75); + --bs-secondary-bg: #343a40; + --bs-tertiary-color: rgba(222, 226, 230, 0.5); + --bs-tertiary-bg: #2b3035; + --bs-heading-color: inherit; + --bs-link-color: #6ea8fe; + --bs-link-hover-color: #8bb9fe; + --bs-link-color-rgb: 110, 168, 254; + --bs-link-hover-color-rgb: 139, 185, 254; + --bs-code-color: #e685b5; + --bs-border-color: #495057; + --bs-border-color-translucent: rgba(255, 255, 255, 0.15); + } + + /* Navbar dark mode */ + .navbar.bg-body-tertiary { + background-color: #2b3035 !important; + } + + .navbar .navbar-text { + color: #dee2e6 !important; + } + + /* Table header dark mode */ + .table-light { + --bs-table-bg: #343a40; + --bs-table-color: #dee2e6; + background-color: #343a40 !important; + color: #dee2e6 !important; + } + } + """, + ) + + +def create_htmx_script() -> html.script: + """Load HTMX library.""" + return html.script( + src="https://unpkg.com/htmx.org@1.9.10", + integrity="sha384-D1Kt99CQMDuVetoL1lrYwg5t+9QdHe7NLX/SoJYkXDFfX37iInKRy5xLSi8nO7UC", + crossorigin="anonymous", + ) + + +def create_bootstrap_js() -> html.script: + """Load Bootstrap JavaScript bundle.""" + return html.script( + src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/js/bootstrap.bundle.min.js", + integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL", + crossorigin="anonymous", + ) + + +class PageLayoutAttrs(Attrs): + """Attributes for PageLayout component.""" + + user: dict[str, typing.Any] | None + current_page: str + error: str | None + page_title: str | None + meta_tags: list[html.meta] | None + + +class PageLayout(Component[AnyChildren, PageLayoutAttrs]): + """Reusable page layout with header and navbar.""" + + @staticmethod + def _render_nav_item( + label: str, + href: str, + icon: str, + *, + is_active: bool, + ) -> html.li: + return html.li( + html.a( + html.i(classes=["bi", f"bi-{icon}", "me-1"]), + label, + href=href, + classes=[ + "nav-link", + "active" if is_active else "", + ], + ), + classes=["nav-item"], + ) + + @staticmethod + def _render_admin_dropdown( + is_active_func: typing.Callable[[str], bool], + ) -> html.li: + is_active = is_active_func("admin") or is_active_func("admin-users") + return html.li( + html.a( # type: ignore[call-arg] + html.i(classes=["bi", "bi-gear-fill", "me-1"]), + "Admin", + href="#", + id="adminDropdown", + role="button", + data_bs_toggle="dropdown", + aria_expanded="false", + classes=[ + "nav-link", + "dropdown-toggle", + "active" if is_active else "", + ], + ), + html.ul( # type: ignore[call-arg] + html.li( + html.a( + html.i(classes=["bi", "bi-list-task", "me-2"]), + "Queue Status", + href="/admin", + classes=["dropdown-item"], + ), + ), + html.li( + html.a( + html.i(classes=["bi", "bi-people-fill", "me-2"]), + "Manage Users", + href="/admin/users", + classes=["dropdown-item"], + ), + ), + html.li( + html.a( + html.i(classes=["bi", "bi-graph-up", "me-2"]), + "Metrics", + href="/admin/metrics", + classes=["dropdown-item"], + ), + ), + classes=["dropdown-menu"], + aria_labelledby="adminDropdown", + ), + classes=["nav-item", "dropdown"], + ) + + @staticmethod + def _render_navbar( + user: dict[str, typing.Any] | None, + current_page: str, + ) -> html.nav: + """Render navigation bar.""" + + def is_active(page: str) -> bool: + return current_page == page + + return html.nav( + html.div( + html.button( # type: ignore[call-arg] + html.span(classes=["navbar-toggler-icon"]), + classes=["navbar-toggler", "ms-auto"], + type="button", + data_bs_toggle="collapse", + data_bs_target="#navbarNav", + aria_controls="navbarNav", + aria_expanded="false", + aria_label="Toggle navigation", + ), + html.div( + html.ul( + PageLayout._render_nav_item( + "Home", + "/", + "house-fill", + is_active=is_active("home"), + ), + PageLayout._render_nav_item( + "Public Feed", + "/public", + "globe", + is_active=is_active("public"), + ), + PageLayout._render_nav_item( + "Pricing", + "/pricing", + "stars", + is_active=is_active("pricing"), + ), + PageLayout._render_nav_item( + "Manage Account", + "/account", + "person-circle", + is_active=is_active("account"), + ), + PageLayout._render_admin_dropdown(is_active) + if user and Core.is_admin(user) + else html.span(), + classes=["navbar-nav"], + ), + id="navbarNav", + classes=["collapse", "navbar-collapse"], + ), + classes=["container-fluid"], + ), + classes=[ + "navbar", + "navbar-expand-lg", + "bg-body-tertiary", + "rounded", + "mb-4", + ], + ) + + @override + def render(self) -> html.html: + user = self.attrs.get("user") + current_page = self.attrs.get("current_page", "") + error = self.attrs.get("error") + page_title = self.attrs.get("page_title") or "PodcastItLater" + meta_tags = self.attrs.get("meta_tags") or [] + + return html.html( + html.head( + html.meta(charset="utf-8"), + html.meta( + name="viewport", + content="width=device-width, initial-scale=1", + ), + html.meta( + name="color-scheme", + content="light dark", + ), + html.title(page_title), + *meta_tags, + create_htmx_script(), + ), + html.body( + create_bootstrap_styles(), + create_auto_dark_mode_style(), + html.div( + html.div( + html.h1( + "PodcastItLater", + classes=["display-4", "mb-2"], + ), + html.p( + "Convert web articles to podcast episodes", + classes=["lead", "text-muted"], + ), + classes=["text-center", "mb-4", "pt-4"], + ), + html.div( + html.div( + html.i( + classes=[ + "bi", + "bi-exclamation-triangle-fill", + "me-2", + ], + ), + error or "", + classes=[ + "alert", + "alert-danger", + "d-flex", + "align-items-center", + ], + role="alert", # type: ignore[call-arg] + ), + ) + if error + else html.div(), + self._render_navbar(user, current_page) + if user + else html.div(), + *self.children, + classes=["container", "px-3", "px-md-4"], + style={"max-width": "900px"}, + ), + create_bootstrap_js(), + ), + ) + + +class AccountPageAttrs(Attrs): + """Attributes for AccountPage component.""" + + user: dict[str, typing.Any] + usage: dict[str, int] + limits: dict[str, int | None] + portal_url: str | None + + +class AccountPage(Component[AnyChildren, AccountPageAttrs]): + """Account management page component.""" + + @override + def render(self) -> PageLayout: + user = self.attrs["user"] + usage = self.attrs["usage"] + limits = self.attrs["limits"] + portal_url = self.attrs["portal_url"] + + plan_tier = user.get("plan_tier", "free") + is_paid = plan_tier == "paid" + + article_limit = limits.get("articles_per_period") + article_usage = usage.get("articles", 0) + + limit_text = ( + "Unlimited" if article_limit is None else str(article_limit) + ) + + usage_percent = 0 + if article_limit: + usage_percent = min(100, int((article_usage / article_limit) * 100)) + + progress_style = ( + {"width": f"{usage_percent}%"} if article_limit else {"width": "0%"} + ) + + return PageLayout( + html.div( + html.div( + html.div( + html.div( + html.div( + html.h2( + html.i( + classes=[ + "bi", + "bi-person-circle", + "me-2", + ], + ), + "My Account", + classes=["card-title", "mb-4"], + ), + # User Info Section + html.div( + html.h5("Profile", classes=["mb-3"]), + html.div( + html.strong("Email: "), + html.span(user.get("email", "")), + html.button( + "Change", + classes=[ + "btn", + "btn-sm", + "btn-outline-secondary", + "ms-2", + "py-0", + ], + hx_get="/settings/email/edit", + hx_target="closest div", + hx_swap="outerHTML", + ), + classes=[ + "mb-2", + "d-flex", + "align-items-center", + ], + ), + html.p( + html.strong("Member since: "), + user.get("created_at", "").split("T")[ + 0 + ], + classes=["mb-4"], + ), + classes=["mb-5"], + ), + # Subscription Section + html.div( + html.h5("Subscription", classes=["mb-3"]), + html.div( + html.div( + html.strong("Current Plan"), + html.span( + plan_tier.title(), + classes=[ + "badge", + "bg-success" + if is_paid + else "bg-secondary", + "ms-2", + ], + ), + classes=[ + "d-flex", + "align-items-center", + "mb-3", + ], + ), + # Usage Stats + html.div( + html.p( + "Usage this period:", + classes=["mb-2", "text-muted"], + ), + html.div( + html.div( + f"{article_usage} / " + f"{limit_text}", + classes=["mb-1"], + ), + html.div( + html.div( + classes=[ + "progress-bar", + ], + role="progressbar", # type: ignore[call-arg] + style=progress_style, # type: ignore[arg-type] + ), + classes=[ + "progress", + "mb-3", + ], + style={"height": "10px"}, + ) + if article_limit + else html.div(), + classes=["mb-3"], + ), + ), + # Actions + html.div( + html.form( + html.button( + html.i( + classes=[ + "bi", + "bi-credit-card", + "me-2", + ], + ), + "Manage Subscription", + type="submit", + classes=[ + "btn", + "btn-outline-primary", + ], + ), + method="post", + action=portal_url, + ) + if is_paid and portal_url + else html.a( + html.i( + classes=[ + "bi", + "bi-star-fill", + "me-2", + ], + ), + "Upgrade to Pro", + href="/pricing", + classes=["btn", "btn-primary"], + ), + classes=["d-flex", "gap-2"], + ), + classes=[ + "card", + "card-body", + "bg-light", + ], + ), + classes=["mb-5"], + ), + # Logout Section + html.div( + html.form( + html.button( + html.i( + classes=[ + "bi", + "bi-box-arrow-right", + "me-2", + ], + ), + "Log Out", + type="submit", + classes=[ + "btn", + "btn-outline-danger", + ], + ), + action="/logout", + method="post", + ), + classes=["border-top", "pt-4"], + ), + # Delete Account Section + html.div( + html.h5( + "Danger Zone", + classes=["text-danger", "mb-3"], + ), + html.div( + html.h6("Delete Account"), + html.p( + "Once you delete your account, " + "there is no going back. " + "Please be certain.", + classes=["card-text"], + ), + html.button( + html.i( + classes=[ + "bi", + "bi-trash", + "me-2", + ], + ), + "Delete Account", + hx_delete="/account", + hx_confirm=( + "Are you absolutely sure you " + "want to delete your account? " + "This action cannot be undone." + ), + classes=["btn", "btn-danger"], + ), + classes=[ + "card", + "card-body", + "border-danger", + ], + ), + classes=["mt-5", "pt-4", "border-top"], + ), + classes=["card-body", "p-4"], + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-8", "mx-auto"], + ), + classes=["row"], + ), + ), + user=user, + current_page="account", + page_title="Account - PodcastItLater", + error=None, + meta_tags=[], + ) + + +class PricingPageAttrs(Attrs): + """Attributes for PricingPage component.""" + + user: dict[str, typing.Any] | None + + +class PricingPage(Component[AnyChildren, PricingPageAttrs]): + """Pricing page component.""" + + @override + def render(self) -> PageLayout: + user = self.attrs.get("user") + current_tier = user.get("plan_tier", "free") if user else "free" + + return PageLayout( + html.div( + html.div( + # Free Tier + html.div( + html.div( + html.div( + html.h3("Free", classes=["card-title"]), + html.h4( + "$0", + classes=[ + "card-subtitle", + "mb-3", + "text-muted", + ], + ), + html.p( + "10 articles total", + classes=["card-text"], + ), + html.ul( + html.li("Convert 10 articles"), + html.li("Basic features"), + classes=["list-unstyled", "mb-4"], + ), + html.button( + "Current Plan", + classes=[ + "btn", + "btn-outline-primary", + "w-100", + ], + disabled=True, + ) + if current_tier == "free" + else html.div(), + classes=["card-body"], + ), + classes=["card", "mb-4", "shadow-sm", "h-100"], + ), + classes=["col-md-6"], + ), + # Paid Tier + html.div( + html.div( + html.div( + html.h3( + "Unlimited", + classes=["card-title"], + ), + html.h4( + "$12/mo", + classes=[ + "card-subtitle", + "mb-3", + "text-muted", + ], + ), + html.p( + "Unlimited articles", + classes=["card-text"], + ), + html.ul( + html.li("Unlimited conversions"), + html.li("Priority processing"), + html.li("Support independent software"), + classes=["list-unstyled", "mb-4"], + ), + html.form( + html.button( + "Upgrade Now", + type="submit", + classes=[ + "btn", + "btn-primary", + "w-100", + ], + ), + action="/upgrade", + method="post", + ) + if user and current_tier == "free" + else ( + html.button( + "Current Plan", + classes=[ + "btn", + "btn-success", + "w-100", + ], + disabled=True, + ) + if user and current_tier == "paid" + else html.a( + "Login to Upgrade", + href="/", + classes=[ + "btn", + "btn-primary", + "w-100", + ], + ) + ), + classes=["card-body"], + ), + classes=[ + "card", + "mb-4", + "shadow-sm", + "border-primary", + "h-100", + ], + ), + classes=["col-md-6"], + ), + classes=["row"], + ), + ), + user=user, + current_page="pricing", + page_title="Pricing - PodcastItLater", + error=None, + meta_tags=[], + ) diff --git a/Biz/PodcastItLater/Web.nix b/Biz/PodcastItLater/Web.nix new file mode 100644 index 0000000..7533ca4 --- /dev/null +++ b/Biz/PodcastItLater/Web.nix @@ -0,0 +1,93 @@ +{ + options, + lib, + config, + ... +}: let + cfg = config.services.podcastitlater-web; + rootDomain = "podcastitlater.com"; + ports = import ../../Omni/Cloud/Ports.nix; +in { + options.services.podcastitlater-web = { + enable = lib.mkEnableOption "Enable the PodcastItLater web service"; + port = lib.mkOption { + type = lib.types.int; + default = 8000; + description = '' + The port on which PodcastItLater web will listen for + incoming HTTP traffic. + ''; + }; + dataDir = lib.mkOption { + type = lib.types.path; + default = "/var/podcastitlater"; + description = "Data directory for PodcastItLater (shared with worker)"; + }; + package = lib.mkOption { + type = lib.types.package; + description = "PodcastItLater web package to use"; + }; + }; + config = lib.mkIf cfg.enable { + systemd.services.podcastitlater-web = { + path = [cfg.package]; + wantedBy = ["multi-user.target"]; + preStart = '' + # Create data directory if it doesn't exist + mkdir -p ${cfg.dataDir} + + # Manual step: create this file with secrets + # SECRET_KEY=your-secret-key-for-sessions + # SESSION_SECRET=your-session-secret + # EMAIL_FROM=noreply@podcastitlater.com + # SMTP_SERVER=smtp.mailgun.org + # SMTP_PASSWORD=your-smtp-password + # STRIPE_SECRET_KEY=sk_live_your_stripe_secret_key + # STRIPE_WEBHOOK_SECRET=whsec_your_webhook_secret + # STRIPE_PRICE_ID_PRO=price_your_pro_price_id + test -f /run/podcastitlater/env + ''; + script = '' + ${cfg.package}/bin/podcastitlater-web + ''; + description = '' + PodcastItLater Web Service + ''; + serviceConfig = { + Environment = [ + "PORT=${toString cfg.port}" + "AREA=Live" + "DATA_DIR=${cfg.dataDir}" + "BASE_URL=https://${rootDomain}" + ]; + EnvironmentFile = "/run/podcastitlater/env"; + KillSignal = "INT"; + Type = "simple"; + Restart = "on-abort"; + RestartSec = "1"; + }; + }; + + # Nginx configuration + services.nginx = { + enable = true; + recommendedGzipSettings = true; + recommendedOptimisation = true; + recommendedProxySettings = true; + recommendedTlsSettings = true; + statusPage = true; + + virtualHosts."${rootDomain}" = { + forceSSL = true; + enableACME = true; + locations."/" = { + proxyPass = "http://127.0.0.1:${toString cfg.port}"; + proxyWebsockets = true; + }; + }; + }; + + # Ensure firewall allows web traffic + networking.firewall.allowedTCPPorts = [ports.ssh ports.http ports.https]; + }; +} diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py new file mode 100644 index 0000000..30b5236 --- /dev/null +++ b/Biz/PodcastItLater/Web.py @@ -0,0 +1,3480 @@ +""" +PodcastItLater Web Service. + +Web frontend for converting articles to podcast episodes. +Provides ludic + htmx interface and RSS feed generation. +""" + +# : out podcastitlater-web +# : dep ludic +# : dep feedgen +# : dep httpx +# : dep itsdangerous +# : dep uvicorn +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep starlette +# : dep stripe +# : dep sqids +import Biz.EmailAgent +import Biz.PodcastItLater.Admin as Admin +import Biz.PodcastItLater.Billing as Billing +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.Episode as Episode +import Biz.PodcastItLater.UI as UI +import html as html_module +import httpx +import logging +import ludic.html as html +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import pathlib +import re +import sys +import tempfile +import typing +import urllib.parse +import uvicorn +from datetime import datetime +from datetime import timezone +from feedgen.feed import FeedGenerator # type: ignore[import-untyped] +from itsdangerous import URLSafeTimedSerializer +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from ludic.web import LudicApp +from ludic.web import Request +from ludic.web.datastructures import FormData +from ludic.web.responses import Response +from sqids import Sqids +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import RedirectResponse +from starlette.testclient import TestClient +from typing import override +from unittest.mock import patch + +logger = logging.getLogger(__name__) +Log.setup(logger) + + +# Configuration +area = App.from_env() +BASE_URL = os.getenv("BASE_URL", "http://localhost:8000") +PORT = int(os.getenv("PORT", "8000")) + +# Initialize sqids for episode URL encoding +sqids = Sqids(min_length=8) + + +def encode_episode_id(episode_id: int) -> str: + """Encode episode ID to sqid for URLs.""" + return str(sqids.encode([episode_id])) + + +def decode_episode_id(sqid: str) -> int | None: + """Decode sqid to episode ID. Returns None if invalid.""" + try: + decoded = sqids.decode(sqid) + return decoded[0] if decoded else None + except (ValueError, IndexError): + return None + + +# Authentication configuration +MAGIC_LINK_MAX_AGE = 3600 # 1 hour +SESSION_MAX_AGE = 30 * 24 * 3600 # 30 days +EMAIL_FROM = os.getenv("EMAIL_FROM", "noreply@podcastitlater.com") +SMTP_SERVER = os.getenv("SMTP_SERVER", "smtp.mailgun.org") +SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "") + +# Initialize serializer for magic links +magic_link_serializer = URLSafeTimedSerializer( + os.getenv("SECRET_KEY", "dev-secret-key"), +) + + +RSS_CONFIG = { + "author": "PodcastItLater", + "language": "en-US", + "base_url": BASE_URL, +} + + +def extract_og_metadata(url: str) -> tuple[str | None, str | None]: + """Extract Open Graph title and author from URL. + + Returns: + tuple: (title, author) - both may be None if extraction fails + """ + try: + # Use httpx to fetch the page with a timeout + response = httpx.get(url, timeout=10.0, follow_redirects=True) + response.raise_for_status() + + # Simple regex-based extraction to avoid heavy dependencies + html_content = response.text + + # Extract og:title + title_match = re.search( + r'<meta\s+(?:property|name)=["\']og:title["\']\s+content=["\'](.*?)["\']', + html_content, + re.IGNORECASE, + ) + title = title_match.group(1) if title_match else None + + # Extract author - try article:author first, then og:site_name + author_match = re.search( + r'<meta\s+(?:property|name)=["\']article:author["\']\s+content=["\'](.*?)["\']', + html_content, + re.IGNORECASE, + ) + if not author_match: + author_match = re.search( + r'<meta\s+(?:property|name)=["\']og:site_name["\']\s+content=["\'](.*?)["\']', + html_content, + re.IGNORECASE, + ) + author = author_match.group(1) if author_match else None + + # Clean up HTML entities + if title: + title = html_module.unescape(title) + if author: + author = html_module.unescape(author) + + except (httpx.HTTPError, httpx.TimeoutException, re.error) as e: + logger.warning("Failed to extract metadata from %s: %s", url, e) + return None, None + else: + return title, author + + +def send_magic_link(email: str, token: str) -> None: + """Send magic link email to user.""" + subject = "Login to PodcastItLater" + + # Create temporary file for email body + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".txt", + delete=False, + encoding="utf-8", + ) as f: + body_text_path = pathlib.Path(f.name) + + # Create email body + magic_link = f"{BASE_URL}/auth/verify?token={token}" + body_text_path.write_text(f""" +Hello, + +Click this link to login to PodcastItLater: +{magic_link} + +This link will expire in 1 hour. + +If you didn't request this, please ignore this email. + +Best, +PodcastItLater +""") + + try: + Biz.EmailAgent.send_email( + to_addrs=[email], + from_addr=EMAIL_FROM, + smtp_server=SMTP_SERVER, + password=SMTP_PASSWORD, + subject=subject, + body_text=body_text_path, + ) + finally: + # Clean up temporary file + body_text_path.unlink(missing_ok=True) + + +class LoginFormAttrs(Attrs): + """Attributes for LoginForm component.""" + + error: str | None + + +class LoginForm(Component[AnyChildren, LoginFormAttrs]): + """Simple email-based login/registration form.""" + + @override + def render(self) -> html.div: + error = self.attrs.get("error") + is_dev_mode = App.from_env() == App.Area.Test + + return html.div( + # Dev mode banner + html.div( + html.div( + html.i(classes=["bi", "bi-info-circle", "me-2"]), + html.strong("Dev/Test Mode: "), + "Use ", + html.code( + "demo@example.com", + classes=["text-dark", "mx-1"], + ), + " for instant login", + classes=[ + "alert", + "alert-info", + "d-flex", + "align-items-center", + "mb-3", + ], + ), + ) + if is_dev_mode + else html.div(), + html.div( + html.div( + html.h4( + html.i(classes=["bi", "bi-envelope-fill", "me-2"]), + "Login / Register", + classes=["card-title", "mb-3"], + ), + html.form( + html.div( + html.label( + "Email address", + for_="email", + classes=["form-label"], + ), + html.input( + type="email", + id="email", + name="email", + placeholder="your@email.com", + value="demo@example.com" if is_dev_mode else "", + required=True, + classes=["form-control", "mb-3"], + ), + ), + html.button( + html.i( + classes=["bi", "bi-arrow-right-circle", "me-2"], + ), + "Continue", + type="submit", + classes=["btn", "btn-primary", "w-100"], + ), + hx_post="/login", + hx_target="#login-result", + hx_swap="innerHTML", + ), + html.div( + error or "", + id="login-result", + classes=["mt-3"], + ), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + + +class SubmitForm(Component[AnyChildren, Attrs]): + """Article submission form with HTMX.""" + + @override + def render(self) -> html.div: + return html.div( + html.div( + html.div( + html.h4( + html.i(classes=["bi", "bi-file-earmark-plus", "me-2"]), + "Submit Article", + classes=["card-title", "mb-3"], + ), + html.form( + html.div( + html.label( + "Article URL", + for_="url", + classes=["form-label"], + ), + html.div( + html.input( + type="url", + id="url", + name="url", + placeholder="https://example.com/article", + required=True, + classes=["form-control"], + on_focus="this.select()", + ), + html.button( + html.i(classes=["bi", "bi-send-fill"]), + type="submit", + classes=["btn", "btn-primary"], + ), + classes=["input-group", "mb-3"], + ), + ), + hx_post="/submit", + hx_target="#submit-result", + hx_swap="innerHTML", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "document.getElementById('url').value = ''" + ), + ), + html.div(id="submit-result", classes=["mt-2"]), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + + +class QueueStatusAttrs(Attrs): + """Attributes for QueueStatus component.""" + + items: list[dict[str, typing.Any]] + + +class QueueStatus(Component[AnyChildren, QueueStatusAttrs]): + """Display queue items with auto-refresh.""" + + @override + def render(self) -> html.div: + items = self.attrs["items"] + if not items: + return html.div( + html.h4( + html.i(classes=["bi", "bi-list-check", "me-2"]), + "Queue Status", + classes=["mb-3"], + ), + html.p("No items in queue", classes=["text-muted"]), + ) + + # Map status to Bootstrap badge classes + status_classes = { + "pending": "bg-warning text-dark", + "processing": "bg-primary", + "extracting": "bg-info text-dark", + "synthesizing": "bg-primary", + "uploading": "bg-success", + "error": "bg-danger", + "cancelled": "bg-secondary", + } + + status_icons = { + "pending": "bi-clock", + "processing": "bi-arrow-repeat", + "extracting": "bi-file-text", + "synthesizing": "bi-mic", + "uploading": "bi-cloud-arrow-up", + "error": "bi-exclamation-triangle", + "cancelled": "bi-x-circle", + } + + queue_items = [] + for item in items: + badge_class = status_classes.get(item["status"], "bg-secondary") + icon_class = status_icons.get(item["status"], "bi-question-circle") + + # Get queue position for pending items + queue_pos = None + if item["status"] == "pending": + queue_pos = Core.Database.get_queue_position(item["id"]) + + queue_items.append( + html.div( + html.div( + html.div( + html.strong(f"#{item['id']}", classes=["me-2"]), + html.span( + html.i(classes=["bi", icon_class, "me-1"]), + item["status"].upper(), + classes=["badge", badge_class], + ), + classes=[ + "d-flex", + "align-items-center", + "justify-content-between", + ], + ), + # Add title and author if available + *( + [ + html.div( + html.strong( + item["title"], + classes=["d-block"], + ), + html.small( + f"by {item['author']}", + classes=["text-muted"], + ) + if item.get("author") + else html.span(), + classes=["mt-2"], + ), + ] + if item.get("title") + else [] + ), + html.small( + html.i(classes=["bi", "bi-link-45deg", "me-1"]), + item["url"][: Core.URL_TRUNCATE_LENGTH] + + ( + "..." + if len(item["url"]) > Core.URL_TRUNCATE_LENGTH + else "" + ), + classes=["text-muted", "d-block", "mt-2"], + ), + html.small( + html.i(classes=["bi", "bi-calendar", "me-1"]), + f"Created: {item['created_at']}", + classes=["text-muted", "d-block", "mt-1"], + ), + # Display queue position if available + html.small( + html.i( + classes=["bi", "bi-hourglass-split", "me-1"], + ), + f"Position in queue: #{queue_pos}", + classes=["text-info", "d-block", "mt-1"], + ) + if queue_pos + else html.span(), + *( + [ + html.div( + html.i( + classes=[ + "bi", + "bi-exclamation-circle", + "me-1", + ], + ), + f"Error: {item['error_message']}", + classes=[ + "alert", + "alert-danger", + "mt-2", + "mb-0", + "py-1", + "px-2", + "small", + ], + ), + ] + if item["error_message"] + else [] + ), + # Add cancel button for pending jobs, remove for others + html.div( + # Retry button for error items + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-clockwise", + "me-1", + ], + ), + "Retry", + hx_post=f"/queue/{item['id']}/retry", + hx_trigger="click", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + "mt-2", + "me-2", + ], + ) + if item["status"] == "error" + else html.span(), + html.button( + html.i(classes=["bi", "bi-x-lg", "me-1"]), + "Cancel", + hx_post=f"/queue/{item['id']}/cancel", + hx_trigger="click", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-danger", + "mt-2", + ], + ) + if item["status"] == "pending" + else html.button( + html.i(classes=["bi", "bi-trash", "me-1"]), + "Remove", + hx_delete=f"/queue/{item['id']}", + hx_trigger="click", + hx_confirm="Remove this item from the queue?", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-secondary", + "mt-2", + ], + ), + classes=["mt-2"], + ), + classes=["card-body"], + ), + classes=["card", "mb-2"], + ), + ) + + return html.div( + html.h4( + html.i(classes=["bi", "bi-list-check", "me-2"]), + "Queue Status", + classes=["mb-3"], + ), + *queue_items, + ) + + +class EpisodeListAttrs(Attrs): + """Attributes for EpisodeList component.""" + + episodes: list[dict[str, typing.Any]] + rss_url: str | None + user: dict[str, typing.Any] | None + viewing_own_feed: bool + + +class EpisodeList(Component[AnyChildren, EpisodeListAttrs]): + """List recent episodes (no audio player - use podcast app).""" + + @override + def render(self) -> html.div: + episodes = self.attrs["episodes"] + rss_url = self.attrs.get("rss_url") + user = self.attrs.get("user") + viewing_own_feed = self.attrs.get("viewing_own_feed", False) + + if not episodes: + return html.div( + html.h4( + html.i(classes=["bi", "bi-broadcast", "me-2"]), + "Recent Episodes", + classes=["mb-3"], + ), + html.p("No episodes yet", classes=["text-muted"]), + ) + + episode_items = [] + for episode in episodes: + duration_str = UI.format_duration(episode.get("duration")) + episode_sqid = encode_episode_id(episode["id"]) + is_public = episode.get("is_public", 0) == 1 + + # Admin "Add to public feed" button at bottom of card + admin_button: html.div | html.button = html.div() + if user and Core.is_admin(user): + if is_public: + admin_button = html.button( + html.i(classes=["bi", "bi-check-circle-fill", "me-1"]), + "Added to public feed", + hx_post=f"/admin/episode/{episode['id']}/toggle-public", + hx_target="body", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-success", "mt-2"], + ) + else: + admin_button = html.button( + html.i(classes=["bi", "bi-plus-circle", "me-1"]), + "Add to public feed", + hx_post=f"/admin/episode/{episode['id']}/toggle-public", + hx_target="body", + hx_swap="outerHTML", + classes=[ + "btn", + "btn-sm", + "btn-outline-success", + "mt-2", + ], + ) + + # "Add to my feed" button for logged-in users + # (only when NOT viewing own feed) + user_button: html.div | html.button = html.div() + if user and not viewing_own_feed: + # Check if user already has this episode + user_has_episode = Core.Database.user_has_episode( + user["id"], + episode["id"], + ) + if user_has_episode: + user_button = html.button( + html.i(classes=["bi", "bi-check-circle-fill", "me-1"]), + "In your feed", + disabled=True, + classes=[ + "btn", + "btn-sm", + "btn-secondary", + "mt-2", + "ms-2", + ], + ) + else: + user_button = html.button( + html.i(classes=["bi", "bi-plus-circle", "me-1"]), + "Add to my feed", + hx_post=f"/episode/{episode['id']}/add-to-feed", + hx_swap="none", + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + "mt-2", + "ms-2", + ], + ) + + episode_items.append( + html.div( + html.div( + html.h5( + html.a( + episode["title"], + href=f"/episode/{episode_sqid}", + classes=["text-decoration-none"], + ), + classes=["card-title", "mb-2"], + ), + # Show author if available + html.p( + html.i(classes=["bi", "bi-person", "me-1"]), + f"by {episode['author']}", + classes=["text-muted", "small", "mb-3"], + ) + if episode.get("author") + else html.div(), + html.div( + html.small( + html.i(classes=["bi", "bi-clock", "me-1"]), + f"Duration: {duration_str}", + classes=["text-muted", "me-3"], + ), + html.small( + html.i(classes=["bi", "bi-calendar", "me-1"]), + f"Created: {episode['created_at']}", + classes=["text-muted"], + ), + classes=["mb-2"], + ), + # Show link to original article if available + html.div( + html.a( + html.i(classes=["bi", "bi-link-45deg", "me-1"]), + "View original article", + href=episode["original_url"], + target="_blank", + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + ], + ), + ) + if episode.get("original_url") + else html.div(), + # Buttons row (admin and user buttons) + html.div( + admin_button, + user_button, + classes=["d-flex", "flex-wrap"], + ), + classes=["card-body"], + ), + classes=["card", "mb-3"], + ), + ) + + return html.div( + html.h4( + html.i(classes=["bi", "bi-broadcast", "me-2"]), + "Recent Episodes", + classes=["mb-3"], + ), + # RSS feed link with copy-to-clipboard + html.div( + html.div( + html.label( + html.i(classes=["bi", "bi-rss-fill", "me-2"]), + "Subscribe in your podcast app:", + classes=["form-label", "fw-bold"], + ), + html.div( + html.button( + html.i(classes=["bi", "bi-copy", "me-1"]), + "Copy", + type="button", + id="rss-copy-button", + on_click=f"navigator.clipboard.writeText('{rss_url}'); " # noqa: E501 + "const btn = document.getElementById('rss-copy-button'); " # noqa: E501 + "const originalHTML = btn.innerHTML; " + "btn.innerHTML = '<i class=\"bi bi-check me-1\"></i>Copied!'; " # noqa: E501 + "btn.classList.remove('btn-outline-secondary'); " + "btn.classList.add('btn-success'); " + "setTimeout(() => {{ " + "btn.innerHTML = originalHTML; " + "btn.classList.remove('btn-success'); " + "btn.classList.add('btn-outline-secondary'); " + "}}, 2000);", + classes=["btn", "btn-outline-secondary"], + ), + html.input( + type="text", + value=rss_url or "", + readonly=True, + on_focus="this.select()", + classes=["form-control"], + ), + classes=["input-group", "mb-3"], + ), + ), + ) + if rss_url + else html.div(), + *episode_items, + ) + + +class HomePageAttrs(Attrs): + """Attributes for HomePage component.""" + + queue_items: list[dict[str, typing.Any]] + episodes: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + error: str | None + + +class PublicFeedPageAttrs(Attrs): + """Attributes for PublicFeedPage component.""" + + episodes: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + + +class PublicFeedPage(Component[AnyChildren, PublicFeedPageAttrs]): + """Public feed page without auto-refresh.""" + + @override + def render(self) -> UI.PageLayout: + episodes = self.attrs["episodes"] + user = self.attrs.get("user") + + return UI.PageLayout( + html.div( + html.h2( + html.i(classes=["bi", "bi-globe", "me-2"]), + "Public Feed", + classes=["mb-3"], + ), + html.p( + "Featured articles converted to audio by our community. " + "Subscribe to get new episodes in your podcast app!", + classes=["lead", "text-muted", "mb-4"], + ), + EpisodeList( + episodes=episodes, + rss_url=f"{BASE_URL}/public.rss", + user=user, + viewing_own_feed=False, + ), + ), + user=user, + current_page="public", + error=None, + ) + + +class HomePage(Component[AnyChildren, HomePageAttrs]): + """Main page combining all components.""" + + @staticmethod + def _render_plan_callout( + user: dict[str, typing.Any], + ) -> html.div: + """Render plan info callout box below navbar.""" + tier = user.get("plan_tier", "free") + + if tier == "free": + # Get usage and show quota + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(user["id"], period_start, period_end) + articles_used = usage["articles"] + articles_limit = 10 + articles_left = max(0, articles_limit - articles_used) + + return html.div( + html.div( + html.div( + html.i( + classes=[ + "bi", + "bi-info-circle-fill", + "me-2", + ], + ), + html.strong(f"{articles_left} articles remaining"), + " of your free plan limit. ", + html.br(), + "Upgrade to ", + html.strong("Paid Plan"), + " for unlimited articles at $12/month.", + ), + html.form( + html.input( + type="hidden", + name="tier", + value="paid", + ), + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-up-circle", + "me-1", + ], + ), + "Upgrade Now", + type="submit", + classes=[ + "btn", + "btn-success", + "btn-sm", + "mt-2", + ], + ), + method="post", + action="/billing/checkout", + ), + classes=[ + "alert", + "alert-info", + "d-flex", + "justify-content-between", + "align-items-center", + "mb-4", + ], + ), + classes=["mb-4"], + ) + # Paid user - no callout needed + return html.div() + + @override + def render(self) -> UI.PageLayout | html.html: + queue_items = self.attrs["queue_items"] + episodes = self.attrs["episodes"] + user = self.attrs.get("user") + error = self.attrs.get("error") + + if not user: + # Show public feed with login form for logged-out users + return UI.PageLayout( + LoginForm(error=error), + html.div( + html.h4( + html.i(classes=["bi", "bi-broadcast", "me-2"]), + "Public Feed", + classes=["mb-3", "mt-4"], + ), + html.p( + "Featured articles converted to audio. " + "Sign up to create your own personal feed!", + classes=["text-muted", "mb-3"], + ), + EpisodeList( + episodes=episodes, + rss_url=None, + user=None, + viewing_own_feed=False, + ), + ), + user=None, + current_page="home", + error=error, + ) + + return UI.PageLayout( + self._render_plan_callout(user), + SubmitForm(), + html.div( + QueueStatus(items=queue_items), + EpisodeList( + episodes=episodes, + rss_url=f"{BASE_URL}/feed/{user['token']}.rss", + user=user, + viewing_own_feed=True, + ), + id="dashboard-content", + hx_get="/dashboard-updates", + hx_trigger="every 3s, queue-updated from:body", + hx_swap="innerHTML", + ), + user=user, + current_page="home", + error=error, + ) + + +# Create ludic app with session support +app = LudicApp() +app.add_middleware( + SessionMiddleware, + secret_key=os.getenv("SESSION_SECRET", "dev-secret-key"), + max_age=SESSION_MAX_AGE, # 30 days + same_site="lax", + https_only=App.from_env() == App.Area.Live, # HTTPS only in production +) + + +@app.get("/") +def index(request: Request) -> HomePage: + """Display main page with form and status.""" + user_id = request.session.get("user_id") + user = None + queue_items = [] + episodes = [] + error = request.query_params.get("error") + status = request.query_params.get("status") + + # Map error codes to user-friendly messages + error_messages = { + "invalid_link": "Invalid login link", + "expired_link": "Login link has expired. Please request a new one.", + "user_not_found": "User not found. Please try logging in again.", + "forbidden": "Access denied. Admin privileges required.", + "cancel": "Checkout cancelled.", + } + + # Handle billing status messages + if status == "success": + error_message = None + elif status == "cancel": + error_message = error_messages["cancel"] + else: + error_message = error_messages.get(error) if error else None + + if user_id: + user = Core.Database.get_user_by_id(user_id) + if user: + # Get user-specific queue items and episodes + queue_items = Core.Database.get_user_queue_status( + user_id, + ) + episodes = Core.Database.get_user_episodes( + user_id, + ) + else: + # Show public feed when not logged in + episodes = Core.Database.get_public_episodes(10) + + return HomePage( + queue_items=queue_items, + episodes=episodes, + user=user, + error=error_message, + ) + + +@app.get("/public") +def public_feed(request: Request) -> PublicFeedPage: + """Display public feed page.""" + # Always show public episodes, whether user is logged in or not + episodes = Core.Database.get_public_episodes(50) + user_id = request.session.get("user_id") + user = Core.Database.get_user_by_id(user_id) if user_id else None + + return PublicFeedPage( + episodes=episodes, + user=user, + ) + + +@app.get("/pricing") +def pricing(request: Request) -> UI.PricingPage: + """Display pricing page.""" + user_id = request.session.get("user_id") + user = Core.Database.get_user_by_id(user_id) if user_id else None + + return UI.PricingPage( + user=user, + ) + + +@app.post("/upgrade") +def upgrade(request: Request) -> RedirectResponse: + """Start upgrade checkout flow.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + try: + checkout_url = Billing.create_checkout_session( + user_id, + "paid", + BASE_URL, + ) + return RedirectResponse(url=checkout_url, status_code=303) + except ValueError: + logger.exception("Failed to create checkout session") + return RedirectResponse(url="/pricing?error=checkout_failed") + + +@app.post("/logout") +def logout(request: Request) -> RedirectResponse: + """Log out user.""" + request.session.clear() + return RedirectResponse(url="/", status_code=303) + + +@app.post("/billing/portal") +def billing_portal(request: Request) -> RedirectResponse: + """Redirect to Stripe billing portal.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + try: + portal_url = Billing.create_portal_session(user_id, BASE_URL) + return RedirectResponse(url=portal_url, status_code=303) + except ValueError as e: + logger.warning("Failed to create portal session: %s", e) + # If user has no customer ID (e.g. free tier), redirect to pricing + return RedirectResponse(url="/pricing") + + +def _handle_test_login(email: str, request: Request) -> Response: + """Handle login in test mode.""" + # Special handling for demo account + is_demo_account = email == "demo@example.com" + + user = Core.Database.get_user_by_email(email) + if not user: + # Create new user + status = "active" + user_id, token = Core.Database.create_user(email, status=status) + user = { + "id": user_id, + "email": email, + "token": token, + "status": status, + } + elif is_demo_account and user.get("status") != "active": + # Auto-activate demo account if it exists but isn't active + Core.Database.update_user_status(user["id"], "active") + user["status"] = "active" + + # Check if user is active + if user.get("status") != "active": + pending_message = ( + '<div class="alert alert-warning">' + "Account created, currently pending. " + 'Email <a href="mailto:ben@bensima.com" ' + 'class="alert-link">ben@bensima.com</a> ' + 'or message <a href="https://x.com/bensima" ' + 'target="_blank" class="alert-link">@bensima</a> ' + "to get your account activated.</div>" + ) + return Response(pending_message, status_code=200) + + # Set session with extended lifetime + request.session["user_id"] = user["id"] + request.session["permanent"] = True + + return Response( + '<div class="alert alert-success">✓ Logged in (dev mode)</div>', + status_code=200, + headers={"HX-Redirect": "/"}, + ) + + +def _handle_production_login(email: str) -> Response: + """Handle login in production mode.""" + pending_message = ( + '<div class="alert alert-warning">' + "Account created, currently pending. " + 'Email <a href="mailto:ben@bensima.com" ' + 'class="alert-link">ben@bensima.com</a> ' + 'or message <a href="https://x.com/bensima" ' + 'target="_blank" class="alert-link">@bensima</a> ' + "to get your account activated.</div>" + ) + + # Get or create user + user = Core.Database.get_user_by_email(email) + if not user: + user_id, token = Core.Database.create_user(email) + user = { + "id": user_id, + "email": email, + "token": token, + "status": "active", + } + + # Check if user is active + if user.get("status") != "active": + return Response(pending_message, status_code=200) + + # Generate magic link token + magic_token = magic_link_serializer.dumps({ + "user_id": user["id"], + "email": email, + }) + + # Send email + send_magic_link(email, magic_token) + + return Response( + f'<div class="alert alert-success">✓ Magic link sent to {email}. ' + f"Check your email!</div>", + status_code=200, + ) + + +@app.post("/login") +def login(request: Request, data: FormData) -> Response: + """Handle login/registration.""" + try: + email_raw = data.get("email", "") + email = email_raw.strip().lower() if isinstance(email_raw, str) else "" + + if not email: + return Response( + '<div class="alert alert-danger">Email is required</div>', + status_code=400, + ) + + area = App.from_env() + + if area == App.Area.Test: + return _handle_test_login(email, request) + return _handle_production_login(email) + + except Exception as e: + logger.exception("Login error") + return Response( + f'<div class="alert alert-danger">Error: {e!s}</div>', + status_code=500, + ) + + +@app.get("/auth/verify") +def verify_magic_link(request: Request) -> Response: + """Verify magic link and log user in.""" + token = request.query_params.get("token") + + if not token: + return RedirectResponse("/?error=invalid_link") + + try: + # Verify token + data = magic_link_serializer.loads(token, max_age=MAGIC_LINK_MAX_AGE) + user_id = data["user_id"] + + # Verify user still exists + user = Core.Database.get_user_by_id(user_id) + if not user: + return RedirectResponse("/?error=user_not_found") + + # Set session with extended lifetime + request.session["user_id"] = user_id + request.session["permanent"] = True + + return RedirectResponse("/") + + except (ValueError, KeyError): + # Token is invalid or expired + return RedirectResponse("/?error=expired_link") + + +@app.get("/settings/email/edit") +def edit_email_form(request: Request) -> typing.Any: + """Return form to edit email.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id(user_id) + if not user: + return Response("User not found", status_code=404) + + return html.div( + html.form( + html.strong("Email: ", classes=["me-2"]), + html.input( + type="email", + name="email", + value=user["email"], + required=True, + classes=[ + "form-control", + "form-control-sm", + "d-inline-block", + "w-auto", + "me-2", + ], + ), + html.button( + "Save", + type="submit", + classes=["btn", "btn-sm", "btn-primary", "me-1"], + ), + html.button( + "Cancel", + hx_get="/settings/email/cancel", + hx_target="closest div", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-secondary"], + ), + hx_post="/settings/email", + hx_target="closest div", + hx_swap="outerHTML", + classes=["d-flex", "align-items-center"], + ), + classes=["mb-2"], + ) + + +@app.get("/settings/email/cancel") +def cancel_edit_email(request: Request) -> typing.Any: + """Cancel email editing and show original view.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id(user_id) + if not user: + return Response("User not found", status_code=404) + + return html.div( + html.strong("Email: "), + html.span(user["email"]), + html.button( + "Change", + classes=[ + "btn", + "btn-sm", + "btn-outline-secondary", + "ms-2", + "py-0", + ], + hx_get="/settings/email/edit", + hx_target="closest div", + hx_swap="outerHTML", + ), + classes=["mb-2", "d-flex", "align-items-center"], + ) + + +@app.post("/settings/email") +def update_email(request: Request, data: FormData) -> typing.Any: + """Update user email.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + new_email_raw = data.get("email", "") + new_email = ( + new_email_raw.strip().lower() if isinstance(new_email_raw, str) else "" + ) + + if not new_email: + return Response("Email required", status_code=400) + + try: + Core.Database.update_user_email(user_id, new_email) + return cancel_edit_email(request) + except ValueError as e: + # Return form with error + return html.div( + html.form( + html.strong("Email: ", classes=["me-2"]), + html.input( + type="email", + name="email", + value=new_email, + required=True, + classes=[ + "form-control", + "form-control-sm", + "d-inline-block", + "w-auto", + "me-2", + "is-invalid", + ], + ), + html.button( + "Save", + type="submit", + classes=["btn", "btn-sm", "btn-primary", "me-1"], + ), + html.button( + "Cancel", + hx_get="/settings/email/cancel", + hx_target="closest div", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-secondary"], + ), + html.div( + str(e), + classes=["invalid-feedback", "d-block", "ms-2"], + ), + hx_post="/settings/email", + hx_target="closest div", + hx_swap="outerHTML", + classes=["d-flex", "align-items-center", "flex-wrap"], + ), + classes=["mb-2"], + ) + + +@app.get("/account") +def account_page(request: Request) -> typing.Any: + """Account management page.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + user = Core.Database.get_user_by_id(user_id) + if not user: + return RedirectResponse(url="/?error=user_not_found") + + # Get usage stats + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(user["id"], period_start, period_end) + + # Get limits + tier = user.get("plan_tier", "free") + limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"]) + + return UI.AccountPage( + user=user, + usage=usage, + limits=limits, + portal_url="/billing/portal" if tier == "paid" else None, + ) + + +@app.delete("/account") +def delete_account(request: Request) -> Response: + """Delete user account.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + Core.Database.delete_user(user_id) + request.session.clear() + + return Response( + "Account deleted", + headers={"HX-Redirect": "/?message=account_deleted"}, + ) + + +@app.post("/submit") +def submit_article( # noqa: PLR0911, PLR0914 + request: Request, + data: FormData, +) -> typing.Any: + """Handle manual form submission.""" + try: + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: Please login first", + classes=["alert", "alert-danger"], + ) + + user = Core.Database.get_user_by_id(user_id) + if not user: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: Invalid session", + classes=["alert", "alert-danger"], + ) + + url_raw = data.get("url", "") + url = url_raw.strip() if isinstance(url_raw, str) else "" + + if not url: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: URL is required", + classes=["alert", "alert-danger"], + ) + + # Basic URL validation + parsed = urllib.parse.urlparse(url) + if not parsed.scheme or not parsed.netloc: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: Invalid URL format", + classes=["alert", "alert-danger"], + ) + + # Check usage limits + allowed, _msg, usage = Billing.can_submit(user_id) + if not allowed: + tier = user.get("plan_tier", "free") + tier_info = Billing.get_tier_info(tier) + limit = tier_info.get("articles_limit", 0) + return html.div( + html.i(classes=["bi", "bi-exclamation-circle", "me-2"]), + html.strong("Limit reached: "), + f"You've used {usage['articles']}/{limit} articles " + "this period. ", + html.a( + "Upgrade your plan", + href="/billing", + classes=["alert-link"], + ), + " to continue.", + classes=["alert", "alert-warning"], + ) + + # Check if episode already exists for this URL + url_hash = Core.hash_url(url) + existing_episode = Core.Database.get_episode_by_url_hash(url_hash) + + if existing_episode: + # Episode already processed - check if user has it + episode_id = existing_episode["id"] + if Core.Database.user_has_episode(user_id, episode_id): + return html.div( + html.i(classes=["bi", "bi-info-circle", "me-2"]), + "This episode is already in your feed.", + classes=["alert", "alert-info"], + ) + # Add existing episode to user's feed + Core.Database.add_episode_to_user(user_id, episode_id) + Core.Database.track_episode_event( + episode_id, + "added", + user_id, + ) + return html.div( + html.i(classes=["bi", "bi-check-circle", "me-2"]), + "✓ Episode added to your feed! ", + html.a( + "View episode", + href=f"/episode/{encode_episode_id(episode_id)}", + classes=["alert-link"], + ), + classes=["alert", "alert-success"], + ) + + # Episode doesn't exist yet - extract metadata and queue for processing + title, author = extract_og_metadata(url) + + job_id = Core.Database.add_to_queue( + url, + user["email"], + user_id, + title=title, + author=author, + ) + return html.div( + html.i(classes=["bi", "bi-check-circle", "me-2"]), + f"✓ Article submitted successfully! Job ID: {job_id}", + classes=["alert", "alert-success"], + ) + + except (httpx.HTTPError, httpx.TimeoutException, ValueError) as e: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + f"Error: {e!s}", + classes=["alert", "alert-danger"], + ) + + +@app.get("/feed/{token}.rss") +def rss_feed(request: Request, token: str) -> Response: # noqa: ARG001 + """Generate user-specific RSS podcast feed.""" + try: + # Validate token and get user + user = Core.Database.get_user_by_token(token) + if not user: + return Response("Invalid feed token", status_code=404) + + # Get episodes for this user only + episodes = Core.Database.get_user_all_episodes( + user["id"], + ) + + # Extract first name from email for personalization + email_name = user["email"].split("@")[0].split(".")[0].title() + + fg = FeedGenerator() + fg.title(f"{email_name}'s Article Podcast") + fg.description(f"Web articles converted to audio for {user['email']}") + fg.author(name=RSS_CONFIG["author"]) + fg.language(RSS_CONFIG["language"]) + fg.link(href=f"{RSS_CONFIG['base_url']}/feed/{token}.rss") + fg.id(f"{RSS_CONFIG['base_url']}/feed/{token}.rss") + + for episode in episodes: + fe = fg.add_entry() + episode_sqid = encode_episode_id(episode["id"]) + fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode_sqid}") + fe.title(episode["title"]) + fe.description(episode["title"]) + fe.enclosure( + episode["audio_url"], + str(episode.get("content_length", 0)), + "audio/mpeg", + ) + # SQLite timestamps don't have timezone info, so add UTC + created_at = datetime.fromisoformat(episode["created_at"]) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + fe.pubDate(created_at) + + rss_str = fg.rss_str(pretty=True) + return Response( + rss_str, + media_type="application/rss+xml; charset=utf-8", + ) + + except (ValueError, KeyError, AttributeError) as e: + return Response(f"Error generating feed: {e}", status_code=500) + + +# Backwards compatibility: .xml extension +@app.get("/feed/{token}.xml") +def rss_feed_xml_alias(request: Request, token: str) -> Response: + """Alias for .rss feed (backwards compatibility).""" + return rss_feed(request, token) + + +@app.get("/public.rss") +def public_rss_feed(request: Request) -> Response: # noqa: ARG001 + """Generate public RSS podcast feed.""" + try: + # Get public episodes + episodes = Core.Database.get_public_episodes(50) + + fg = FeedGenerator() + fg.title("PodcastItLater Public Feed") + fg.description("Curated articles converted to audio") + fg.author(name=RSS_CONFIG["author"]) + fg.language(RSS_CONFIG["language"]) + fg.link(href=f"{RSS_CONFIG['base_url']}/public.rss") + fg.id(f"{RSS_CONFIG['base_url']}/public.rss") + + for episode in episodes: + fe = fg.add_entry() + episode_sqid = encode_episode_id(episode["id"]) + fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode_sqid}") + fe.title(episode["title"]) + fe.description(episode["title"]) + fe.enclosure( + episode["audio_url"], + str(episode.get("content_length", 0)), + "audio/mpeg", + ) + # SQLite timestamps don't have timezone info, so add UTC + created_at = datetime.fromisoformat(episode["created_at"]) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + fe.pubDate(created_at) + + rss_str = fg.rss_str(pretty=True) + return Response( + rss_str, + media_type="application/rss+xml; charset=utf-8", + ) + + except (ValueError, KeyError, AttributeError) as e: + return Response(f"Error generating feed: {e}", status_code=500) + + +# Backwards compatibility: .xml extension +@app.get("/public.xml") +def public_rss_feed_xml_alias(request: Request) -> Response: + """Alias for .rss feed (backwards compatibility).""" + return public_rss_feed(request) + + +@app.get("/episode/{episode_id:int}") +def episode_detail_legacy( + request: Request, # noqa: ARG001 + episode_id: int, +) -> RedirectResponse: + """Redirect legacy integer episode IDs to sqid URLs. + + Deprecated: This route exists for backward compatibility. + Will be removed in a future version. + """ + episode_sqid = encode_episode_id(episode_id) + return RedirectResponse( + url=f"/episode/{episode_sqid}", + status_code=301, # Permanent redirect + ) + + +@app.get("/episode/{episode_sqid}") +def episode_detail( + request: Request, + episode_sqid: str, +) -> Episode.EpisodeDetailPage | Response: + """Display individual episode page (public, no auth required).""" + try: + # Decode sqid to episode ID + episode_id = decode_episode_id(episode_sqid) + if episode_id is None: + return Response("Invalid episode ID", status_code=404) + + # Get episode from database + episode = Core.Database.get_episode_by_id(episode_id) + + if not episode: + return Response("Episode not found", status_code=404) + + # Get creator email if episode has user_id + creator_email = None + if episode.get("user_id"): + creator = Core.Database.get_user_by_id(episode["user_id"]) + creator_email = creator["email"] if creator else None + + # Check if current user is logged in + user_id = request.session.get("user_id") + user = None + user_has_episode = False + if user_id: + user = Core.Database.get_user_by_id(user_id) + user_has_episode = Core.Database.user_has_episode( + user_id, + episode_id, + ) + + return Episode.EpisodeDetailPage( + episode=episode, + episode_sqid=episode_sqid, + creator_email=creator_email, + user=user, + base_url=BASE_URL, + user_has_episode=user_has_episode, + ) + + except (ValueError, KeyError) as e: + logger.exception("Error loading episode") + return Response(f"Error loading episode: {e}", status_code=500) + + +@app.get("/status") +def queue_status(request: Request) -> QueueStatus: + """Return HTMX endpoint for live queue updates.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return QueueStatus(items=[]) + + # Get user-specific queue items + queue_items = Core.Database.get_user_queue_status( + user_id, + ) + return QueueStatus(items=queue_items) + + +@app.get("/dashboard-updates") +def dashboard_updates(request: Request) -> Response: + """Return both queue status and recent episodes for dashboard updates.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + queue_status = QueueStatus(items=[]) + episode_list = EpisodeList( + episodes=[], + rss_url=None, + user=None, + viewing_own_feed=False, + ) + # Return HTML as string with both components + return Response( + str(queue_status) + str(episode_list), + media_type="text/html", + ) + + # Get user info for RSS URL + user = Core.Database.get_user_by_id(user_id) + rss_url = f"{BASE_URL}/feed/{user['token']}.rss" if user else None + + # Get user-specific queue items and episodes + queue_items = Core.Database.get_user_queue_status(user_id) + episodes = Core.Database.get_user_recent_episodes(user_id, 10) + + # Return just the content components, not the wrapper div + # The wrapper div with HTMX attributes is in HomePage + queue_status = QueueStatus(items=queue_items) + episode_list = EpisodeList( + episodes=episodes, + rss_url=rss_url, + user=user, + viewing_own_feed=True, + ) + return Response( + str(queue_status) + str(episode_list), + media_type="text/html", + ) + + +# Register admin routes +app.get("/admin")(Admin.admin_queue_status) +app.post("/queue/{job_id}/retry")(Admin.retry_queue_item) + + +@app.post("/billing/checkout") +def billing_checkout(request: Request, data: FormData) -> Response: + """Create Stripe Checkout session.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + tier_raw = data.get("tier", "paid") + tier = tier_raw if isinstance(tier_raw, str) else "paid" + if tier != "paid": + return Response("Invalid tier", status_code=400) + + try: + checkout_url = Billing.create_checkout_session(user_id, tier, BASE_URL) + return RedirectResponse(url=checkout_url, status_code=303) + except ValueError as e: + logger.exception("Checkout error") + return Response(f"Error: {e!s}", status_code=400) + + +@app.post("/stripe/webhook") +async def stripe_webhook(request: Request) -> Response: + """Handle Stripe webhook events.""" + payload = await request.body() + sig_header = request.headers.get("stripe-signature", "") + + try: + result = Billing.handle_webhook_event(payload, sig_header) + return Response(f"OK: {result['status']}", status_code=200) + except Exception as e: + logger.exception("Webhook error") + return Response(f"Error: {e!s}", status_code=400) + + +@app.post("/queue/{job_id}/cancel") +def cancel_queue_item(request: Request, job_id: int) -> Response: + """Cancel a pending queue item.""" + try: + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + # Get job and verify ownership + job = Core.Database.get_job_by_id(job_id) + if job is None or job.get("user_id") != user_id: + return Response("Forbidden", status_code=403) + + # Only allow canceling pending jobs + if job.get("status") != "pending": + return Response("Can only cancel pending jobs", status_code=400) + + # Update status to cancelled + Core.Database.update_job_status( + job_id, + "cancelled", + error="Cancelled by user", + ) + + # Return success with HTMX trigger to refresh + return Response( + "", + status_code=200, + headers={"HX-Trigger": "queue-updated"}, + ) + except (ValueError, KeyError) as e: + return Response( + f"Error cancelling job: {e!s}", + status_code=500, + ) + + +app.delete("/queue/{job_id}")(Admin.delete_queue_item) +app.get("/admin/users")(Admin.admin_users) +app.get("/admin/metrics")(Admin.admin_metrics) +app.post("/admin/users/{user_id}/status")(Admin.update_user_status) +app.post("/admin/episode/{episode_id}/toggle-public")( + Admin.toggle_episode_public, +) + + +@app.post("/episode/{episode_id}/add-to-feed") +def add_episode_to_feed(request: Request, episode_id: int) -> Response: + """Add an episode to the user's feed.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return Response( + '<div class="alert alert-warning">Please login first</div>', + status_code=200, + ) + + # Check if episode exists + episode = Core.Database.get_episode_by_id(episode_id) + if not episode: + return Response( + '<div class="alert alert-danger">Episode not found</div>', + status_code=404, + ) + + # Check if user already has this episode + if Core.Database.user_has_episode(user_id, episode_id): + return Response( + '<div class="alert alert-info">Already in your feed</div>', + status_code=200, + ) + + # Add episode to user's feed + Core.Database.add_episode_to_user(user_id, episode_id) + + # Track the "added" event + Core.Database.track_episode_event(episode_id, "added", user_id) + + # Reload the current page to show updated button state + # Check referer to determine where to redirect + referer = request.headers.get("referer", "/") + return Response( + "", + status_code=200, + headers={"HX-Redirect": referer}, + ) + + +@app.post("/episode/{episode_id}/track") +def track_episode( + request: Request, + episode_id: int, + data: FormData, +) -> Response: + """Track an episode metric event (play, download).""" + # Get event type from form data + event_type_raw = data.get("event_type", "") + event_type = event_type_raw if isinstance(event_type_raw, str) else "" + + # Validate event type + if event_type not in {"played", "downloaded"}: + return Response("Invalid event type", status_code=400) + + # Get user ID if logged in (None for anonymous) + user_id = request.session.get("user_id") + + # Track the event + Core.Database.track_episode_event(episode_id, event_type, user_id) + + return Response("", status_code=200) + + +class BaseWebTest(Test.TestCase): + """Base class for web tests with database setup.""" + + def setUp(self) -> None: + """Set up test database and client.""" + Core.Database.init_db() + # Create test client + self.client = TestClient(app) + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Core.Database.teardown() + + +class TestDurationFormatting(Test.TestCase): + """Test duration formatting functionality.""" + + def test_format_duration_minutes_only(self) -> None: + """Test formatting durations less than an hour.""" + self.assertEqual(UI.format_duration(60), "1m") + self.assertEqual(UI.format_duration(240), "4m") + self.assertEqual(UI.format_duration(300), "5m") + self.assertEqual(UI.format_duration(3599), "60m") + + def test_format_duration_hours_and_minutes(self) -> None: + """Test formatting durations with hours and minutes.""" + self.assertEqual(UI.format_duration(3600), "1h") + self.assertEqual(UI.format_duration(3840), "1h 4m") + self.assertEqual(UI.format_duration(11520), "3h 12m") + self.assertEqual(UI.format_duration(7320), "2h 2m") + + def test_format_duration_round_up(self) -> None: + """Test that seconds are rounded up to nearest minute.""" + self.assertEqual(UI.format_duration(61), "2m") + self.assertEqual(UI.format_duration(119), "2m") + self.assertEqual(UI.format_duration(121), "3m") + self.assertEqual(UI.format_duration(3601), "1h 1m") + + def test_format_duration_edge_cases(self) -> None: + """Test edge cases for duration formatting.""" + self.assertEqual(UI.format_duration(None), "Unknown") + self.assertEqual(UI.format_duration(0), "Unknown") + self.assertEqual(UI.format_duration(-100), "Unknown") + + +class TestAuthentication(BaseWebTest): + """Test authentication functionality.""" + + def test_login_new_user_active(self) -> None: + """New users should be created with active status.""" + response = self.client.post("/login", data={"email": "new@example.com"}) + self.assertEqual(response.status_code, 200) + + # Verify user was created with active status + user = Core.Database.get_user_by_email( + "new@example.com", + ) + self.assertIsNotNone(user) + if user is None: + msg = "no user found" + raise Test.TestError(msg) + self.assertEqual(user.get("status"), "active") + + def test_login_active_user(self) -> None: + """Active users should be able to login.""" + # Create user and set to active + user_id, _ = Core.Database.create_user( + "active@example.com", + ) + Core.Database.update_user_status(user_id, "active") + + response = self.client.post( + "/login", + data={"email": "active@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + def test_login_existing_pending_user(self) -> None: + """Existing pending users should see the pending message.""" + # Create a pending user + _user_id, _ = Core.Database.create_user( + "pending@example.com", + status="pending", + ) + + response = self.client.post( + "/login", + data={"email": "pending@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Account created, currently pending", response.text) + self.assertIn("ben@bensima.com", response.text) + self.assertIn("@bensima", response.text) + + def test_login_disabled_user(self) -> None: + """Disabled users should not be able to login.""" + # Create user and set to disabled + user_id, _ = Core.Database.create_user( + "disabled@example.com", + ) + Core.Database.update_user_status( + user_id, + "disabled", + ) + + response = self.client.post( + "/login", + data={"email": "disabled@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Account created, currently pending", response.text) + + def test_login_invalid_email(self) -> None: + """Reject malformed emails.""" + response = self.client.post("/login", data={"email": ""}) + + self.assertEqual(response.status_code, 400) + self.assertIn("Email is required", response.text) + + def test_session_persistence(self) -> None: + """Verify session across requests.""" + # Create active user + _user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + # Login + self.client.post("/login", data={"email": "test@example.com"}) + + # Access protected page + response = self.client.get("/") + + # Should see logged-in content (navbar with Manage Account link) + self.assertIn("Manage Account", response.text) + self.assertIn("Home", response.text) + + def test_protected_routes_pending_user(self) -> None: + """Pending users should not access protected routes.""" + # Create pending user + Core.Database.create_user("pending@example.com", status="pending") + + # Try to login + response = self.client.post( + "/login", + data={"email": "pending@example.com"}, + ) + self.assertEqual(response.status_code, 200) + + # Should not have session + response = self.client.get("/") + self.assertNotIn("Logged in as:", response.text) + + def test_protected_routes(self) -> None: + """Ensure auth required for user actions.""" + # Try to submit without login + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + self.assertIn("Please login first", response.text) + + +class TestArticleSubmission(BaseWebTest): + """Test article submission functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + # Create active user and login + user_id, _ = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status(user_id, "active") + self.client.post("/login", data={"email": "test@example.com"}) + + def test_submit_valid_url(self) -> None: + """Accept well-formed URLs.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com/article"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Article submitted successfully", response.text) + self.assertIn("Job ID:", response.text) + + def test_submit_invalid_url(self) -> None: + """Reject malformed URLs.""" + response = self.client.post("/submit", data={"url": "not-a-url"}) + + self.assertIn("Invalid URL format", response.text) + + def test_submit_without_auth(self) -> None: + """Reject unauthenticated submissions.""" + # Clear session + self.client.get("/logout") + + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + self.assertIn("Please login first", response.text) + + def test_submit_creates_job(self) -> None: + """Verify job creation in database.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com/test"}, + ) + + # Extract job ID from response + match = re.search(r"Job ID: (\d+)", response.text) + self.assertIsNotNone(match) + if match is None: + self.fail("Job ID not found in response") + job_id = int(match.group(1)) + + # Verify job in database + job = Core.Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: # Type guard for mypy + self.fail("Job should not be None") + self.assertEqual(job["url"], "https://example.com/test") + self.assertEqual(job["status"], "pending") + + def test_htmx_response(self) -> None: + """Ensure proper HTMX response format.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + # Should return HTML fragment, not full page + self.assertNotIn("<!DOCTYPE", response.text) + self.assertIn("<div", response.text) + + +class TestRSSFeed(BaseWebTest): + """Test RSS feed generation.""" + + def setUp(self) -> None: + """Set up test client and create test data.""" + super().setUp() + + # Create user and episodes + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + + # Create test episodes + ep1_id = Core.Database.create_episode( + "Episode 1", + "https://example.com/ep1.mp3", + 300, + 5000, + self.user_id, + ) + ep2_id = Core.Database.create_episode( + "Episode 2", + "https://example.com/ep2.mp3", + 600, + 10000, + self.user_id, + ) + Core.Database.add_episode_to_user(self.user_id, ep1_id) + Core.Database.add_episode_to_user(self.user_id, ep2_id) + + def test_feed_generation(self) -> None: + """Generate valid RSS XML.""" + response = self.client.get(f"/feed/{self.token}.rss") + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.headers["content-type"], + "application/rss+xml; charset=utf-8", + ) + + # Verify RSS structure + self.assertIn("<?xml", response.text) + self.assertIn("<rss", response.text) + self.assertIn("<channel>", response.text) + self.assertIn("<item>", response.text) + + def test_feed_user_isolation(self) -> None: + """Only show user's episodes.""" + # Create another user with episodes + user2_id, _ = Core.Database.create_user( + "other@example.com", + ) + other_ep_id = Core.Database.create_episode( + "Other Episode", + "https://example.com/other.mp3", + 400, + 6000, + user2_id, + ) + Core.Database.add_episode_to_user(user2_id, other_ep_id) + + # Get first user's feed + response = self.client.get(f"/feed/{self.token}.rss") + + # Should only have user's episodes + self.assertIn("Episode 1", response.text) + self.assertIn("Episode 2", response.text) + self.assertNotIn("Other Episode", response.text) + + def test_feed_invalid_token(self) -> None: + """Return 404 for bad tokens.""" + response = self.client.get("/feed/invalid-token.rss") + + self.assertEqual(response.status_code, 404) + + def test_feed_metadata(self) -> None: + """Verify personalized feed titles.""" + response = self.client.get(f"/feed/{self.token}.rss") + + # Should personalize based on email + self.assertIn("Test's Article Podcast", response.text) + self.assertIn("test@example.com", response.text) + + def test_feed_episode_order(self) -> None: + """Ensure reverse chronological order.""" + response = self.client.get(f"/feed/{self.token}.rss") + + # Episode 2 should appear before Episode 1 + ep2_pos = response.text.find("Episode 2") + ep1_pos = response.text.find("Episode 1") + self.assertLess(ep2_pos, ep1_pos) + + def test_feed_enclosures(self) -> None: + """Verify audio URLs and metadata.""" + response = self.client.get(f"/feed/{self.token}.rss") + + # Check enclosure tags + self.assertIn("<enclosure", response.text) + self.assertIn('type="audio/mpeg"', response.text) + + def test_feed_xml_alias_works(self) -> None: + """Test .xml extension works for backwards compatibility.""" + # Get feed with .xml extension + response_xml = self.client.get(f"/feed/{self.token}.xml") + # Get feed with .rss extension + response_rss = self.client.get(f"/feed/{self.token}.rss") + + # Both should work and return same content + self.assertEqual(response_xml.status_code, 200) + self.assertEqual(response_rss.status_code, 200) + self.assertEqual(response_xml.text, response_rss.text) + + def test_public_feed_xml_alias_works(self) -> None: + """Test .xml extension works for public feed.""" + # Get feed with .xml extension + response_xml = self.client.get("/public.xml") + # Get feed with .rss extension + response_rss = self.client.get("/public.rss") + + # Both should work and return same content + self.assertEqual(response_xml.status_code, 200) + self.assertEqual(response_rss.status_code, 200) + self.assertEqual(response_xml.text, response_rss.text) + + +class TestAdminInterface(BaseWebTest): + """Test admin interface functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + + # Create and login admin user + self.user_id, _ = Core.Database.create_user( + "ben@bensima.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Create test data + self.job_id = Core.Database.add_to_queue( + "https://example.com/test", + "ben@bensima.com", + self.user_id, + ) + + def test_queue_status_view(self) -> None: + """Verify queue display.""" + response = self.client.get("/admin") + + self.assertEqual(response.status_code, 200) + self.assertIn("Queue Status", response.text) + self.assertIn("https://example.com/test", response.text) + + def test_retry_action(self) -> None: + """Test retry button functionality.""" + # Set job to error state + Core.Database.update_job_status( + self.job_id, + "error", + "Failed", + ) + + # Retry + response = self.client.post(f"/queue/{self.job_id}/retry") + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Job should be pending again + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job is not None: + self.assertEqual(job["status"], "pending") + + def test_delete_action(self) -> None: + """Test delete button functionality.""" + response = self.client.delete( + f"/queue/{self.job_id}", + headers={"referer": "/admin"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Job should be gone + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNone(job) + + def test_user_data_isolation(self) -> None: + """Ensure admin sees all data.""" + # Create another user's job + user2_id, _ = Core.Database.create_user( + "other@example.com", + ) + Core.Database.add_to_queue( + "https://example.com/other", + "other@example.com", + user2_id, + ) + + # View queue status as admin + response = self.client.get("/admin") + + # Admin should see all jobs + self.assertIn("https://example.com/test", response.text) + self.assertIn("https://example.com/other", response.text) + + def test_status_summary(self) -> None: + """Verify status counts display.""" + # Create jobs with different statuses + Core.Database.update_job_status( + self.job_id, + "error", + "Failed", + ) + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + Core.Database.update_job_status( + job2, + "processing", + ) + + response = self.client.get("/admin") + + # Should show status counts + self.assertIn("ERROR: 1", response.text) + self.assertIn("PROCESSING: 1", response.text) + + +class TestMetricsDashboard(BaseWebTest): + """Test metrics dashboard functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in admin user.""" + super().setUp() + + # Create and login admin user + self.user_id, _ = Core.Database.create_user( + "ben@bensima.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "ben@bensima.com"}) + + def test_metrics_page_requires_admin(self) -> None: + """Verify non-admin users cannot access metrics.""" + # Create non-admin user + user_id, _ = Core.Database.create_user("user@example.com") + Core.Database.update_user_status(user_id, "active") + + # Login as non-admin + self.client.get("/logout") + self.client.post("/login", data={"email": "user@example.com"}) + + # Try to access metrics + response = self.client.get("/admin/metrics", follow_redirects=False) + + # Should redirect + self.assertEqual(response.status_code, 302) + self.assertEqual(response.headers["Location"], "/?error=forbidden") + + def test_metrics_page_requires_login(self) -> None: + """Verify unauthenticated users are redirected.""" + self.client.get("/logout") + + response = self.client.get("/admin/metrics", follow_redirects=False) + + self.assertEqual(response.status_code, 302) + self.assertEqual(response.headers["Location"], "/") + + def test_metrics_displays_summary(self) -> None: + """Verify metrics summary is displayed.""" + # Create test episode + episode_id = Core.Database.create_episode( + title="Test Episode", + audio_url="http://example.com/audio.mp3", + content_length=1000, + duration=300, + ) + Core.Database.add_episode_to_user(self.user_id, episode_id) + + # Track some events + Core.Database.track_episode_event(episode_id, "played") + Core.Database.track_episode_event(episode_id, "played") + Core.Database.track_episode_event(episode_id, "downloaded") + Core.Database.track_episode_event(episode_id, "added", self.user_id) + + # Get metrics page + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Episode Metrics", response.text) + self.assertIn("Total Episodes", response.text) + self.assertIn("Total Plays", response.text) + + def test_growth_metrics_display(self) -> None: + """Verify growth and usage metrics are displayed.""" + # Create an active subscriber + user2_id, _ = Core.Database.create_user("active@example.com") + Core.Database.update_user_subscription( + user2_id, + subscription_id="sub_test", + status="active", + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + tier="paid", + cancel_at_period_end=False, + ) + + # Create a queue item + Core.Database.add_to_queue( + "https://example.com/new", + "active@example.com", + user2_id, + ) + + # Get metrics page + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Growth & Usage", response.text) + self.assertIn("Total Users", response.text) + self.assertIn("Active Subs", response.text) + self.assertIn("Submissions (24h)", response.text) + + self.assertIn("Total Downloads", response.text) + self.assertIn("Total Adds", response.text) + + def test_metrics_shows_top_episodes(self) -> None: + """Verify top episodes tables are displayed.""" + # Create test episodes + episode1 = Core.Database.create_episode( + title="Popular Episode", + audio_url="http://example.com/popular.mp3", + content_length=1000, + duration=300, + author="Test Author", + ) + Core.Database.add_episode_to_user(self.user_id, episode1) + + episode2 = Core.Database.create_episode( + title="Less Popular Episode", + audio_url="http://example.com/less.mp3", + content_length=1000, + duration=300, + ) + Core.Database.add_episode_to_user(self.user_id, episode2) + + # Track events - more for episode1 + for _ in range(5): + Core.Database.track_episode_event(episode1, "played") + for _ in range(2): + Core.Database.track_episode_event(episode2, "played") + + for _ in range(3): + Core.Database.track_episode_event(episode1, "downloaded") + Core.Database.track_episode_event(episode2, "downloaded") + + # Get metrics page + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Most Played", response.text) + self.assertIn("Most Downloaded", response.text) + self.assertIn("Popular Episode", response.text) + + def test_metrics_empty_state(self) -> None: + """Verify metrics page works with no data.""" + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Episode Metrics", response.text) + # Should show 0 for counts + self.assertIn("Total Episodes", response.text) + + +class TestJobCancellation(BaseWebTest): + """Test job cancellation functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user and pending job.""" + super().setUp() + + # Create and login user + self.user_id, _ = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + # Create pending job + self.job_id = Core.Database.add_to_queue( + "https://example.com/test", + "test@example.com", + self.user_id, + ) + + def test_cancel_pending_job(self) -> None: + """Successfully cancel a pending job.""" + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Trigger", response.headers) + self.assertEqual(response.headers["HX-Trigger"], "queue-updated") + + # Verify job status is cancelled + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job is not None: + self.assertEqual(job["status"], "cancelled") + self.assertEqual(job.get("error_message", ""), "Cancelled by user") + + def test_cannot_cancel_processing_job(self) -> None: + """Prevent cancelling jobs that are already processing.""" + # Set job to processing + Core.Database.update_job_status( + self.job_id, + "processing", + ) + + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 400) + self.assertIn("Can only cancel pending jobs", response.text) + + def test_cannot_cancel_completed_job(self) -> None: + """Prevent cancelling completed jobs.""" + # Set job to completed + Core.Database.update_job_status( + self.job_id, + "completed", + ) + + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 400) + + def test_cannot_cancel_other_users_job(self) -> None: + """Prevent users from cancelling other users' jobs.""" + # Create another user's job + user2_id, _ = Core.Database.create_user( + "other@example.com", + ) + other_job_id = Core.Database.add_to_queue( + "https://example.com/other", + "other@example.com", + user2_id, + ) + + # Try to cancel it + response = self.client.post(f"/queue/{other_job_id}/cancel") + + self.assertEqual(response.status_code, 403) + + def test_cancel_without_auth(self) -> None: + """Require authentication to cancel jobs.""" + # Logout + self.client.get("/logout") + + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 401) + + def test_cancel_button_visibility(self) -> None: + """Cancel button only shows for pending jobs.""" + # Create jobs with different statuses + processing_job = Core.Database.add_to_queue( + "https://example.com/processing", + "test@example.com", + self.user_id, + ) + Core.Database.update_job_status( + processing_job, + "processing", + ) + + # Get status view + response = self.client.get("/status") + + # Should have cancel button for pending job + self.assertIn(f'hx-post="/queue/{self.job_id}/cancel"', response.text) + self.assertIn("Cancel", response.text) + + # Should NOT have cancel button for processing job + self.assertNotIn( + f'hx-post="/queue/{processing_job}/cancel"', + response.text, + ) + + +class TestEpisodeDetailPage(BaseWebTest): + """Test episode detail page functionality.""" + + def setUp(self) -> None: + """Set up test client with user and episode.""" + super().setUp() + + # Create user and episode + self.user_id, self.token = Core.Database.create_user( + "creator@example.com", + status="active", + ) + self.episode_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + author="Test Author", + original_url="https://example.com/article", + ) + Core.Database.add_episode_to_user(self.user_id, self.episode_id) + self.episode_sqid = encode_episode_id(self.episode_id) + + def test_episode_page_loads(self) -> None: + """Episode page should load successfully.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertEqual(response.status_code, 200) + self.assertIn("Test Episode", response.text) + self.assertIn("Test Author", response.text) + + def test_episode_not_found(self) -> None: + """Non-existent episode should return 404.""" + response = self.client.get("/episode/invalidcode") + + self.assertEqual(response.status_code, 404) + + def test_audio_player_present(self) -> None: + """Audio player should be present on episode page.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("<audio", response.text) + self.assertIn("controls", response.text) + self.assertIn("https://example.com/audio.mp3", response.text) + + def test_share_button_present(self) -> None: + """Share button should be present.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("Share Episode", response.text) + self.assertIn("navigator.clipboard.writeText", response.text) + + def test_original_article_link(self) -> None: + """Original article link should be present.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("View original article", response.text) + self.assertIn("https://example.com/article", response.text) + + def test_signup_banner_for_non_authenticated(self) -> None: + """Non-authenticated users should see signup banner.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("This episode was created by", response.text) + self.assertIn("creator@example.com", response.text) + self.assertIn("Sign Up", response.text) + + def test_no_signup_banner_for_authenticated(self) -> None: + """Authenticated users should not see signup banner.""" + # Login + self.client.post("/login", data={"email": "creator@example.com"}) + + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertNotIn("This episode was created by", response.text) + + def test_episode_links_from_home_page(self) -> None: + """Episode titles on home page should link to detail page.""" + # Login to see episodes + self.client.post("/login", data={"email": "creator@example.com"}) + + response = self.client.get("/") + + self.assertIn(f'href="/episode/{self.episode_sqid}"', response.text) + self.assertIn("Test Episode", response.text) + + def test_legacy_integer_id_redirects(self) -> None: + """Legacy integer episode IDs should redirect to sqid URLs.""" + response = self.client.get( + f"/episode/{self.episode_id}", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 301) + self.assertEqual( + response.headers["location"], + f"/episode/{self.episode_sqid}", + ) + + +class TestPublicFeed(BaseWebTest): + """Test public feed functionality.""" + + def setUp(self) -> None: + """Set up test database, client, and create sample episodes.""" + super().setUp() + + # Create admin user + self.admin_id, _ = Core.Database.create_user( + "ben@bensima.com", + status="active", + ) + + # Create some episodes, some public, some private + self.public_episode_id = Core.Database.create_episode( + title="Public Episode", + audio_url="https://example.com/public.mp3", + duration=300, + content_length=1000, + user_id=self.admin_id, + author="Test Author", + original_url="https://example.com/public", + original_url_hash=Core.hash_url("https://example.com/public"), + ) + Core.Database.mark_episode_public(self.public_episode_id) + + self.private_episode_id = Core.Database.create_episode( + title="Private Episode", + audio_url="https://example.com/private.mp3", + duration=200, + content_length=800, + user_id=self.admin_id, + author="Test Author", + original_url="https://example.com/private", + original_url_hash=Core.hash_url("https://example.com/private"), + ) + + def test_public_feed_page(self) -> None: + """Public feed page should show only public episodes.""" + response = self.client.get("/public") + + self.assertEqual(response.status_code, 200) + self.assertIn("Public Episode", response.text) + self.assertNotIn("Private Episode", response.text) + + def test_home_page_shows_public_feed_when_logged_out(self) -> None: + """Home page should show public episodes when user is not logged in.""" + response = self.client.get("/") + + self.assertEqual(response.status_code, 200) + self.assertIn("Public Episode", response.text) + self.assertNotIn("Private Episode", response.text) + + def test_admin_can_toggle_episode_public(self) -> None: + """Admin should be able to toggle episode public/private status.""" + # Login as admin + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Toggle private episode to public + response = self.client.post( + f"/admin/episode/{self.private_episode_id}/toggle-public", + ) + + self.assertEqual(response.status_code, 200) + + # Verify it's now public + episode = Core.Database.get_episode_by_id(self.private_episode_id) + self.assertEqual(episode["is_public"], 1) # type: ignore[index] + + def test_non_admin_cannot_toggle_public(self) -> None: + """Non-admin users should not be able to toggle public status.""" + # Create and login as regular user + _user_id, _ = Core.Database.create_user("user@example.com") + self.client.post("/login", data={"email": "user@example.com"}) + + # Try to toggle + response = self.client.post( + f"/admin/episode/{self.private_episode_id}/toggle-public", + ) + + self.assertEqual(response.status_code, 403) + + def test_admin_can_add_user_episode_to_own_feed(self) -> None: + """Admin can add another user's episode to their own feed.""" + # Create regular user and their episode + user_id, _ = Core.Database.create_user( + "user@example.com", + status="active", + ) + user_episode_id = Core.Database.create_episode( + title="User Episode", + audio_url="https://example.com/user.mp3", + duration=400, + content_length=1200, + user_id=user_id, + author="User Author", + original_url="https://example.com/user-article", + original_url_hash=Core.hash_url("https://example.com/user-article"), + ) + Core.Database.add_episode_to_user(user_id, user_episode_id) + + # Login as admin + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Admin adds user's episode to their own feed + response = self.client.post(f"/episode/{user_episode_id}/add-to-feed") + + self.assertEqual(response.status_code, 200) + + # Verify episode is now in admin's feed + admin_episodes = Core.Database.get_user_episodes(self.admin_id) + episode_ids = [e["id"] for e in admin_episodes] + self.assertIn(user_episode_id, episode_ids) + + # Verify "added" event was tracked + metrics = Core.Database.get_episode_metric_events(user_episode_id) + added_events = [m for m in metrics if m["event_type"] == "added"] + self.assertEqual(len(added_events), 1) + self.assertEqual(added_events[0]["user_id"], self.admin_id) + + def test_admin_can_add_user_episode_to_public_feed(self) -> None: + """Admin should be able to add another user's episode to public feed.""" + # Create regular user and their episode + user_id, _ = Core.Database.create_user( + "user@example.com", + status="active", + ) + user_episode_id = Core.Database.create_episode( + title="User Episode for Public", + audio_url="https://example.com/user-public.mp3", + duration=500, + content_length=1500, + user_id=user_id, + author="User Author", + original_url="https://example.com/user-public-article", + original_url_hash=Core.hash_url( + "https://example.com/user-public-article", + ), + ) + Core.Database.add_episode_to_user(user_id, user_episode_id) + + # Verify episode is private initially + episode = Core.Database.get_episode_by_id(user_episode_id) + self.assertEqual(episode["is_public"], 0) # type: ignore[index] + + # Login as admin + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Admin toggles episode to public + response = self.client.post( + f"/admin/episode/{user_episode_id}/toggle-public", + ) + + self.assertEqual(response.status_code, 200) + + # Verify episode is now public + episode = Core.Database.get_episode_by_id(user_episode_id) + self.assertEqual(episode["is_public"], 1) # type: ignore[index] + + # Verify episode appears in public feed + public_episodes = Core.Database.get_public_episodes() + episode_ids = [e["id"] for e in public_episodes] + self.assertIn(user_episode_id, episode_ids) + + +class TestEpisodeDeduplication(BaseWebTest): + """Test episode deduplication functionality.""" + + def setUp(self) -> None: + """Set up test database, client, and create test user.""" + super().setUp() + + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + status="active", + ) + + # Create an existing episode + self.existing_url = "https://example.com/article" + self.url_hash = Core.hash_url(self.existing_url) + + self.episode_id = Core.Database.create_episode( + title="Existing Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test Author", + original_url=self.existing_url, + original_url_hash=self.url_hash, + ) + + def test_url_normalization(self) -> None: + """URLs should be normalized for deduplication.""" + # Different URL variations that should be normalized to same hash + urls = [ + "http://example.com/article", + "https://example.com/article", + "https://www.example.com/article", + "https://EXAMPLE.COM/article", + "https://example.com/article/", + ] + + hashes = [Core.hash_url(url) for url in urls] + + # All should produce the same hash + self.assertEqual(len(set(hashes)), 1) + + def test_find_existing_episode_by_hash(self) -> None: + """Should find existing episode by normalized URL hash.""" + # Try different URL variations + similar_urls = [ + "http://example.com/article", + "https://www.example.com/article", + ] + + for url in similar_urls: + url_hash = Core.hash_url(url) + episode = Core.Database.get_episode_by_url_hash(url_hash) + + self.assertIsNotNone(episode) + if episode is not None: + self.assertEqual(episode["id"], self.episode_id) + + def test_add_existing_episode_to_user_feed(self) -> None: + """Should add existing episode to new user's feed.""" + # Create second user + user2_id, _ = Core.Database.create_user("user2@example.com") + + # Add existing episode to their feed + Core.Database.add_episode_to_user(user2_id, self.episode_id) + + # Verify it appears in their feed + episodes = Core.Database.get_user_episodes(user2_id) + episode_ids = [e["id"] for e in episodes] + + self.assertIn(self.episode_id, episode_ids) + + +class TestMetricsTracking(BaseWebTest): + """Test episode metrics tracking.""" + + def setUp(self) -> None: + """Set up test database, client, and create test episode.""" + super().setUp() + + self.user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + + self.episode_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test Author", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + + def test_track_episode_added(self) -> None: + """Should track when episode is added to feed.""" + Core.Database.track_episode_event( + self.episode_id, + "added", + self.user_id, + ) + + # Verify metric was recorded + metrics = Core.Database.get_episode_metric_events(self.episode_id) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["event_type"], "added") + self.assertEqual(metrics[0]["user_id"], self.user_id) + + def test_track_episode_played(self) -> None: + """Should track when episode is played.""" + Core.Database.track_episode_event( + self.episode_id, + "played", + self.user_id, + ) + + metrics = Core.Database.get_episode_metric_events(self.episode_id) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["event_type"], "played") + + def test_track_anonymous_play(self) -> None: + """Should track plays from anonymous users.""" + Core.Database.track_episode_event( + self.episode_id, + "played", + user_id=None, + ) + + metrics = Core.Database.get_episode_metric_events(self.episode_id) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["event_type"], "played") + self.assertIsNone(metrics[0]["user_id"]) + + def test_track_endpoint(self) -> None: + """POST /episode/{id}/track should record metrics.""" + # Login as user + self.client.post("/login", data={"email": "test@example.com"}) + + response = self.client.post( + f"/episode/{self.episode_id}/track", + data={"event_type": "played"}, + ) + + self.assertEqual(response.status_code, 200) + + # Verify metric was recorded + metrics = Core.Database.get_episode_metric_events(self.episode_id) + played_metrics = [m for m in metrics if m["event_type"] == "played"] + self.assertGreater(len(played_metrics), 0) + + +class TestUsageLimits(BaseWebTest): + """Test usage tracking and limit enforcement.""" + + def setUp(self) -> None: + """Set up test with free tier user.""" + super().setUp() + + # Create free tier user + self.user_id, self.token = Core.Database.create_user( + "free@example.com", + status="active", + ) + # Login + self.client.post("/login", data={"email": "free@example.com"}) + + def test_usage_counts_episodes_added_to_feed(self) -> None: + """Usage should count episodes added via user_episodes table.""" + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNotNone(user) + assert user is not None # type narrowing # noqa: S101 + period_start, period_end = Billing.get_period_boundaries(user) + + # Initially no usage + usage = Billing.get_usage(self.user_id, period_start, period_end) + self.assertEqual(usage["articles"], 0) + + # Add an episode to user's feed + ep_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/test.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Usage should now be 1 + usage = Billing.get_usage(self.user_id, period_start, period_end) + self.assertEqual(usage["articles"], 1) + + def test_usage_counts_existing_episodes_correctly(self) -> None: + """Adding existing episodes should count toward usage.""" + # Create another user who creates an episode + other_user_id, _ = Core.Database.create_user("other@example.com") + ep_id = Core.Database.create_episode( + title="Other User Episode", + audio_url="https://example.com/other.mp3", + duration=400, + content_length=1200, + user_id=other_user_id, + author="Other", + original_url="https://example.com/other-article", + original_url_hash=Core.hash_url( + "https://example.com/other-article", + ), + ) + Core.Database.add_episode_to_user(other_user_id, ep_id) + + # Free user adds it to their feed + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Check usage for free user + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNotNone(user) + assert user is not None # type narrowing # noqa: S101 + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(self.user_id, period_start, period_end) + + # Should count as 1 article for free user + self.assertEqual(usage["articles"], 1) + + def test_free_tier_limit_enforcement(self) -> None: + """Free tier users should be blocked at 10 articles.""" + # Add 10 episodes (the free tier limit) + for i in range(10): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Try to submit 11th article + response = self.client.post( + "/submit", + data={"url": "https://example.com/article11"}, + ) + + # Should be blocked + self.assertEqual(response.status_code, 200) + self.assertIn("Limit reached", response.text) + self.assertIn("10", response.text) + self.assertIn("Upgrade", response.text) + + def test_can_submit_blocks_at_limit(self) -> None: + """can_submit should return False at limit.""" + # Add 10 episodes + for i in range(10): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Check can_submit + allowed, msg, usage = Billing.can_submit(self.user_id) + + self.assertFalse(allowed) + self.assertIn("10", msg) + self.assertIn("limit", msg.lower()) + self.assertEqual(usage["articles"], 10) + + def test_paid_tier_unlimited(self) -> None: + """Paid tier should have no article limits.""" + # Create a paid tier user directly + paid_user_id, _ = Core.Database.create_user("paid@example.com") + + # Simulate paid subscription via update_user_subscription + now = datetime.now(timezone.utc) + period_start = now + december = 12 + january = 1 + period_end = now.replace( + month=now.month + 1 if now.month < december else january, + ) + + Core.Database.update_user_subscription( + paid_user_id, + subscription_id="sub_test123", + status="active", + period_start=period_start, + period_end=period_end, + tier="paid", + cancel_at_period_end=False, + ) + + # Add 20 episodes (more than free limit) + for i in range(20): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=paid_user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(paid_user_id, ep_id) + + # Should still be allowed to submit + allowed, msg, usage = Billing.can_submit(paid_user_id) + + self.assertTrue(allowed) + self.assertEqual(msg, "") + self.assertEqual(usage["articles"], 20) + + +class TestAccountPage(BaseWebTest): + """Test account page functionality.""" + + def setUp(self) -> None: + """Set up test with user.""" + super().setUp() + self.user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + def test_account_page_logged_in(self) -> None: + """Account page should render for logged-in users.""" + # Create some usage to verify stats are shown + ep_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test Author", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + response = self.client.get("/account") + + self.assertEqual(response.status_code, 200) + self.assertIn("My Account", response.text) + self.assertIn("test@example.com", response.text) + self.assertIn("1 / 10", response.text) # Usage / Limit for free tier + + def test_account_page_login_required(self) -> None: + """Should redirect to login if not logged in.""" + self.client.post("/logout") + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + self.assertEqual(response.headers["location"], "/?error=login_required") + + def test_logout(self) -> None: + """Logout should clear session.""" + response = self.client.post("/logout", follow_redirects=False) + self.assertEqual(response.status_code, 303) + self.assertEqual(response.headers["location"], "/") + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + def test_billing_portal_redirect(self) -> None: + """Billing portal should redirect to Stripe.""" + # First set a customer ID + Core.Database.set_user_stripe_customer(self.user_id, "cus_test") + + # Mock the create_portal_session method + with patch( + "Biz.PodcastItLater.Billing.create_portal_session", + ) as mock_portal: + mock_portal.return_value = "https://billing.stripe.com/test" + + response = self.client.post( + "/billing/portal", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 303) + self.assertEqual( + response.headers["location"], + "https://billing.stripe.com/test", + ) + + def test_update_email_success(self) -> None: + """Should allow updating email.""" + # POST new email + response = self.client.post( + "/settings/email", + data={"email": "new@example.com"}, + ) + self.assertEqual(response.status_code, 200) + + # Verify update in DB + user = Core.Database.get_user_by_id(self.user_id) + self.assertEqual(user["email"], "new@example.com") # type: ignore[index] + + def test_update_email_duplicate(self) -> None: + """Should prevent updating to existing email.""" + # Create another user + Core.Database.create_user("other@example.com") + + # Try to update to their email + response = self.client.post( + "/settings/email", + data={"email": "other@example.com"}, + ) + + # Should show error (return 200 with error message in form) + self.assertEqual(response.status_code, 200) + self.assertIn("already taken", response.text.lower()) + + def test_delete_account(self) -> None: + """Should allow user to delete their account.""" + # Delete account + response = self.client.delete("/account") + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Verify user gone + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNone(user) + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + +class TestAdminUsers(BaseWebTest): + """Test admin user management functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in admin user.""" + super().setUp() + + # Create and login admin user + self.user_id, _ = Core.Database.create_user( + "ben@bensima.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Create another regular user + self.other_user_id, _ = Core.Database.create_user("user@example.com") + Core.Database.update_user_status(self.other_user_id, "active") + + def test_admin_users_page_access(self) -> None: + """Admin can access users page.""" + response = self.client.get("/admin/users") + self.assertEqual(response.status_code, 200) + self.assertIn("User Management", response.text) + self.assertIn("user@example.com", response.text) + + def test_non_admin_users_page_access(self) -> None: + """Non-admin cannot access users page.""" + # Login as regular user + self.client.get("/logout") + self.client.post("/login", data={"email": "user@example.com"}) + + response = self.client.get("/admin/users") + self.assertEqual(response.status_code, 302) + self.assertIn("error=forbidden", response.headers["Location"]) + + def test_admin_can_update_user_status(self) -> None: + """Admin can update user status.""" + response = self.client.post( + f"/admin/users/{self.other_user_id}/status", + data={"status": "disabled"}, + ) + self.assertEqual(response.status_code, 200) + + user = Core.Database.get_user_by_id(self.other_user_id) + assert user is not None # noqa: S101 + self.assertEqual(user["status"], "disabled") + + def test_non_admin_cannot_update_user_status(self) -> None: + """Non-admin cannot update user status.""" + # Login as regular user + self.client.get("/logout") + self.client.post("/login", data={"email": "user@example.com"}) + + response = self.client.post( + f"/admin/users/{self.other_user_id}/status", + data={"status": "disabled"}, + ) + self.assertEqual(response.status_code, 403) + + user = Core.Database.get_user_by_id(self.other_user_id) + assert user is not None # noqa: S101 + self.assertEqual(user["status"], "active") + + def test_update_user_status_invalid_status(self) -> None: + """Invalid status validation.""" + response = self.client.post( + f"/admin/users/{self.other_user_id}/status", + data={"status": "invalid_status"}, + ) + self.assertEqual(response.status_code, 400) + + user = Core.Database.get_user_by_id(self.other_user_id) + assert user is not None # noqa: S101 + self.assertEqual(user["status"], "active") + + +def test() -> None: + """Run all tests for the web module.""" + Test.run( + App.Area.Test, + [ + TestDurationFormatting, + TestAuthentication, + TestArticleSubmission, + TestRSSFeed, + TestAdminInterface, + TestJobCancellation, + TestEpisodeDetailPage, + TestPublicFeed, + TestEpisodeDeduplication, + TestMetricsTracking, + TestUsageLimits, + TestAccountPage, + TestAdminUsers, + ], + ) + + +def main() -> None: + """Run the web server.""" + if "test" in sys.argv: + test() + else: + # Initialize database on startup + Core.Database.init_db() + uvicorn.run(app, host="0.0.0.0", port=PORT) # noqa: S104 diff --git a/Biz/PodcastItLater/Worker.nix b/Biz/PodcastItLater/Worker.nix new file mode 100644 index 0000000..974a3ba --- /dev/null +++ b/Biz/PodcastItLater/Worker.nix @@ -0,0 +1,63 @@ +{ + options, + lib, + config, + pkgs, + ... +}: let + cfg = config.services.podcastitlater-worker; +in { + options.services.podcastitlater-worker = { + enable = lib.mkEnableOption "Enable the PodcastItLater worker service"; + dataDir = lib.mkOption { + type = lib.types.path; + default = "/var/podcastitlater"; + description = "Data directory for PodcastItLater (shared with web)"; + }; + package = lib.mkOption { + type = lib.types.package; + description = "PodcastItLater worker package to use"; + }; + }; + config = lib.mkIf cfg.enable { + systemd.services.podcastitlater-worker = { + path = [cfg.package pkgs.ffmpeg]; # ffmpeg needed for pydub + wantedBy = ["multi-user.target"]; + after = ["network.target"]; + preStart = '' + # Create data directory if it doesn't exist + mkdir -p ${cfg.dataDir} + + # Manual step: create this file with secrets + # OPENAI_API_KEY=your-openai-api-key + # S3_ENDPOINT=https://your-s3-endpoint.digitaloceanspaces.com + # S3_BUCKET=your-bucket-name + # S3_ACCESS_KEY=your-s3-access-key + # S3_SECRET_KEY=your-s3-secret-key + test -f /run/podcastitlater/worker-env + ''; + script = '' + ${cfg.package}/bin/podcastitlater-worker + ''; + description = '' + PodcastItLater Worker Service - processes articles to podcasts + ''; + serviceConfig = { + Environment = [ + "AREA=Live" + "DATA_DIR=${cfg.dataDir}" + ]; + EnvironmentFile = "/run/podcastitlater/worker-env"; + KillSignal = "TERM"; + KillMode = "mixed"; + Type = "simple"; + Restart = "always"; + RestartSec = "10"; + # Give the worker time to finish current job + TimeoutStopSec = "300"; # 5 minutes + # Send SIGTERM first, then SIGKILL after timeout + SendSIGKILL = "yes"; + }; + }; + }; +} diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py new file mode 100644 index 0000000..bf6ef9e --- /dev/null +++ b/Biz/PodcastItLater/Worker.py @@ -0,0 +1,2199 @@ +"""Background worker for processing article-to-podcast conversions.""" + +# : dep boto3 +# : dep botocore +# : dep openai +# : dep psutil +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep trafilatura +# : out podcastitlater-worker +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import boto3 # type: ignore[import-untyped] +import concurrent.futures +import io +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import operator +import os +import psutil # type: ignore[import-untyped] +import pytest +import signal +import sys +import tempfile +import threading +import time +import trafilatura +import typing +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from pathlib import Path +from pydub import AudioSegment # type: ignore[import-untyped] +from typing import Any + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +S3_ENDPOINT = os.getenv("S3_ENDPOINT") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") +area = App.from_env() + +# Worker configuration +MAX_CONTENT_LENGTH = 5000 # characters for TTS +MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles +POLL_INTERVAL = 30 # seconds +MAX_RETRIES = 3 +TTS_MODEL = "tts-1" +TTS_VOICE = "alloy" +MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage +CROSSFADE_DURATION = 500 # ms for crossfading segments +PAUSE_DURATION = 1000 # ms for silence between segments + + +class ShutdownHandler: + """Handles graceful shutdown of the worker.""" + + def __init__(self) -> None: + """Initialize shutdown handler.""" + self.shutdown_requested = threading.Event() + self.current_job_id: int | None = None + self.lock = threading.Lock() + + # Register signal handlers + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + def _handle_signal(self, signum: int, _frame: Any) -> None: + """Handle shutdown signals.""" + logger.info( + "Received signal %d, initiating graceful shutdown...", + signum, + ) + self.shutdown_requested.set() + + def is_shutdown_requested(self) -> bool: + """Check if shutdown has been requested.""" + return self.shutdown_requested.is_set() + + def set_current_job(self, job_id: int | None) -> None: + """Set the currently processing job.""" + with self.lock: + self.current_job_id = job_id + + def get_current_job(self) -> int | None: + """Get the currently processing job.""" + with self.lock: + return self.current_job_id + + +class ArticleProcessor: + """Handles the complete article-to-podcast conversion pipeline.""" + + def __init__(self, shutdown_handler: ShutdownHandler) -> None: + """Initialize the processor with required services. + + Raises: + ValueError: If OPENAI_API_KEY environment variable is not set. + """ + if not OPENAI_API_KEY: + msg = "OPENAI_API_KEY environment variable is required" + raise ValueError(msg) + + self.openai_client: openai.OpenAI = openai.OpenAI( + api_key=OPENAI_API_KEY, + ) + self.shutdown_handler = shutdown_handler + + # Initialize S3 client for Digital Ocean Spaces + if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): + self.s3_client: Any = boto3.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + ) + else: + logger.warning("S3 configuration incomplete, uploads will fail") + self.s3_client = None + + @staticmethod + def extract_article_content( + url: str, + ) -> tuple[str, str, str | None, str | None]: + """Extract title, content, author, and date from article URL. + + Returns: + tuple: (title, content, author, publication_date) + + Raises: + ValueError: If content cannot be downloaded, extracted, or large. + """ + try: + downloaded = trafilatura.fetch_url(url) + if not downloaded: + msg = f"Failed to download content from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Check size before processing + if ( + len(downloaded) > MAX_ARTICLE_SIZE * 4 + ): # Rough HTML to text ratio + msg = f"Article too large: {len(downloaded)} bytes" + raise ValueError(msg) # noqa: TRY301 + + # Extract with metadata + result = trafilatura.extract( + downloaded, + include_comments=False, + include_tables=False, + with_metadata=True, + output_format="json", + ) + + if not result: + msg = f"Failed to extract content from {url}" + raise ValueError(msg) # noqa: TRY301 + + data = json.loads(result) + + title = data.get("title", "Untitled Article") + content = data.get("text", "") + author = data.get("author") + pub_date = data.get("date") + + if not content: + msg = f"No content extracted from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Enforce content size limit + if len(content) > MAX_ARTICLE_SIZE: + logger.warning( + "Article content truncated from %d to %d characters", + len(content), + MAX_ARTICLE_SIZE, + ) + content = content[:MAX_ARTICLE_SIZE] + + logger.info( + "Extracted article: %s (%d chars, author: %s, date: %s)", + title, + len(content), + author or "unknown", + pub_date or "unknown", + ) + except Exception: + logger.exception("Failed to extract content from %s", url) + raise + else: + return title, content, author, pub_date + + def text_to_speech( + self, + text: str, + title: str, + author: str | None = None, + pub_date: str | None = None, + ) -> bytes: + """Convert text to speech with intro/outro using OpenAI TTS API. + + Uses parallel processing for chunks while maintaining order. + Adds intro with metadata and outro with attribution. + + Args: + text: Article content to convert + title: Article title + author: Article author (optional) + pub_date: Publication date (optional) + + Raises: + ValueError: If no chunks are generated from text. + """ + try: + # Generate intro audio + intro_text = self._create_intro_text(title, author, pub_date) + intro_audio = self._generate_tts_segment(intro_text) + + # Generate outro audio + outro_text = self._create_outro_text(title, author) + outro_audio = self._generate_tts_segment(outro_text) + + # Use LLM to prepare and chunk the main content + chunks = prepare_text_for_tts(text, title) + + if not chunks: + msg = "No chunks generated from text" + raise ValueError(msg) # noqa: TRY301 + + logger.info("Processing %d chunks for TTS", len(chunks)) + + # Check memory before parallel processing + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer + logger.warning( + "High memory usage (%.1f%%), falling back to serial " + "processing", + mem_usage, + ) + content_audio_bytes = self._text_to_speech_serial(chunks) + else: + # Determine max workers + max_workers = min( + 4, # Reasonable limit to avoid rate limiting + len(chunks), # No more workers than chunks + max(1, psutil.cpu_count() // 2), # Use half of CPU cores + ) + + logger.info( + "Using %d workers for parallel TTS processing", + max_workers, + ) + + # Process chunks in parallel + chunk_results: list[tuple[int, bytes]] = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + ) as executor: + # Submit all chunks for processing + future_to_index = { + executor.submit(self._process_tts_chunk, chunk, i): i + for i, chunk in enumerate(chunks) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed( + future_to_index, + ): + index = future_to_index[future] + try: + audio_data = future.result() + chunk_results.append((index, audio_data)) + except Exception: + logger.exception( + "Failed to process chunk %d", + index, + ) + raise + + # Sort results by index to maintain order + chunk_results.sort(key=operator.itemgetter(0)) + + # Combine audio chunks + content_audio_bytes = self._combine_audio_chunks([ + data for _, data in chunk_results + ]) + + # Combine intro, content, and outro with pauses + return ArticleProcessor._combine_intro_content_outro( + intro_audio, + content_audio_bytes, + outro_audio, + ) + + except Exception: + logger.exception("TTS generation failed") + raise + + @staticmethod + def _create_intro_text( + title: str, + author: str | None, + pub_date: str | None, + ) -> str: + """Create intro text with available metadata.""" + parts = [f"Title: {title}"] + + if author: + parts.append(f"Author: {author}") + + if pub_date: + parts.append(f"Published: {pub_date}") + + return ". ".join(parts) + "." + + @staticmethod + def _create_outro_text(title: str, author: str | None) -> str: + """Create outro text with attribution.""" + if author: + return ( + f"This has been an audio version of {title} " + f"by {author}, created using Podcast It Later." + ) + return ( + f"This has been an audio version of {title}, " + "created using Podcast It Later." + ) + + def _generate_tts_segment(self, text: str) -> bytes: + """Generate TTS audio for a single segment (intro/outro). + + Args: + text: Text to convert to speech + + Returns: + MP3 audio bytes + """ + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=text, + ) + return response.content + + @staticmethod + def _combine_intro_content_outro( + intro_audio: bytes, + content_audio: bytes, + outro_audio: bytes, + ) -> bytes: + """Combine intro, content, and outro with crossfades. + + Args: + intro_audio: MP3 bytes for intro + content_audio: MP3 bytes for main content + outro_audio: MP3 bytes for outro + + Returns: + Combined MP3 audio bytes + """ + # Load audio segments + intro = AudioSegment.from_mp3(io.BytesIO(intro_audio)) + content = AudioSegment.from_mp3(io.BytesIO(content_audio)) + outro = AudioSegment.from_mp3(io.BytesIO(outro_audio)) + + # Create bridge silence (pause + 2 * crossfade to account for overlap) + bridge = AudioSegment.silent( + duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION + ) + + def safe_append( + seg1: AudioSegment, seg2: AudioSegment, crossfade: int + ) -> AudioSegment: + if len(seg1) < crossfade or len(seg2) < crossfade: + logger.warning( + "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation", + crossfade, + len(seg1), + len(seg2), + ) + return seg1 + seg2 + return seg1.append(seg2, crossfade=crossfade) + + # Combine segments with crossfades + # Intro -> Bridge -> Content -> Bridge -> Outro + # This effectively fades out the previous segment and fades in the next one + combined = safe_append(intro, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, content, CROSSFADE_DURATION) + combined = safe_append(combined, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, outro, CROSSFADE_DURATION) + + # Export to bytes + output = io.BytesIO() + combined.export(output, format="mp3") + return output.getvalue() + + def _process_tts_chunk(self, chunk: str, index: int) -> bytes: + """Process a single TTS chunk. + + Args: + chunk: Text to convert to speech + index: Chunk index for logging + + Returns: + Audio data as bytes + """ + logger.info( + "Generating TTS for chunk %d (%d chars)", + index + 1, + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + return response.content + + @staticmethod + def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes: + """Combine multiple audio chunks with silence gaps. + + Args: + audio_chunks: List of audio data in order + + Returns: + Combined audio data + """ + if not audio_chunks: + msg = "No audio chunks to combine" + raise ValueError(msg) + + # Create a temporary file for the combined audio + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Start with the first chunk + combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0])) + + # Add remaining chunks with silence gaps + for chunk_data in audio_chunks[1:]: + chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data)) + silence = AudioSegment.silent(duration=300) # 300ms gap + combined_audio = combined_audio + silence + chunk_audio + + # Export to file + combined_audio.export(temp_path, format="mp3", bitrate="128k") + + # Read back the combined audio + return Path(temp_path).read_bytes() + + finally: + # Clean up temp file + if Path(temp_path).exists(): + Path(temp_path).unlink() + + def _text_to_speech_serial(self, chunks: list[str]) -> bytes: + """Fallback serial processing for high memory situations. + + This is the original serial implementation. + """ + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunks[0], + response_format="mp3", + ) + + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment + + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data + + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() + + def upload_to_s3(self, audio_data: bytes, filename: str) -> str: + """Upload audio file to S3-compatible storage and return public URL. + + Raises: + ValueError: If S3 client is not configured. + ClientError: If S3 upload fails. + """ + if not self.s3_client: + msg = "S3 client not configured" + raise ValueError(msg) + + try: + # Upload file using streaming to minimize memory usage + audio_stream = io.BytesIO(audio_data) + self.s3_client.upload_fileobj( + audio_stream, + S3_BUCKET, + filename, + ExtraArgs={ + "ContentType": "audio/mpeg", + "ACL": "public-read", + }, + ) + + # Construct public URL + audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" + logger.info( + "Uploaded audio to: %s (%d bytes)", + audio_url, + len(audio_data), + ) + except ClientError: + logger.exception("S3 upload failed") + raise + else: + return audio_url + + @staticmethod + def estimate_duration(audio_data: bytes) -> int: + """Estimate audio duration in seconds based on file size and bitrate.""" + # Rough estimation: MP3 at 128kbps = ~16KB per second + estimated_seconds = len(audio_data) // 16000 + return max(1, estimated_seconds) # Minimum 1 second + + @staticmethod + def generate_filename(job_id: int, title: str) -> str: + """Generate unique filename for audio file.""" + timestamp = int(datetime.now(tz=timezone.utc).timestamp()) + # Create safe filename from title + safe_title = "".join( + c for c in title if c.isalnum() or c in {" ", "-", "_"} + ).rstrip() + safe_title = safe_title.replace(" ", "_")[:50] # Limit length + return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" + + def process_job( + self, + job: dict[str, Any], + ) -> None: + """Process a single job through the complete pipeline.""" + job_id = job["id"] + url = job["url"] + + # Check memory before starting + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD: + logger.warning( + "High memory usage (%.1f%%), deferring job %d", + mem_usage, + job_id, + ) + return + + # Track current job for graceful shutdown + self.shutdown_handler.set_current_job(job_id) + + try: + logger.info("Processing job %d: %s", job_id, url) + + # Update status to processing + Core.Database.update_job_status( + job_id, + "processing", + ) + + # Check for shutdown before each major step + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 1: Extract article content + Core.Database.update_job_status(job_id, "extracting") + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content(url) + ) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 2: Generate audio with metadata + Core.Database.update_job_status(job_id, "synthesizing") + audio_data = self.text_to_speech(content, title, author, pub_date) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 3: Upload to S3 + Core.Database.update_job_status(job_id, "uploading") + filename = ArticleProcessor.generate_filename(job_id, title) + audio_url = self.upload_to_s3(audio_data, filename) + + # Step 4: Calculate duration + duration = ArticleProcessor.estimate_duration(audio_data) + + # Step 5: Create episode record + url_hash = Core.hash_url(url) + episode_id = Core.Database.create_episode( + title=title, + audio_url=audio_url, + duration=duration, + content_length=len(content), + user_id=job.get("user_id"), + author=job.get("author"), # Pass author from job + original_url=url, # Pass the original article URL + original_url_hash=url_hash, + ) + + # Add episode to user's feed + user_id = job.get("user_id") + if user_id: + Core.Database.add_episode_to_user(user_id, episode_id) + Core.Database.track_episode_event( + episode_id, + "added", + user_id, + ) + + # Step 6: Mark job as complete + Core.Database.update_job_status( + job_id, + "completed", + ) + + logger.info( + "Successfully processed job %d -> episode %d", + job_id, + episode_id, + ) + + except Exception as e: + error_msg = str(e) + logger.exception("Job %d failed: %s", job_id, error_msg) + Core.Database.update_job_status( + job_id, + "error", + error_msg, + ) + raise + finally: + # Clear current job + self.shutdown_handler.set_current_job(None) + + +def prepare_text_for_tts(text: str, title: str) -> list[str]: + """Use LLM to prepare text for TTS, returning chunks ready for speech. + + First splits text mechanically, then has LLM edit each chunk. + """ + # First, split the text into manageable chunks + raw_chunks = split_text_into_chunks(text, max_chars=3000) + + logger.info("Split article into %d raw chunks", len(raw_chunks)) + + # Prepare the first chunk with intro + edited_chunks = [] + + for i, chunk in enumerate(raw_chunks): + is_first = i == 0 + is_last = i == len(raw_chunks) - 1 + + try: + edited_chunk = edit_chunk_for_speech( + chunk, + title=title if is_first else None, + is_first=is_first, + is_last=is_last, + ) + edited_chunks.append(edited_chunk) + except Exception: + logger.exception("Failed to edit chunk %d", i + 1) + # Fall back to raw chunk if LLM fails + if is_first: + edited_chunks.append( + f"This is an audio version of {title}. {chunk}", + ) + elif is_last: + edited_chunks.append(f"{chunk} This concludes the article.") + else: + edited_chunks.append(chunk) + + return edited_chunks + + +def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: + """Split text into chunks at sentence boundaries.""" + chunks = [] + current_chunk = "" + + # Split into paragraphs first + paragraphs = text.split("\n\n") + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + # If paragraph itself is too long, split by sentences + if len(para_stripped) > max_chars: + sentences = para_stripped.split(". ") + for sentence in sentences: + if len(current_chunk) + len(sentence) + 2 < max_chars: + current_chunk += sentence + ". " + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + ". " + # If adding this paragraph would exceed limit, start new chunk + elif len(current_chunk) + len(para_stripped) + 2 > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = para_stripped + " " + else: + current_chunk += para_stripped + " " + + # Don't forget the last chunk + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def edit_chunk_for_speech( + chunk: str, + title: str | None = None, + *, + is_first: bool = False, + is_last: bool = False, +) -> str: + """Use LLM to lightly edit a single chunk for speech. + + Raises: + ValueError: If no content is returned from LLM. + """ + system_prompt = ( + "You are a podcast script editor. Your job is to lightly edit text " + "to make it sound natural when spoken aloud.\n\n" + "Guidelines:\n" + ) + system_prompt += """ +- Remove URLs and email addresses, replacing with descriptive phrases +- Convert bullet points and lists into flowing sentences +- Fix any awkward phrasing for speech +- Remove references like "click here" or "see below" +- Keep edits minimal - preserve the original content and style +- Do NOT add commentary or explanations +- Return ONLY the edited text, no JSON or formatting +""" + + user_prompt = chunk + + # Add intro/outro if needed + if is_first and title: + user_prompt = ( + f"Add a brief intro mentioning this is an audio version of " + f"'{title}', then edit this text:\n\n{chunk}" + ) + elif is_last: + user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" + + try: + client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.3, # Lower temperature for more consistent edits + max_tokens=4000, + ) + + content = response.choices[0].message.content + if not content: + msg = "No content returned from LLM" + raise ValueError(msg) # noqa: TRY301 + + # Ensure the chunk isn't too long + max_chunk_length = 4000 + if len(content) > max_chunk_length: + # Truncate at sentence boundary + sentences = content.split(". ") + truncated = "" + for sentence in sentences: + if len(truncated) + len(sentence) + 2 < max_chunk_length: + truncated += sentence + ". " + else: + break + content = truncated.strip() + + except Exception: + logger.exception("LLM chunk editing failed") + raise + else: + return content + + +def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: + """Parse datetime string and ensure it has timezone info.""" + if isinstance(created_at, str): + # Handle timezone-aware datetime strings + if created_at.endswith("Z"): + created_at = created_at[:-1] + "+00:00" + last_attempt = datetime.fromisoformat(created_at) + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + else: + last_attempt = created_at + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + return last_attempt + + +def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: + """Check if a job should be retried based on retry count and backoff time. + + Uses exponential backoff to determine if enough time has passed. + """ + retry_count = job["retry_count"] + if retry_count >= max_retries: + return False + + # Exponential backoff: 30s, 60s, 120s + backoff_time = 30 * (2**retry_count) + last_attempt = parse_datetime_with_timezone(job["created_at"]) + time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt + + return time_since_attempt > timedelta(seconds=backoff_time) + + +def process_pending_jobs( + processor: ArticleProcessor, +) -> None: + """Process all pending jobs.""" + pending_jobs = Core.Database.get_pending_jobs( + limit=5, + ) + + for job in pending_jobs: + if processor.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, stopping job processing") + break + + current_job = job["id"] + try: + processor.process_job(job) + except Exception as e: + # Ensure job is marked as error even if process_job didn't handle it + logger.exception("Failed to process job: %d", current_job) + # Check if job is still in processing state + current_status = Core.Database.get_job_by_id( + current_job, + ) + if current_status and current_status.get("status") == "processing": + Core.Database.update_job_status( + current_job, + "error", + str(e), + ) + continue + + +def process_retryable_jobs() -> None: + """Check and retry failed jobs with exponential backoff.""" + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + for job in retryable_jobs: + if should_retry_job(job, MAX_RETRIES): + logger.info( + "Retrying job %d (attempt %d)", + job["id"], + job["retry_count"] + 1, + ) + Core.Database.update_job_status( + job["id"], + "pending", + ) + + +def check_memory_usage() -> int | Any: + """Check current memory usage percentage.""" + try: + process = psutil.Process() + # this returns an int but psutil is untyped + return process.memory_percent() + except (psutil.Error, OSError): + logger.warning("Failed to check memory usage") + return 0.0 + + +def cleanup_stale_jobs() -> None: + """Reset jobs stuck in processing state on startup.""" + with Core.Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE queue + SET status = 'pending' + WHERE status = 'processing' + """, + ) + affected = cursor.rowcount + conn.commit() + + if affected > 0: + logger.info( + "Reset %d stale jobs from processing to pending", + affected, + ) + + +def main_loop() -> None: + """Poll for jobs and process them in a continuous loop.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + # Clean up any stale jobs from previous runs + cleanup_stale_jobs() + + logger.info("Worker started, polling for jobs...") + + while not shutdown_handler.is_shutdown_requested(): + try: + # Process pending jobs + process_pending_jobs(processor) + process_retryable_jobs() + + # Check if there's any work + pending_jobs = Core.Database.get_pending_jobs( + limit=1, + ) + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + if not pending_jobs and not retryable_jobs: + logger.debug("No jobs to process, sleeping...") + + except Exception: + logger.exception("Error in main loop") + + # Use interruptible sleep + if not shutdown_handler.is_shutdown_requested(): + shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL) + + # Graceful shutdown + current_job = shutdown_handler.get_current_job() + if current_job: + logger.info( + "Waiting for job %d to complete before shutdown...", + current_job, + ) + # The job will complete or be reset to pending + + logger.info("Worker shutdown complete") + + +def move() -> None: + """Make the worker move.""" + try: + # Initialize database + Core.Database.init_db() + + # Start main processing loop + main_loop() + + except KeyboardInterrupt: + logger.info("Worker stopped by user") + except Exception: + logger.exception("Worker crashed") + raise + + +class TestArticleExtraction(Test.TestCase): + """Test article extraction functionality.""" + + def test_extract_valid_article(self) -> None: + """Extract from well-formed HTML.""" + # Mock trafilatura.fetch_url and extract + mock_html = ( + "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" + ) + mock_result = json.dumps({ + "title": "Test Article", + "text": "Content here", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Content here") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_missing_title(self) -> None: + """Handle articles without titles.""" + mock_html = "<html><body><p>Content without title</p></body></html>" + mock_result = json.dumps({"text": "Content without title"}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Untitled Article") + self.assertEqual(content, "Content without title") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_empty_content(self) -> None: + """Handle empty articles.""" + mock_html = "<html><body></body></html>" + mock_result = json.dumps({"title": "Empty Article", "text": ""}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + pytest.raises(ValueError, match="No content extracted") as cm, + ): + ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertIn("No content extracted", str(cm.value)) + + def test_extract_network_error(self) -> None: + """Handle connection failures.""" + with ( + unittest.mock.patch("trafilatura.fetch_url", return_value=None), + pytest.raises(ValueError, match="Failed to download") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Failed to download", str(cm.value)) + + @staticmethod + def test_extract_timeout() -> None: + """Handle slow responses.""" + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + side_effect=TimeoutError("Timeout"), + ), + pytest.raises(TimeoutError), + ): + ArticleProcessor.extract_article_content("https://example.com") + + def test_content_sanitization(self) -> None: + """Remove unwanted elements.""" + mock_html = """ + <html><body> + <h1>Article</h1> + <p>Good content</p> + <script>alert('bad')</script> + <table><tr><td>data</td></tr></table> + </body></html> + """ + mock_result = json.dumps({ + "title": "Article", + "text": "Good content", # Tables and scripts removed + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + _title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(content, "Good content") + self.assertNotIn("script", content) + self.assertNotIn("table", content) + + +class TestTextToSpeech(Test.TestCase): + """Test text-to-speech functionality.""" + + def setUp(self) -> None: + """Set up mocks.""" + # Mock OpenAI API key + self.env_patcher = unittest.mock.patch.dict( + os.environ, + {"OPENAI_API_KEY": "test-key"}, + ) + self.env_patcher.start() + + # Mock OpenAI response + self.mock_audio_response: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_response.content = b"fake-audio-data" + + # Mock AudioSegment to avoid ffmpeg issues in tests + self.mock_audio_segment: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_segment.export.return_value = None + self.audio_segment_patcher = unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=self.mock_audio_segment, + ) + self.audio_segment_patcher.start() + + # Mock the concatenation operations + self.mock_audio_segment.__add__.return_value = self.mock_audio_segment + + def tearDown(self) -> None: + """Clean up mocks.""" + self.env_patcher.stop() + self.audio_segment_patcher.stop() + + def test_tts_generation(self) -> None: + """Generate audio from text.""" + + # Mock the export to write test audio data + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["Test content"], + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + "Test content", + "Test Title", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_chunking(self) -> None: + """Handle long articles with chunking.""" + long_text = "Long content " * 1000 + chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock AudioSegment.silent + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + long_text, + "Long Article", + ) + + # Should have called TTS for each chunk + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_empty_text(self) -> None: + """Handle empty input.""" + with unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[], + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + with pytest.raises(ValueError, match="No chunks generated") as cm: + processor.text_to_speech("", "Empty") + + self.assertIn("No chunks generated", str(cm.value)) + + def test_tts_special_characters(self) -> None: + """Handle unicode and special chars.""" + special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[special_text], + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + special_text, + "Special", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_llm_text_preparation(self) -> None: + """Verify LLM editing.""" + # Test the actual text preparation functions + chunks = split_text_into_chunks("Short text", max_chars=100) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], "Short text") + + # Test long text splitting + long_text = "Sentence one. " * 100 + chunks = split_text_into_chunks(long_text, max_chars=100) + self.assertGreater(len(chunks), 1) + for chunk in chunks: + self.assertLessEqual(len(chunk), 100) + + @staticmethod + def test_llm_failure_fallback() -> None: + """Handle LLM API failures.""" + # Mock LLM failure + with unittest.mock.patch("openai.OpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_client.chat.completions.create.side_effect = Exception( + "API Error", + ) + + # Should fall back to raw text + with pytest.raises(Exception, match="API Error"): + edit_chunk_for_speech("Test chunk", "Title", is_first=True) + + def test_chunk_concatenation(self) -> None: + """Verify audio joining.""" + # Mock multiple audio segments + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test", "Title") + + # Should produce combined audio + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_parallel_tts_generation(self) -> None: + """Test parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"] + + # Mock responses for each chunk + mock_responses = [] + for i in range(len(chunks)): + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{i}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # Make create return different responses for each call + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Mock AudioSegment operations + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + Path(path).write_bytes(b"combined-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=mock_segment, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, # Normal memory usage + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Verify all chunks were processed + self.assertEqual(mock_speech.create.call_count, len(chunks)) + self.assertEqual(audio_data, b"combined-audio") + + def test_parallel_tts_high_memory_fallback(self) -> None: + """Test fallback to serial processing when memory is high.""" + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"serial-audio") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=65.0, # High memory usage + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Should use serial processing + self.assertEqual(audio_data, b"serial-audio") + + @staticmethod + def test_parallel_tts_error_handling() -> None: + """Test error handling in parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2"] + + # Mock OpenAI client with one failure + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # First call succeeds, second fails + mock_resp1 = unittest.mock.MagicMock() + mock_resp1.content = b"audio-1" + mock_speech.create.side_effect = [mock_resp1, Exception("API Error")] + + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Set up the test context + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + pytest.raises(Exception, match="API Error"), + ): + processor.text_to_speech("Test content", "Test Title") + + def test_parallel_tts_order_preservation(self) -> None: + """Test that chunks are combined in the correct order.""" + chunks = ["First", "Second", "Third", "Fourth", "Fifth"] + + # Create mock responses with identifiable content + mock_responses = [] + for chunk in chunks: + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{chunk}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Track the order of segments being combined + combined_order = [] + + def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock: + content = data.read() + combined_order.append(content.decode()) + segment = unittest.mock.MagicMock() + segment.__add__.return_value = segment + return segment + + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + # Verify order is preserved + expected_order = [f"audio-{chunk}" for chunk in chunks] + if combined_order != expected_order: + msg = f"Order mismatch: {combined_order} != {expected_order}" + raise AssertionError(msg) + Path(path).write_bytes(b"ordered-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + side_effect=mock_from_mp3, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + self.assertEqual(audio_data, b"ordered-audio") + + +class TestIntroOutro(Test.TestCase): + """Test intro and outro generation with metadata.""" + + def test_create_intro_text_full_metadata(self) -> None: + """Test intro text creation with all metadata.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author="John Doe", + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertIn("Author: John Doe", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_no_author(self) -> None: + """Test intro text without author.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertNotIn("Author:", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_minimal(self) -> None: + """Test intro text with only title.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date=None, + ) + self.assertEqual(intro, "Title: Test Article.") + + def test_create_outro_text_with_author(self) -> None: + """Test outro text with author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author="Jane Smith", + ) + self.assertIn("Test Article", outro) + self.assertIn("Jane Smith", outro) + self.assertIn("Podcast It Later", outro) + + def test_create_outro_text_no_author(self) -> None: + """Test outro text without author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author=None, + ) + self.assertIn("Test Article", outro) + self.assertNotIn("by", outro) + self.assertIn("Podcast It Later", outro) + + def test_extract_with_metadata(self) -> None: + """Test that extraction returns metadata.""" + mock_html = "<html><body><p>Content</p></body></html>" + mock_result = json.dumps({ + "title": "Test Article", + "text": "Article content", + "author": "Test Author", + "date": "2024-01-15", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Article content") + self.assertEqual(author, "Test Author") + self.assertEqual(pub_date, "2024-01-15") + + +class TestMemoryEfficiency(Test.TestCase): + """Test memory-efficient processing.""" + + def test_large_article_size_limit(self) -> None: + """Test that articles exceeding size limits are rejected.""" + huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=huge_text * 4, # Simulate large HTML + ), + pytest.raises(ValueError, match="Article too large") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Article too large", str(cm.value)) + + def test_content_truncation(self) -> None: + """Test that oversized content is truncated.""" + large_content = "Content " * 100_000 # Create large content + mock_result = json.dumps({ + "title": "Large Article", + "text": large_content, + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value="<html><body>content</body></html>", + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Large Article") + self.assertLessEqual(len(content), MAX_ARTICLE_SIZE) + + def test_memory_usage_check(self) -> None: + """Test memory usage monitoring.""" + usage = check_memory_usage() + self.assertIsInstance(usage, float) + self.assertGreaterEqual(usage, 0.0) + self.assertLessEqual(usage, 100.0) + + +class TestJobProcessing(Test.TestCase): + """Test job processing functionality.""" + + def setUp(self) -> None: + """Set up test environment.""" + Core.Database.init_db() + + # Create test user and job + self.user_id, _ = Core.Database.create_user( + "test@example.com", + ) + self.job_id = Core.Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + # Mock environment + self.env_patcher = unittest.mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test-key", + "S3_ENDPOINT": "https://s3.example.com", + "S3_BUCKET": "test-bucket", + "S3_ACCESS_KEY": "test-access", + "S3_SECRET_KEY": "test-secret", + }, + ) + self.env_patcher.start() + + def tearDown(self) -> None: + """Clean up.""" + self.env_patcher.stop() + Core.Database.teardown() + + def test_process_job_success(self) -> None: + """Complete pipeline execution.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock all external calls + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio-data", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://s3.example.com/audio.mp3", + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.create_episode", + ) as mock_create, + ): + mock_create.return_value = 1 + processor.process_job(job) + + # Verify job was marked complete + mock_update.assert_called_with(self.job_id, "completed") + mock_create.assert_called_once() + + def test_process_job_extraction_failure(self) -> None: + """Handle bad URLs.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + side_effect=ValueError("Bad URL"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(ValueError, match="Bad URL"), + ): + processor.process_job(job) + + # Job should be marked as error + mock_update.assert_called_with(self.job_id, "error", "Bad URL") + + def test_process_job_tts_failure(self) -> None: + """Handle TTS errors.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=Exception("TTS Error"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(Exception, match="TTS Error"), + ): + processor.process_job(job) + + mock_update.assert_called_with(self.job_id, "error", "TTS Error") + + def test_process_job_s3_failure(self) -> None: + """Handle upload errors.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + side_effect=ClientError({}, "PutObject"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ), + pytest.raises(ClientError), + ): + processor.process_job(job) + + def test_job_retry_logic(self) -> None: + """Verify exponential backoff.""" + # Set job to error with retry count + Core.Database.update_job_status( + self.job_id, + "error", + "First failure", + ) + Core.Database.update_job_status( + self.job_id, + "error", + "Second failure", + ) + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + self.assertEqual(job["retry_count"], 2) + + # Should be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + + def test_max_retries(self) -> None: + """Stop after max attempts.""" + # Exceed retry limit + for i in range(4): + Core.Database.update_job_status( + self.job_id, + "error", + f"Failure {i}", + ) + + # Should not be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_graceful_shutdown(self) -> None: + """Test graceful shutdown during job processing.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + def mock_tts(*_args: Any) -> bytes: + shutdown_handler.shutdown_requested.set() + return b"audio-data" + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=mock_tts, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should be reset to pending due to shutdown + mock_update.assert_any_call(self.job_id, "pending") + + def test_cleanup_stale_jobs(self) -> None: + """Test cleanup of stale processing jobs.""" + # Manually set job to processing + Core.Database.update_job_status(self.job_id, "processing") + + # Run cleanup + cleanup_stale_jobs() + + # Job should be back to pending + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + self.assertEqual(job["status"], "pending") + + def test_concurrent_processing(self) -> None: + """Handle multiple jobs.""" + # Create multiple jobs + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Core.Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Core.Database.get_pending_jobs(limit=5) + + self.assertEqual(len(jobs), 3) + self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + + def test_memory_threshold_deferral(self) -> None: + """Test that jobs are deferred when memory usage is high.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock high memory usage + with ( + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=90.0, # High memory usage + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should not be processed (no status updates) + mock_update.assert_not_called() + + +class TestWorkerErrorHandling(Test.TestCase): + """Test worker error handling and recovery.""" + + def setUp(self) -> None: + """Set up test environment.""" + Core.Database.init_db() + self.user_id, _ = Core.Database.create_user("test@example.com") + self.job_id = Core.Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + self.shutdown_handler = ShutdownHandler() + self.processor = ArticleProcessor(self.shutdown_handler) + + @staticmethod + def tearDown() -> None: + """Clean up.""" + Core.Database.teardown() + + def test_process_pending_jobs_exception_handling(self) -> None: + """Test that process_pending_jobs handles exceptions.""" + + def side_effect(job: dict[str, Any]) -> None: + # Simulate process_job starting and setting status to processing + Core.Database.update_job_status(job["id"], "processing") + msg = "Unexpected Error" + raise ValueError(msg) + + with ( + unittest.mock.patch.object( + self.processor, + "process_job", + side_effect=side_effect, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + side_effect=Core.Database.update_job_status, + ) as _mock_update, + ): + process_pending_jobs(self.processor) + + # Job should be marked as error + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "error") + self.assertIn("Unexpected Error", job["error_message"]) + + def test_process_retryable_jobs_success(self) -> None: + """Test processing of retryable jobs.""" + # Set up a retryable job + Core.Database.update_job_status(self.job_id, "error", "Fail 1") + + # Modify created_at to be in the past to satisfy backoff + with Core.Database.get_connection() as conn: + conn.execute( + "UPDATE queue SET created_at = ? WHERE id = ?", + ( + ( + datetime.now(tz=timezone.utc) - timedelta(minutes=5) + ).isoformat(), + self.job_id, + ), + ) + conn.commit() + + process_retryable_jobs() + + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "pending") + + def test_process_retryable_jobs_not_ready(self) -> None: + """Test that jobs are not retried before backoff period.""" + # Set up a retryable job that just failed + Core.Database.update_job_status(self.job_id, "error", "Fail 1") + + # created_at is now, so backoff should prevent retry + process_retryable_jobs() + + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "error") + + +class TestTextChunking(Test.TestCase): + """Test text chunking edge cases.""" + + def test_split_text_single_long_word(self) -> None: + """Handle text with a single word exceeding limit.""" + long_word = "a" * 4000 + chunks = split_text_into_chunks(long_word, max_chars=3000) + + # Should keep it as one chunk or split? + # The current implementation does not split words + self.assertEqual(len(chunks), 1) + self.assertEqual(len(chunks[0]), 4000) + + def test_split_text_no_sentence_boundaries(self) -> None: + """Handle long text with no sentence boundaries.""" + text = "word " * 1000 # 5000 chars + chunks = split_text_into_chunks(text, max_chars=3000) + + # Should keep it as one chunk as it can't split by ". " + self.assertEqual(len(chunks), 1) + self.assertGreater(len(chunks[0]), 3000) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestArticleExtraction, + TestTextToSpeech, + TestMemoryEfficiency, + TestJobProcessing, + TestWorkerErrorHandling, + TestTextChunking, + ], + ) + + +def main() -> None: + """Entry point for the worker.""" + if "test" in sys.argv: + test() + else: + move() |
