diff options
Diffstat (limited to 'Biz')
| -rwxr-xr-x | Biz/Dragons.hs | 14 | ||||
| -rwxr-xr-x | Biz/Dragons/Analysis.nix | 7 | ||||
| -rwxr-xr-x | Biz/EmailAgent.py | 151 | ||||
| -rw-r--r-- | Biz/Packages.nix | 15 | ||||
| -rw-r--r-- | Biz/PodcastItLater.md | 338 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Admin.py | 1068 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Billing.py | 581 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 2174 | ||||
| -rw-r--r-- | Biz/PodcastItLater/DESIGN.md | 43 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Episode.py | 390 | ||||
| -rw-r--r-- | Biz/PodcastItLater/INFRASTRUCTURE.md | 38 | ||||
| -rw-r--r-- | Biz/PodcastItLater/STRIPE_TESTING.md | 114 | ||||
| -rw-r--r-- | Biz/PodcastItLater/TESTING.md | 45 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Test.py | 276 | ||||
| -rw-r--r-- | Biz/PodcastItLater/TestMetricsView.py | 121 | ||||
| -rw-r--r-- | Biz/PodcastItLater/UI.py | 755 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.nix | 93 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 3480 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker.nix | 63 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 2199 | ||||
| -rwxr-xr-x | Biz/Que/Host.hs | 73 | ||||
| -rwxr-xr-x | Biz/Storybook.py | 10 |
22 files changed, 12030 insertions, 18 deletions
diff --git a/Biz/Dragons.hs b/Biz/Dragons.hs index 7ba7fa0..cfe211e 100755 --- a/Biz/Dragons.hs +++ b/Biz/Dragons.hs @@ -742,7 +742,19 @@ startup quiet = do cfg <- Envy.decodeWithDefaults Envy.defConfig oAuthArgs <- Envy.decodeWithDefaults Envy.defConfig kp <- Acid.openLocalStateFrom (keep cfg) init :: IO (Acid.AcidState Keep) - jwk <- Auth.generateKey -- TODO: store this in a file somewhere + let jwkPath = keep cfg </> "jwk.json" + jwkExists <- Directory.doesFileExist jwkPath + jwk <- + if jwkExists + then do + maybeKey <- Aeson.decodeFileStrict jwkPath + case maybeKey of + Nothing -> panic <| "Could not decode JWK from " <> str jwkPath + Just k -> pure k + else do + k <- Auth.generateKey + Aeson.encodeFile jwkPath k + pure k let url = case homeExample cfg of ForgeURL u -> u CLISubmission -> "<CLISubmission>" diff --git a/Biz/Dragons/Analysis.nix b/Biz/Dragons/Analysis.nix index 5ea8713..b0e0cc9 100755 --- a/Biz/Dragons/Analysis.nix +++ b/Biz/Dragons/Analysis.nix @@ -1,5 +1,8 @@ #!/usr/bin/env run.sh -{bild}: +{ + bild, + packages ? import ../Packages.nix {inherit bild;}, +}: # Run this like so: # # bild Biz/Dragons/Analysis.nix @@ -11,6 +14,6 @@ bild.image { fromImage = null; fromImageName = null; fromImageTag = "latest"; - contents = [bild.pkgs.git (bild.run ./Analysis.hs)]; + contents = [bild.pkgs.git packages.dragons-analysis]; config.Cmd = ["/bin/dragons-analyze"]; } diff --git a/Biz/EmailAgent.py b/Biz/EmailAgent.py new file mode 100755 index 0000000..6ac4c95 --- /dev/null +++ b/Biz/EmailAgent.py @@ -0,0 +1,151 @@ +#!/usr/bin/env run.sh +""" +Email sending utility that can be used as a script or imported as a library. + +Password is provided through systemd's LoadCredential feature. This is intended +to be used by automated agents in a systemd timer. +""" + +import argparse +import email.message +import email.utils +import errno +import os +import pathlib +import smtplib +import sys + + +# ruff: noqa: PLR0917, PLR0913 +def send_email( + to_addrs: list[str], + from_addr: str, + smtp_server: str, + password: str, + subject: str, + body_text: pathlib.Path, + body_html: pathlib.Path | None = None, + port: int = 587, +) -> dict[str, tuple[int, bytes]]: + """ + Send an email using the provided parameters. + + Args: + to_addr: Recipient email addresses + from_addr: Sender email address + smtp_server: SMTP server hostname + password: Password for authentication + subject: Email subject + body_text: File with email body text + body_html: File with email body html + port: SMTP server port (default: 587) + + """ + msg = email.message.EmailMessage() + msg["Subject"] = subject + msg["From"] = from_addr + msg["To"] = ", ".join(to_addrs) + msg["Message-ID"] = email.utils.make_msgid( + idstring=__name__, + domain=smtp_server, + ) + msg["Date"] = email.utils.formatdate(localtime=True) + with body_text.open(encoding="utf-8") as txt: + msg.set_content(txt.read()) + if body_html: + with body_html.open(encoding="utf-*") as html: + msg.add_alternative(html.read(), subtype="html") + with smtplib.SMTP(smtp_server, port) as server: + server.starttls() + server.login(from_addr, password) + return server.send_message( + msg, + from_addr=from_addr, + to_addrs=to_addrs, + ) + + +def main() -> None: + """Parse command line arguments and send email. + + Raises: + FileNotFoundError: if --password-file does not exist + """ + if "test" in sys.argv: + sys.exit(0) + parser = argparse.ArgumentParser( + description="Send an email", + ) + parser.add_argument( + "--to", + required=True, + help="Recipient email addresses, can be specified multiple times", + nargs="+", + action="extend", + ) + parser.add_argument( + "--from", + dest="from_addr", + required=True, + help="Sender email address", + ) + parser.add_argument( + "--smtp-server", + required=True, + help="SMTP server hostname", + ) + parser.add_argument("--subject", required=True, help="Email subject") + parser.add_argument( + "--body-text", + required=True, + help="File with email body text", + ) + parser.add_argument( + "--body-html", + help="File with email body html", + default=None, + ) + parser.add_argument( + "--port", + type=int, + default=587, + help="SMTP server port (default: 587)", + ) + parser.add_argument( + "--password-file", + default="smtp-password", + help="Where to find the password file", + ) + + args = parser.parse_args() + + credential_path = pathlib.Path(args.password_file) + if not credential_path.exists(): + raise FileNotFoundError( + errno.ENOENT, + os.strerror(errno.ENOENT), + credential_path, + ) + sys.exit(1) + + with pathlib.Path.open(credential_path, encoding="utf-8") as f: + password = f.read().strip() + + results = send_email( + to_addrs=args.to, + from_addr=args.from_addr, + smtp_server=args.smtp_server, + subject=args.subject, + body_text=pathlib.Path(args.body_text), + body_html=pathlib.Path(args.body_html) if args.body_html else None, + password=password, + port=args.port, + ) + if len(results) > 0: + sys.stdout.write(str(results)) + sys.stdout.flush() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/Biz/Packages.nix b/Biz/Packages.nix new file mode 100644 index 0000000..6b17fe5 --- /dev/null +++ b/Biz/Packages.nix @@ -0,0 +1,15 @@ +# Build all Biz packages independently, outside NixOS context. +# +# This file builds all Biz packages and returns them as an attribute set. +# The NixOS config (Biz.nix) will accept these as inputs rather than +# building them during OS evaluation. +# +# Usage: +# nix-build Biz/Packages.nix # builds all packages +# nix-build Biz/Packages.nix -A storybook # builds one package +{bild ? import ../Omni/Bild.nix {}}: { + storybook = bild.run ../Biz/Storybook.py; + podcastitlater-web = bild.run ../Biz/PodcastItLater/Web.py; + podcastitlater-worker = bild.run ../Biz/PodcastItLater/Worker.py; + dragons-analysis = bild.run ../Biz/Dragons/Analysis.hs; +} diff --git a/Biz/PodcastItLater.md b/Biz/PodcastItLater.md new file mode 100644 index 0000000..c3d1708 --- /dev/null +++ b/Biz/PodcastItLater.md @@ -0,0 +1,338 @@ +# PodcastItLater + +A service that converts web articles to podcast episodes via email submission or web interface. Users can submit articles and receive them as audio episodes in their personal podcast feed. + +## Current Implementation Status + +### Architecture +- **Web Service** (`Biz/PodcastItLater/Web.py`) - Ludic web app with HTMX interface +- **Background Worker** (`Biz/PodcastItLater/Worker.py`) - Processes articles to audio +- **Core/Database** (`Biz/PodcastItLater/Core.py`) - Shared database operations + +### Features Implemented + +#### User Management +- Email-based registration/login (no passwords) +- Session-based authentication +- Personal RSS feed tokens +- User-specific data isolation + +#### Article Processing +- Manual URL submission via web form +- Content extraction with trafilatura +- LLM-powered text preparation for natural speech +- OpenAI TTS conversion with chunking for long articles +- S3-compatible storage (Digital Ocean Spaces) + +#### Web Interface +- Login/logout functionality +- Submit article form +- Live queue status updates (HTMX) +- Recent episodes with audio player +- Personal RSS feed URL display +- Admin queue view with retry/delete actions + +#### RSS Feeds +- Personalized feeds at `/feed/{user_token}.xml` +- User-specific episode filtering +- Customized feed titles based on user email + +### Database Schema +```sql +-- Users table +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Queue table with user support +CREATE TABLE queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + email TEXT, + user_id INTEGER REFERENCES users(id), + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + error_message TEXT +); + +-- Episodes table with user support +CREATE TABLE episodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + content_length INTEGER, + audio_url TEXT NOT NULL, + duration INTEGER, + user_id INTEGER REFERENCES users(id), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); +``` + +## Phase 3: Path to Paid Product + +### Immediate Priorities + +#### 1. Usage Limits & Billing Infrastructure +- Add usage tracking to users table (articles_processed, audio_minutes) +- Implement free tier limits (e.g., 10 articles/month) +- Add subscription status and tier to users +- Integrate Stripe for payments +- Create billing webhook handlers + +#### 2. Enhanced User Experience +- Implement article preview/editing before conversion +- Add voice selection options +- Support for multiple TTS providers (cost optimization) +- Batch processing for multiple URLs + +#### 3. Content Quality Improvements +- Better handling of different article types (news, blogs, research papers) +- Improved code block and technical content handling +- Table/chart description generation +- Multi-language support +- Custom intro/outro options + +#### 4. Admin & Analytics +- Admin dashboard for monitoring all users +- Usage analytics and metrics +- Cost tracking per user +- System health monitoring +- Automated error alerting + +### Technical Improvements Needed + +#### Security & Reliability +- Add rate limiting per user +- Implement proper API authentication (not just session-based) +- Add request signing for webhook security +- Backup and disaster recovery for database +- Queue persistence across worker restarts + +#### Performance & Scalability +- Move from SQLite to PostgreSQL +- Implement proper job queue (Redis/RabbitMQ) +- Add caching layer for processed articles +- CDN for audio file delivery +- Horizontal scaling for workers + +#### Code Quality +- Add comprehensive test suite +- API documentation +- Error tracking (Sentry) +- Structured logging with correlation IDs +- Configuration management (not just env vars) + +### Pricing Model Considerations +- Free tier: 5-10 articles/month, basic voice +- Personal: $5-10/month, 50 articles, voice selection +- Pro: $20-30/month, unlimited articles, priority processing +- API access for developers + +### MVP for Paid Launch +1. Stripe integration with subscription management +2. Usage tracking and enforcement +3. Email notifications +4. Basic admin dashboard +5. Improved error handling and retry logic +6. PostgreSQL migration +7. Basic API with authentication + +### Environment Variables Required +```bash +# Current +OPENAI_API_KEY= +S3_ENDPOINT= +S3_BUCKET= +S3_ACCESS_KEY= +S3_SECRET_KEY= +BASE_URL= +DATA_DIR= # Used by both Web and Worker services +SESSION_SECRET= +PORT= + +# Needed for paid version +STRIPE_SECRET_KEY= +STRIPE_WEBHOOK_SECRET= +STRIPE_PRICE_ID_PERSONAL= +STRIPE_PRICE_ID_PRO= +SENDGRID_API_KEY= # for transactional emails +SENTRY_DSN= +REDIS_URL= +``` + +### Next Implementation Steps +1. Create `Biz/PodcastItLater/Billing.py` for Stripe integration +2. Add usage tracking to Core.py database operations +3. Implement email notifications in Worker.py +4. Create admin interface endpoints in Web.py +5. Add comprehensive error handling and logging +6. Write test suite +7. Create deployment configuration + +## Test Plan + +### Overview +The test suite will ensure reliability and correctness of all components before launching the paid product. Tests will be organized into three main categories matching the architecture: Core (database), Web (frontend/API), and Worker (background processing). + +### Test Structure +Tests will be placed in the same file as the code they test, following the pattern established in the codebase. Each module will contain its test classes nearby the functionality that class is testing: + +- `Biz/PodcastItLater/Core.py` - Contains database logic and TestDatabase, TestUserManagement, TestQueueOperations, TestEpisodeManagement classes +- `Biz/PodcastItLater/Web.py` - Contains web interface and TestAuthentication, TestArticleSubmission, TestRSSFeed, TestAdminInterface classes +- `Biz/PodcastItLater/Worker.py` - Contains background worker and TestArticleExtraction, TestTextToSpeech, TestJobProcessing classes + +Each file will follow this pattern: +```python +# Main code implementation +class Database: + ... + +# Test class next to the class it is testing +class TestDatabase(Test.TestCase): + """Test the Database class.""" + + def test_init_db(self) -> None: + """Verify all tables and indexes are created correctly.""" + ... +``` + +This keeps tests close to the code they test, making it easier to maintain and understand the relationship between implementation and tests. + +### Core Tests (Core.py) + +#### TestDatabase +- `test_init_db` - Verify all tables and indexes are created correctly +- `test_connection_context_manager` - Ensure connections are properly closed +- `test_migration_idempotency` - Verify migrations can run multiple times safely + +#### TestUserManagement +- `test_create_user` - Create user with unique email and token +- `test_create_duplicate_user` - Verify duplicate emails return existing user +- `test_get_user_by_email` - Retrieve user by email +- `test_get_user_by_token` - Retrieve user by RSS token +- `test_get_user_by_id` - Retrieve user by ID +- `test_invalid_user_lookups` - Verify None returned for non-existent users +- `test_token_uniqueness` - Ensure tokens are cryptographically unique + +#### TestQueueOperations +- `test_add_to_queue` - Add job with user association +- `test_get_pending_jobs` - Retrieve jobs in correct order +- `test_update_job_status` - Update status and error messages +- `test_retry_job` - Reset failed jobs for retry +- `test_delete_job` - Remove jobs from queue +- `test_get_retryable_jobs` - Find jobs eligible for retry +- `test_user_queue_isolation` - Ensure users only see their own jobs +- `test_status_counts` - Verify status aggregation queries + +#### TestEpisodeManagement +- `test_create_episode` - Create episode with user association +- `test_get_recent_episodes` - Retrieve episodes in reverse chronological order +- `test_get_user_episodes` - Ensure user isolation for episodes +- `test_episode_metadata` - Verify duration and content_length storage + +### Web Tests (Web.py) + +#### TestAuthentication +- `test_login_new_user` - Auto-create user on first login +- `test_login_existing_user` - Login with existing email +- `test_login_invalid_email` - Reject malformed emails +- `test_session_persistence` - Verify session across requests +- `test_protected_routes` - Ensure auth required for user actions + +#### TestArticleSubmission +- `test_submit_valid_url` - Accept well-formed URLs +- `test_submit_invalid_url` - Reject malformed URLs +- `test_submit_without_auth` - Reject unauthenticated submissions +- `test_submit_creates_job` - Verify job creation in database +- `test_htmx_response` - Ensure proper HTMX response format + +#### TestRSSFeed +- `test_feed_generation` - Generate valid RSS XML +- `test_feed_user_isolation` - Only show user's episodes +- `test_feed_invalid_token` - Return 404 for bad tokens +- `test_feed_metadata` - Verify personalized feed titles +- `test_feed_episode_order` - Ensure reverse chronological order +- `test_feed_enclosures` - Verify audio URLs and metadata + + +#### TestAdminInterface +- `test_queue_status_view` - Verify queue display +- `test_retry_action` - Test retry button functionality +- `test_delete_action` - Test delete button functionality +- `test_user_data_isolation` - Ensure users only see own data +- `test_status_summary` - Verify status counts display + +### Worker Tests (Worker.py) + +#### TestArticleExtraction +- `test_extract_valid_article` - Extract from well-formed HTML +- `test_extract_missing_title` - Handle articles without titles +- `test_extract_empty_content` - Handle empty articles +- `test_extract_network_error` - Handle connection failures +- `test_extract_timeout` - Handle slow responses +- `test_content_sanitization` - Remove unwanted elements + +#### TestTextToSpeech +- `test_tts_generation` - Generate audio from text +- `test_tts_chunking` - Handle long articles with chunking +- `test_tts_empty_text` - Handle empty input +- `test_tts_special_characters` - Handle unicode and special chars +- `test_llm_text_preparation` - Verify LLM editing +- `test_llm_failure_fallback` - Handle LLM API failures +- `test_chunk_concatenation` - Verify audio joining + +#### TestJobProcessing +- `test_process_job_success` - Complete pipeline execution +- `test_process_job_extraction_failure` - Handle bad URLs +- `test_process_job_tts_failure` - Handle TTS errors +- `test_process_job_s3_failure` - Handle upload errors +- `test_job_retry_logic` - Verify exponential backoff +- `test_max_retries` - Stop after max attempts +- `test_concurrent_processing` - Handle multiple jobs + +### Integration Tests + +#### TestEndToEnd +- `test_web_to_podcast` - Full pipeline from web submission +- `test_multiple_users` - Concurrent multi-user scenarios +- `test_error_recovery` - System recovery from failures + +### Test Infrastructure + +#### Fixtures and Mocks +- Mock OpenAI API responses +- Mock S3/Digital Ocean Spaces +- In-memory SQLite for fast tests +- Test data generators for articles + +#### Test Configuration +- Separate test database +- Mock external services by default +- Optional integration tests with real services +- Test coverage reporting +- Performance benchmarks for TTS chunking + +### Testing Best Practices +1. Each test should be independent and idempotent +2. Use descriptive test names that explain the scenario +3. Test both happy paths and error conditions +4. Mock external services to avoid dependencies +5. Use fixtures for common test data +6. Measure test coverage (aim for >80%) +7. Run tests in CI/CD pipeline +8. Keep tests fast (< 30 seconds total) + +### Pre-Launch Testing Checklist +- [x] All unit tests passing +- [ ] Integration tests with real services +- [ ] Load testing (100 concurrent users) +- [ ] Security testing (SQL injection, XSS) +- [ ] RSS feed validation +- [ ] Audio quality verification +- [ ] Error handling and logging +- [ ] Database backup/restore +- [ ] User data isolation verification +- [ ] Billing integration tests (when implemented) diff --git a/Biz/PodcastItLater/Admin.py b/Biz/PodcastItLater/Admin.py new file mode 100644 index 0000000..6f60948 --- /dev/null +++ b/Biz/PodcastItLater/Admin.py @@ -0,0 +1,1068 @@ +""" +PodcastItLater Admin Interface. + +Admin pages and functionality for managing users and queue items. +""" + +# : out podcastitlater-admin +# : dep ludic +# : dep httpx +# : dep starlette +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.UI as UI +import ludic.html as html + +# i need to import these unused because bild cannot get local transitive python +# dependencies yet +import Omni.App as App # noqa: F401 +import Omni.Log as Log # noqa: F401 +import Omni.Test as Test # noqa: F401 +import sys +import typing +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from ludic.web import Request +from ludic.web.datastructures import FormData +from ludic.web.responses import Response +from typing import override + + +class MetricsAttrs(Attrs): + """Attributes for Metrics component.""" + + metrics: dict[str, typing.Any] + user: dict[str, typing.Any] | None + + +class MetricCardAttrs(Attrs): + """Attributes for MetricCard component.""" + + title: str + value: int + icon: str + + +class MetricCard(Component[AnyChildren, MetricCardAttrs]): + """Display a single metric card.""" + + @override + def render(self) -> html.div: + title = self.attrs["title"] + value = self.attrs["value"] + icon = self.attrs.get("icon", "bi-bar-chart") + + return html.div( + html.div( + html.div( + html.i(classes=["bi", icon, "text-primary", "fs-2"]), + classes=["col-auto"], + ), + html.div( + html.h6(title, classes=["text-muted", "mb-1"]), + html.h3(str(value), classes=["mb-0"]), + classes=["col"], + ), + classes=["row", "align-items-center"], + ), + classes=["card-body"], + ) + + +class TopEpisodesTableAttrs(Attrs): + """Attributes for TopEpisodesTable component.""" + + episodes: list[dict[str, typing.Any]] + metric_name: str + count_key: str + + +class TopEpisodesTable(Component[AnyChildren, TopEpisodesTableAttrs]): + """Display a table of top episodes by a metric.""" + + @override + def render(self) -> html.div: + episodes = self.attrs["episodes"] + metric_name = self.attrs["metric_name"] + count_key = self.attrs["count_key"] + + if not episodes: + return html.div( + html.p( + "No data yet", + classes=["text-muted", "text-center", "py-3"], + ), + classes=["card-body"], + ) + + return html.div( + html.div( + html.table( + html.thead( + html.tr( + html.th("#", classes=["text-muted"]), + html.th("Title"), + html.th("Author", classes=["text-muted"]), + html.th( + metric_name, + classes=["text-end", "text-muted"], + ), + ), + classes=["table-light"], + ), + html.tbody( + *[ + html.tr( + html.td( + str(idx + 1), + classes=["text-muted"], + ), + html.td( + TruncatedText( + text=episode["title"], + max_length=Core.TITLE_TRUNCATE_LENGTH, + ), + ), + html.td( + episode.get("author") or "-", + classes=["text-muted"], + ), + html.td( + str(episode[count_key]), + classes=["text-end"], + ), + ) + for idx, episode in enumerate(episodes) + ], + ), + classes=["table", "table-hover", "mb-0"], + ), + classes=["table-responsive"], + ), + classes=["card-body", "p-0"], + ) + + +class MetricsDashboard(Component[AnyChildren, MetricsAttrs]): + """Admin metrics dashboard showing aggregate statistics.""" + + @override + def render(self) -> UI.PageLayout: + metrics = self.attrs["metrics"] + user = self.attrs.get("user") + + return UI.PageLayout( + html.div( + html.h2( + html.i(classes=["bi", "bi-people", "me-2"]), + "Growth & Usage", + classes=["mb-4"], + ), + # Growth & Usage cards + html.div( + html.div( + html.div( + MetricCard( + title="Total Users", + value=metrics.get("total_users", 0), + icon="bi-people", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Active Subs", + value=metrics.get("active_subscriptions", 0), + icon="bi-credit-card", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Submissions (24h)", + value=metrics.get("submissions_24h", 0), + icon="bi-activity", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Submissions (7d)", + value=metrics.get("submissions_7d", 0), + icon="bi-calendar-week", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + classes=["row", "g-3", "mb-5"], + ), + html.h2( + html.i(classes=["bi", "bi-graph-up", "me-2"]), + "Episode Metrics", + classes=["mb-4"], + ), + # Summary cards + html.div( + html.div( + html.div( + MetricCard( + title="Total Episodes", + value=metrics["total_episodes"], + icon="bi-collection", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Total Plays", + value=metrics["total_plays"], + icon="bi-play-circle", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Total Downloads", + value=metrics["total_downloads"], + icon="bi-download", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + html.div( + html.div( + MetricCard( + title="Total Adds", + value=metrics["total_adds"], + icon="bi-plus-circle", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-md-3"], + ), + classes=["row", "g-3", "mb-4"], + ), + # Top episodes tables + html.div( + html.div( + html.div( + html.div( + html.h5( + html.i( + classes=[ + "bi", + "bi-play-circle-fill", + "me-2", + ], + ), + "Most Played", + classes=["card-title", "mb-0"], + ), + classes=["card-header", "bg-white"], + ), + TopEpisodesTable( + episodes=metrics["most_played"], + metric_name="Plays", + count_key="play_count", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-4"], + ), + html.div( + html.div( + html.div( + html.h5( + html.i( + classes=[ + "bi", + "bi-download", + "me-2", + ], + ), + "Most Downloaded", + classes=["card-title", "mb-0"], + ), + classes=["card-header", "bg-white"], + ), + TopEpisodesTable( + episodes=metrics["most_downloaded"], + metric_name="Downloads", + count_key="download_count", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-4"], + ), + html.div( + html.div( + html.div( + html.h5( + html.i( + classes=[ + "bi", + "bi-plus-circle-fill", + "me-2", + ], + ), + "Most Added to Feeds", + classes=["card-title", "mb-0"], + ), + classes=["card-header", "bg-white"], + ), + TopEpisodesTable( + episodes=metrics["most_added"], + metric_name="Adds", + count_key="add_count", + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-4"], + ), + classes=["row", "g-3"], + ), + ), + user=user, + current_page="admin-metrics", + error=None, + ) + + +class AdminUsersAttrs(Attrs): + """Attributes for AdminUsers component.""" + + users: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + + +class StatusBadgeAttrs(Attrs): + """Attributes for StatusBadge component.""" + + status: str + count: int | None + + +class StatusBadge(Component[AnyChildren, StatusBadgeAttrs]): + """Display a status badge with optional count.""" + + @override + def render(self) -> html.span: + status = self.attrs["status"] + count = self.attrs.get("count", None) + + text = f"{status.upper()}: {count}" if count is not None else status + badge_class = self.get_status_badge_class(status) + + return html.span( + text, + classes=["badge", badge_class, "me-3" if count is not None else ""], + ) + + @staticmethod + def get_status_badge_class(status: str) -> str: + """Get Bootstrap badge class for status.""" + return { + "pending": "bg-warning text-dark", + "processing": "bg-primary", + "completed": "bg-success", + "active": "bg-success", + "error": "bg-danger", + "cancelled": "bg-secondary", + "disabled": "bg-danger", + }.get(status, "bg-secondary") + + +class TruncatedTextAttrs(Attrs): + """Attributes for TruncatedText component.""" + + text: str + max_length: int + max_width: str + + +class TruncatedText(Component[AnyChildren, TruncatedTextAttrs]): + """Display truncated text with tooltip.""" + + @override + def render(self) -> html.div: + text = self.attrs["text"] + max_length = self.attrs["max_length"] + max_width = self.attrs.get("max_width", "200px") + + truncated = ( + text[:max_length] + "..." if len(text) > max_length else text + ) + + return html.div( + truncated, + title=text, + classes=["text-truncate"], + style={"max-width": max_width}, + ) + + +class ActionButtonsAttrs(Attrs): + """Attributes for ActionButtons component.""" + + job_id: int + status: str + + +class ActionButtons(Component[AnyChildren, ActionButtonsAttrs]): + """Render action buttons for queue items.""" + + @override + def render(self) -> html.div: + job_id = self.attrs["job_id"] + status = self.attrs["status"] + + buttons = [] + + if status != "completed": + buttons.append( + html.button( + html.i(classes=["bi", "bi-arrow-clockwise", "me-1"]), + "Retry", + hx_post=f"/queue/{job_id}/retry", + hx_target="body", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-success", "me-1"], + disabled=status == "completed", + ), + ) + + buttons.append( + html.button( + html.i(classes=["bi", "bi-trash", "me-1"]), + "Delete", + hx_delete=f"/queue/{job_id}", + hx_confirm="Are you sure you want to delete this queue item?", + hx_target="body", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-danger"], + ), + ) + + return html.div( + *buttons, + classes=["btn-group"], + ) + + +class QueueTableRowAttrs(Attrs): + """Attributes for QueueTableRow component.""" + + item: dict[str, typing.Any] + + +class QueueTableRow(Component[AnyChildren, QueueTableRowAttrs]): + """Render a single queue table row.""" + + @override + def render(self) -> html.tr: + item = self.attrs["item"] + + return html.tr( + html.td(str(item["id"])), + html.td( + TruncatedText( + text=item["url"], + max_length=Core.TITLE_TRUNCATE_LENGTH, + max_width="300px", + ), + ), + html.td( + TruncatedText( + text=item.get("title") or "-", + max_length=Core.TITLE_TRUNCATE_LENGTH, + ), + ), + html.td(item["email"] or "-"), + html.td(StatusBadge(status=item["status"])), + html.td(str(item.get("retry_count", 0))), + html.td(html.small(item["created_at"], classes=["text-muted"])), + html.td( + TruncatedText( + text=item["error_message"] or "-", + max_length=Core.ERROR_TRUNCATE_LENGTH, + ) + if item["error_message"] + else html.span("-", classes=["text-muted"]), + ), + html.td(ActionButtons(job_id=item["id"], status=item["status"])), + ) + + +class EpisodeTableRowAttrs(Attrs): + """Attributes for EpisodeTableRow component.""" + + episode: dict[str, typing.Any] + + +class EpisodeTableRow(Component[AnyChildren, EpisodeTableRowAttrs]): + """Render a single episode table row.""" + + @override + def render(self) -> html.tr: + episode = self.attrs["episode"] + + return html.tr( + html.td(str(episode["id"])), + html.td( + TruncatedText( + text=episode["title"], + max_length=Core.TITLE_TRUNCATE_LENGTH, + ), + ), + html.td( + html.a( + html.i(classes=["bi", "bi-play-circle", "me-1"]), + "Listen", + href=episode["audio_url"], + target="_blank", + classes=["btn", "btn-sm", "btn-outline-primary"], + ), + ), + html.td( + f"{episode['duration']}s" if episode["duration"] else "-", + ), + html.td( + f"{episode['content_length']:,} chars" + if episode["content_length"] + else "-", + ), + html.td(html.small(episode["created_at"], classes=["text-muted"])), + ) + + +class UserTableRowAttrs(Attrs): + """Attributes for UserTableRow component.""" + + user: dict[str, typing.Any] + + +class UserTableRow(Component[AnyChildren, UserTableRowAttrs]): + """Render a single user table row.""" + + @override + def render(self) -> html.tr: + user = self.attrs["user"] + + return html.tr( + html.td(user["email"]), + html.td(html.small(user["created_at"], classes=["text-muted"])), + html.td(StatusBadge(status=user.get("status", "pending"))), + html.td( + html.select( + html.option( + "Pending", + value="pending", + selected=user.get("status") == "pending", + ), + html.option( + "Active", + value="active", + selected=user.get("status") == "active", + ), + html.option( + "Disabled", + value="disabled", + selected=user.get("status") == "disabled", + ), + name="status", + hx_post=f"/admin/users/{user['id']}/status", + hx_trigger="change", + hx_target="body", + hx_swap="outerHTML", + classes=["form-select", "form-select-sm"], + ), + ), + ) + + +def create_table_header(columns: list[str]) -> html.thead: + """Create a table header with given column names.""" + return html.thead( + html.tr(*[html.th(col, scope="col") for col in columns]), + classes=["table-light"], + ) + + +class AdminUsers(Component[AnyChildren, AdminUsersAttrs]): + """Admin view for managing users.""" + + @override + def render(self) -> UI.PageLayout: + users = self.attrs["users"] + user = self.attrs.get("user") + + return UI.PageLayout( + html.h2( + "User Management", + classes=["mb-4"], + ), + self._render_users_table(users), + user=user, + current_page="admin-users", + error=None, + ) + + @staticmethod + def _render_users_table( + users: list[dict[str, typing.Any]], + ) -> html.div: + """Render users table.""" + return html.div( + html.h2("All Users", classes=["mb-3"]), + html.div( + html.table( + create_table_header([ + "Email", + "Created At", + "Status", + "Actions", + ]), + html.tbody(*[UserTableRow(user=user) for user in users]), + classes=["table", "table-hover", "table-striped"], + ), + classes=["table-responsive"], + ), + ) + + +class AdminViewAttrs(Attrs): + """Attributes for AdminView component.""" + + queue_items: list[dict[str, typing.Any]] + episodes: list[dict[str, typing.Any]] + status_counts: dict[str, int] + user: dict[str, typing.Any] | None + + +class AdminView(Component[AnyChildren, AdminViewAttrs]): + """Admin view showing all queue items and episodes in tables.""" + + @override + def render(self) -> UI.PageLayout: + queue_items = self.attrs["queue_items"] + episodes = self.attrs["episodes"] + status_counts = self.attrs.get("status_counts", {}) + user = self.attrs.get("user") + + return UI.PageLayout( + html.div( + AdminView.render_content( + queue_items, + episodes, + status_counts, + ), + id="admin-content", + hx_get="/admin", + hx_trigger="every 10s", + hx_swap="innerHTML", + hx_target="#admin-content", + ), + user=user, + current_page="admin", + error=None, + ) + + @staticmethod + def render_content( + queue_items: list[dict[str, typing.Any]], + episodes: list[dict[str, typing.Any]], + status_counts: dict[str, int], + ) -> html.div: + """Render the main content of the admin page.""" + return html.div( + html.h2( + "Admin Queue Status", + classes=["mb-4"], + ), + AdminView.render_status_summary(status_counts), + AdminView.render_queue_table(queue_items), + AdminView.render_episodes_table(episodes), + ) + + @staticmethod + def render_status_summary(status_counts: dict[str, int]) -> html.div: + """Render status summary section.""" + return html.div( + html.h2("Status Summary", classes=["mb-3"]), + html.div( + *[ + StatusBadge(status=status, count=count) + for status, count in status_counts.items() + ], + classes=["mb-4"], + ), + ) + + @staticmethod + def render_queue_table( + queue_items: list[dict[str, typing.Any]], + ) -> html.div: + """Render queue items table.""" + return html.div( + html.h2("Queue Items", classes=["mb-3"]), + html.div( + html.table( + create_table_header([ + "ID", + "URL", + "Title", + "Email", + "Status", + "Retries", + "Created", + "Error", + "Actions", + ]), + html.tbody(*[ + QueueTableRow(item=item) for item in queue_items + ]), + classes=["table", "table-hover", "table-sm"], + ), + classes=["table-responsive", "mb-5"], + ), + ) + + @staticmethod + def render_episodes_table( + episodes: list[dict[str, typing.Any]], + ) -> html.div: + """Render episodes table.""" + return html.div( + html.h2("Completed Episodes", classes=["mb-3"]), + html.div( + html.table( + create_table_header([ + "ID", + "Title", + "Audio URL", + "Duration", + "Content Length", + "Created", + ]), + html.tbody(*[ + EpisodeTableRow(episode=episode) for episode in episodes + ]), + classes=["table", "table-hover", "table-sm"], + ), + classes=["table-responsive"], + ), + ) + + +def admin_queue_status(request: Request) -> AdminView | Response | html.div: + """Return admin view showing all queue items and episodes.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + # Redirect to login + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id( + user_id, + ) + if not user: + # Invalid session + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + # Check if user is admin + if not Core.is_admin(user): + # Forbidden - redirect to home with error + return Response( + "", + status_code=302, + headers={"Location": "/?error=forbidden"}, + ) + + # Admins can see all data (excluding completed items) + all_queue_items = [ + item + for item in Core.Database.get_all_queue_items(None) + if item.get("status") != "completed" + ] + all_episodes = Core.Database.get_all_episodes( + None, + ) + + # Get overall status counts for all users + status_counts: dict[str, int] = {} + for item in all_queue_items: + status = item.get("status", "unknown") + status_counts[status] = status_counts.get(status, 0) + 1 + + # Check if this is an HTMX request for auto-update + if request.headers.get("HX-Request") == "true": + # Return just the content div for HTMX updates + content = AdminView.render_content( + all_queue_items, + all_episodes, + status_counts, + ) + return html.div( + content, + hx_get="/admin", + hx_trigger="every 10s", + hx_swap="innerHTML", + ) + + return AdminView( + queue_items=all_queue_items, + episodes=all_episodes, + status_counts=status_counts, + user=user, + ) + + +def retry_queue_item(request: Request, job_id: int) -> Response: + """Retry a failed queue item.""" + try: + # Check if user owns this job or is admin + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + job = Core.Database.get_job_by_id( + job_id, + ) + if job is None: + return Response("Job not found", status_code=404) + + # Check ownership or admin status + user = Core.Database.get_user_by_id(user_id) + if job.get("user_id") != user_id and not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + Core.Database.retry_job(job_id) + + # Check if request is from admin page via referer header + is_from_admin = "/admin" in request.headers.get("referer", "") + + # Redirect to admin if from admin page, trigger update otherwise + if is_from_admin: + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin"}, + ) + return Response( + "", + status_code=200, + headers={"HX-Trigger": "queue-updated"}, + ) + except (ValueError, KeyError) as e: + return Response( + f"Error retrying job: {e!s}", + status_code=500, + ) + + +def delete_queue_item(request: Request, job_id: int) -> Response: + """Delete a queue item.""" + try: + # Check if user owns this job or is admin + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + job = Core.Database.get_job_by_id( + job_id, + ) + if job is None: + return Response("Job not found", status_code=404) + + # Check ownership or admin status + user = Core.Database.get_user_by_id(user_id) + if job.get("user_id") != user_id and not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + Core.Database.delete_job(job_id) + + # Check if request is from admin page via referer header + is_from_admin = "/admin" in request.headers.get("referer", "") + + # Redirect to admin if from admin page, trigger update otherwise + if is_from_admin: + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin"}, + ) + return Response( + "", + status_code=200, + headers={"HX-Trigger": "queue-updated"}, + ) + except (ValueError, KeyError) as e: + return Response( + f"Error deleting job: {e!s}", + status_code=500, + ) + + +def admin_users(request: Request) -> AdminUsers | Response: + """Admin page for managing users.""" + # Check if user is logged in and is admin + user_id = request.session.get("user_id") + if not user_id: + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id( + user_id, + ) + if not user or not Core.is_admin(user): + return Response( + "", + status_code=302, + headers={"Location": "/?error=forbidden"}, + ) + + # Get all users + with Core.Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id, email, created_at, status FROM users " + "ORDER BY created_at DESC", + ) + rows = cursor.fetchall() + users = [dict(row) for row in rows] + + return AdminUsers(users=users, user=user) + + +def update_user_status( + request: Request, + user_id: int, + data: FormData, +) -> Response: + """Update user account status.""" + # Check if user is logged in and is admin + session_user_id = request.session.get("user_id") + if not session_user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id( + session_user_id, + ) + if not user or not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + # Get new status from form data + new_status_raw = data.get("status", "pending") + new_status = ( + new_status_raw if isinstance(new_status_raw, str) else "pending" + ) + if new_status not in {"pending", "active", "disabled"}: + return Response("Invalid status", status_code=400) + + # Update user status + Core.Database.update_user_status( + user_id, + new_status, + ) + + # Redirect back to users page + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin/users"}, + ) + + +def toggle_episode_public(request: Request, episode_id: int) -> Response: + """Toggle episode public/private status.""" + # Check if user is logged in and is admin + session_user_id = request.session.get("user_id") + if not session_user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id(session_user_id) + if not user or not Core.is_admin(user): + return Response("Forbidden", status_code=403) + + # Get current episode status + episode = Core.Database.get_episode_by_id(episode_id) + if not episode: + return Response("Episode not found", status_code=404) + + # Toggle public status + current_public = episode.get("is_public", 0) == 1 + if current_public: + Core.Database.unmark_episode_public(episode_id) + else: + Core.Database.mark_episode_public(episode_id) + + # Redirect to home page to see updated status + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/"}, + ) + + +def admin_metrics(request: Request) -> MetricsDashboard | Response: + """Admin metrics dashboard showing episode statistics.""" + # Check if user is logged in and is admin + user_id = request.session.get("user_id") + if not user_id: + return Response( + "", + status_code=302, + headers={"Location": "/"}, + ) + + user = Core.Database.get_user_by_id( + user_id, + ) + if not user or not Core.is_admin(user): + return Response( + "", + status_code=302, + headers={"Location": "/?error=forbidden"}, + ) + + # Get metrics data + metrics = Core.Database.get_metrics_summary() + + return MetricsDashboard(metrics=metrics, user=user) + + +def main() -> None: + """Admin tests are currently in Web.""" + if "test" in sys.argv: + sys.exit(0) diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py new file mode 100644 index 0000000..9f3739d --- /dev/null +++ b/Biz/PodcastItLater/Billing.py @@ -0,0 +1,581 @@ +""" +PodcastItLater Billing Integration. + +Stripe subscription management and usage enforcement. +""" + +# : out podcastitlater-billing +# : dep stripe +# : dep pytest +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import stripe +import sys +import typing +from datetime import datetime +from datetime import timezone + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Stripe configuration +stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") + +# Price IDs from Stripe dashboard +STRIPE_PRICE_ID_PAID = os.getenv("STRIPE_PRICE_ID_PAID", "") + +# Map Stripe price IDs to tier names +PRICE_TO_TIER = { + STRIPE_PRICE_ID_PAID: "paid", +} + +# Tier limits (None = unlimited) +TIER_LIMITS: dict[str, dict[str, int | None]] = { + "free": { + "articles_per_period": 10, + "minutes_per_period": None, + }, + "paid": { + "articles_per_period": None, + "minutes_per_period": None, + }, +} + +# Price map for checkout +PRICE_MAP = { + "paid": STRIPE_PRICE_ID_PAID, +} + + +def get_period_boundaries( + user: dict[str, typing.Any], +) -> tuple[datetime, datetime]: + """Get billing period boundaries for user. + + For paid users: use Stripe subscription period. + For free users: lifetime (from account creation to far future). + + Returns: + tuple: (period_start, period_end) as datetime objects + """ + if user.get("plan_tier") != "free" and user.get("current_period_start"): + # Paid user - use Stripe billing period + period_start = datetime.fromisoformat(user["current_period_start"]) + period_end = datetime.fromisoformat(user["current_period_end"]) + else: + # Free user - lifetime limit from account creation + period_start = datetime.fromisoformat(user["created_at"]) + # Set far future end date (100 years from now) + now = datetime.now(timezone.utc) + period_end = now.replace(year=now.year + 100) + + return period_start, period_end + + +def get_usage( + user_id: int, + period_start: datetime, + period_end: datetime, +) -> dict[str, int]: + """Get usage stats for user in billing period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + return Core.Database.get_usage(user_id, period_start, period_end) + + +def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]: + """Check if user can submit article based on tier limits. + + Returns: + tuple: (allowed: bool, message: str, usage: dict) + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + return False, "User not found", {} + + tier = user.get("plan_tier", "free") + limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) + + # Get billing period boundaries + period_start, period_end = get_period_boundaries(user) + + # Get current usage + usage = get_usage(user_id, period_start, period_end) + + # Check article limit + article_limit = limits.get("articles_per_period") + if article_limit is not None and usage["articles"] >= article_limit: + msg = ( + f"You've reached your limit of {article_limit} articles " + "per period. Upgrade to continue." + ) + return (False, msg, usage) + + # Check minutes limit (if implemented) + minute_limit = limits.get("minutes_per_period") + if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: + return ( + False, + f"You've reached your limit of {minute_limit} minutes per period. " + "Please upgrade to continue.", + usage, + ) + + return True, "", usage + + +def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: + """Create Stripe Checkout session for subscription. + + Args: + user_id: User ID + tier: Subscription tier (paid) + base_url: Base URL for success/cancel redirects + + Returns: + Checkout session URL to redirect user to + + Raises: + ValueError: If tier is invalid or price ID not configured + """ + if tier not in PRICE_MAP: + msg = f"Invalid tier: {tier}" + raise ValueError(msg) + + price_id = PRICE_MAP[tier] + if not price_id: + msg = f"Stripe price ID not configured for tier: {tier}" + raise ValueError(msg) + + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + # Create checkout session + session_params = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "success_url": f"{base_url}/?status=success", + "cancel_url": f"{base_url}/?status=cancel", + "client_reference_id": str(user_id), + "metadata": {"user_id": str(user_id), "tier": tier}, + "allow_promotion_codes": True, + } + + # Use existing customer if available + if user.get("stripe_customer_id"): + session_params["customer"] = user["stripe_customer_id"] + else: + session_params["customer_email"] = user["email"] + + session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type] + + logger.info( + "Created checkout session for user %s, tier %s: %s", + user_id, + tier, + session.id, + ) + + return session.url # type: ignore[return-value] + + +def create_portal_session(user_id: int, base_url: str) -> str: + """Create Stripe Billing Portal session. + + Args: + user_id: User ID + base_url: Base URL for return redirect + + Returns: + Portal session URL to redirect user to + + Raises: + ValueError: If user has no Stripe customer ID or portal not configured + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + if not user.get("stripe_customer_id"): + msg = "User has no Stripe customer ID" + raise ValueError(msg) + + session = stripe.billing_portal.Session.create( + customer=user["stripe_customer_id"], + return_url=f"{base_url}/account", + ) + + logger.info( + "Created portal session for user %s: %s", + user_id, + session.id, + ) + + return session.url + + +def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]: + """Verify and process Stripe webhook event. + + Args: + payload: Raw webhook body + sig_header: Stripe-Signature header value + + Returns: + dict with processing status + + Note: + May raise stripe.error.SignatureVerificationError if invalid signature + """ + # Verify webhook signature (skip in test mode if secret not configured) + if STRIPE_WEBHOOK_SECRET: + event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call] + payload, + sig_header, + STRIPE_WEBHOOK_SECRET, + ) + else: + # Test mode without signature verification + logger.warning( + "Webhook signature verification skipped (no STRIPE_WEBHOOK_SECRET)", + ) + event = json.loads(payload.decode("utf-8")) + + event_id = event["id"] + event_type = event["type"] + + # Check if already processed (idempotency) + if Core.Database.has_processed_stripe_event(event_id): + logger.info("Skipping already processed event: %s", event_id) + return {"status": "skipped", "reason": "already_processed"} + + # Process event based on type + logger.info("Processing webhook event: %s (%s)", event_id, event_type) + + try: + if event_type == "checkout.session.completed": + _handle_checkout_completed(event["data"]["object"]) + elif event_type == "customer.subscription.created": + _handle_subscription_created(event["data"]["object"]) + elif event_type == "customer.subscription.updated": + _handle_subscription_updated(event["data"]["object"]) + elif event_type == "customer.subscription.deleted": + _handle_subscription_deleted(event["data"]["object"]) + elif event_type == "invoice.payment_failed": + _handle_payment_failed(event["data"]["object"]) + else: + logger.info("Unhandled event type: %s", event_type) + return {"status": "ignored", "type": event_type} + + # Mark event as processed + Core.Database.mark_stripe_event_processed(event_id, event_type, payload) + except Exception: + logger.exception("Error processing webhook event %s", event_id) + raise + else: + return {"status": "processed", "type": event_type} + + +def _handle_checkout_completed(session: dict[str, typing.Any]) -> None: + """Handle checkout.session.completed event.""" + client_ref = session.get("client_reference_id") or session.get( + "metadata", + {}, + ).get("user_id") + customer_id = session.get("customer") + + if not client_ref or not customer_id: + logger.warning( + "Missing user_id or customer_id in checkout session: %s", + session.get("id", "unknown"), + ) + return + + try: + user_id = int(client_ref) + except (ValueError, TypeError): + logger.warning( + "Invalid user_id in checkout session: %s", + client_ref, + ) + return + + # Link Stripe customer to user + Core.Database.set_user_stripe_customer(user_id, customer_id) + logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) + + +def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.created event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.updated event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.deleted event.""" + customer_id = subscription["customer"] + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Downgrade to free + Core.Database.downgrade_to_free(user["id"]) + logger.info("Downgraded user %s to free tier", user["id"]) + + +def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None: + """Handle invoice.payment_failed event.""" + customer_id = invoice["customer"] + subscription_id = invoice.get("subscription") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update subscription status to past_due + if subscription_id: + Core.Database.update_subscription_status(user["id"], "past_due") + logger.warning( + "Payment failed for user %s, subscription %s", + user["id"], + subscription_id, + ) + + +def _update_subscription_state(subscription: dict[str, typing.Any]) -> None: + """Update user subscription state from Stripe subscription object.""" + customer_id = subscription.get("customer") + subscription_id = subscription.get("id") + status = subscription.get("status") + cancel_at_period_end = subscription.get("cancel_at_period_end", False) + + if not customer_id or not subscription_id or not status: + logger.warning( + "Missing required fields in subscription: %s", + subscription_id, + ) + return + + # Get billing period - try multiple field names for API compatibility + period_start_ts = ( + subscription.get("current_period_start") + or subscription.get("billing_cycle_anchor") + or subscription.get("start_date") + ) + period_end_ts = subscription.get("current_period_end") + + if not period_start_ts: + logger.warning( + "Missing period start in subscription: %s", + subscription_id, + ) + return + + period_start = datetime.fromtimestamp(period_start_ts, tz=timezone.utc) + + # Calculate period end if not provided (assume monthly) + december = 12 + january = 1 + if not period_end_ts: + if period_start.month == december: + period_end = period_start.replace( + year=period_start.year + 1, + month=january, + ) + else: + period_end = period_start.replace(month=period_start.month + 1) + else: + period_end = datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + + # Determine tier from price ID + items = subscription.get("items", {}) + data = items.get("data", []) + if not data: + logger.warning("No items in subscription: %s", subscription_id) + return + + price_id = data[0].get("price", {}).get("id") + if not price_id: + logger.warning("No price ID in subscription: %s", subscription_id) + return + + tier = PRICE_TO_TIER.get(price_id, "free") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update user subscription + Core.Database.update_user_subscription( + user["id"], + subscription_id, + status, + period_start, + period_end, + tier, + cancel_at_period_end, + ) + + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user["id"], + tier, + status, + ) + + +def get_tier_info(tier: str) -> dict[str, typing.Any]: + """Get tier information for display. + + Returns: + dict with keys: name, articles_limit, price, description + """ + tier_info: dict[str, dict[str, typing.Any]] = { + "free": { + "name": "Free", + "articles_limit": 10, + "price": "$0", + "description": "10 articles total", + }, + "paid": { + "name": "Paid", + "articles_limit": None, + "price": "$12/mo", + "description": "Unlimited articles", + }, + } + return tier_info.get(tier, tier_info["free"]) + + +# Tests +# ruff: noqa: PLR6301, PLW0603, S101 + + +class TestWebhookHandling(Test.TestCase): + """Test Stripe webhook handling.""" + + def setUp(self) -> None: + """Set up test database.""" + Core.Database.init_db() + + def tearDown(self) -> None: + """Clean up test database.""" + Core.Database.teardown() + + def test_full_checkout_flow(self) -> None: + """Test complete checkout flow from session to subscription.""" + # Create test user + user_id, _token = Core.Database.create_user("test@example.com") + + # Temporarily set price mapping for test + global PRICE_TO_TIER + old_mapping = PRICE_TO_TIER.copy() + PRICE_TO_TIER["price_test_paid"] = "paid" + + try: + # Step 1: Handle checkout.session.completed + checkout_session = { + "id": "cs_test123", + "client_reference_id": str(user_id), + "customer": "cus_test123", + "metadata": {"user_id": str(user_id), "tier": "paid"}, + } + _handle_checkout_completed(checkout_session) + + # Verify customer was linked + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["stripe_customer_id"], "cus_test123") + + # Step 2: Handle customer.subscription.created + # (newer API uses billing_cycle_anchor instead of current_period_*) + subscription = { + "id": "sub_test123", + "customer": "cus_test123", + "status": "active", + "billing_cycle_anchor": 1700000000, + "cancel_at_period_end": False, + "items": { + "data": [ + { + "price": { + "id": "price_test_paid", + }, + }, + ], + }, + } + _update_subscription_state(subscription) + + # Verify subscription was created and user upgraded + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["plan_tier"], "paid") + self.assertEqual(user["subscription_status"], "active") + self.assertEqual(user["stripe_subscription_id"], "sub_test123") + self.assertEqual(user["stripe_customer_id"], "cus_test123") + finally: + PRICE_TO_TIER = old_mapping + + def test_webhook_missing_fields(self) -> None: + """Test handling webhook with missing required fields.""" + # Create test user + user_id, _token = Core.Database.create_user("test@example.com") + Core.Database.set_user_stripe_customer(user_id, "cus_test456") + + # Mock subscription with missing current_period_start + subscription = { + "id": "sub_test456", + "customer": "cus_test456", + "status": "active", + # Missing current_period_start and current_period_end + "cancel_at_period_end": False, + "items": {"data": []}, + } + + # Should not crash, just log warning and return + _update_subscription_state(subscription) + + # User should remain on free tier + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["plan_tier"], "free") + + +def main() -> None: + """Run tests.""" + if len(sys.argv) > 1 and sys.argv[1] == "test": + os.environ["AREA"] = "Test" + Test.run(App.Area.Test, [TestWebhookHandling]) + else: + logger.error("Usage: billing.py test") + + +if __name__ == "__main__": + main() diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py new file mode 100644 index 0000000..3a88f22 --- /dev/null +++ b/Biz/PodcastItLater/Core.py @@ -0,0 +1,2174 @@ +"""Core, shared logic for PodcastItalater. + +Includes: +- Database models +- Data access layer +- Shared types +""" + +# : out podcastitlater-core +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +import hashlib +import logging +import Omni.App as App +import Omni.Test as Test +import os +import pathlib +import pytest +import secrets +import sqlite3 +import sys +import time +import typing +import urllib.parse +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = logging.getLogger(__name__) + + +CODEROOT = pathlib.Path(os.getenv("CODEROOT", ".")) +DATA_DIR = pathlib.Path( + os.environ.get("DATA_DIR", CODEROOT / "_/var/podcastitlater/"), +) + +# Constants for UI display +URL_TRUNCATE_LENGTH = 80 +TITLE_TRUNCATE_LENGTH = 50 +ERROR_TRUNCATE_LENGTH = 50 + +# Admin whitelist +ADMIN_EMAILS = ["ben@bensima.com", "admin@example.com"] + + +def is_admin(user: dict[str, typing.Any] | None) -> bool: + """Check if user is an admin based on email whitelist.""" + if not user: + return False + return user.get("email", "").lower() in [ + email.lower() for email in ADMIN_EMAILS + ] + + +def normalize_url(url: str) -> str: + """Normalize URL for comparison and hashing. + + Normalizes: + - Protocol (http/https) + - Domain case (lowercase) + - www prefix (removed) + - Trailing slash (removed) + - Preserves query params and fragments as they may be meaningful + + Args: + url: URL to normalize + + Returns: + Normalized URL string + """ + parsed = urllib.parse.urlparse(url.strip()) + + # Normalize domain to lowercase, remove www prefix + domain = parsed.netloc.lower() + domain = domain.removeprefix("www.") + + # Normalize path - remove trailing slash unless it's the root + path = parsed.path.rstrip("/") if parsed.path != "/" else "/" + + # Rebuild URL with normalized components + # Use https as the canonical protocol + return urllib.parse.urlunparse(( + "https", # Always use https + domain, + path, + parsed.params, + parsed.query, + parsed.fragment, + )) + + +def hash_url(url: str) -> str: + """Generate a hash of a URL for deduplication. + + Args: + url: URL to hash + + Returns: + SHA256 hash of the normalized URL + """ + normalized = normalize_url(url) + return hashlib.sha256(normalized.encode()).hexdigest() + + +class Database: # noqa: PLR0904 + """Data access layer for PodcastItLater database operations.""" + + @staticmethod + def teardown() -> None: + """Delete the existing database, for cleanup after tests.""" + db_path = DATA_DIR / "podcast.db" + if db_path.exists(): + db_path.unlink() + + @staticmethod + @contextmanager + def get_connection() -> Iterator[sqlite3.Connection]: + """Context manager for database connections. + + Yields: + sqlite3.Connection: Database connection with row factory set. + """ + db_path = DATA_DIR / "podcast.db" + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + @staticmethod + def init_db() -> None: + """Initialize database with required tables.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Queue table for job processing + cursor.execute(""" + CREATE TABLE IF NOT EXISTS queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + email TEXT, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + error_message TEXT, + title TEXT, + author TEXT + ) + """) + + # Episodes table for completed podcasts + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL, + content_length INTEGER, + audio_url TEXT NOT NULL, + duration INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes for performance + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_created " + "ON queue(created_at)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_created " + "ON episodes(created_at)", + ) + + conn.commit() + logger.info("Database initialized successfully") + + # Run migration to add user support + Database.migrate_to_multi_user() + + # Run migration to add metadata fields + Database.migrate_add_metadata_fields() + + # Run migration to add episode metadata fields + Database.migrate_add_episode_metadata() + + # Run migration to add user status field + Database.migrate_add_user_status() + + # Run migration to add default titles + Database.migrate_add_default_titles() + + # Run migration to add billing fields + Database.migrate_add_billing_fields() + + # Run migration to add stripe events table + Database.migrate_add_stripe_events_table() + + # Run migration to add public feed features + Database.migrate_add_public_feed() + + @staticmethod + def add_to_queue( + url: str, + email: str, + user_id: int, + title: str | None = None, + author: str | None = None, + ) -> int: + """Insert new job into queue with metadata, return job ID. + + Raises: + ValueError: If job ID cannot be retrieved after insert. + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO queue (url, email, user_id, title, author) " + "VALUES (?, ?, ?, ?, ?)", + (url, email, user_id, title, author), + ) + conn.commit() + job_id = cursor.lastrowid + if job_id is None: + msg = "Failed to get job ID after insert" + raise ValueError(msg) + logger.info("Added job %s to queue: %s", job_id, url) + return job_id + + @staticmethod + def get_pending_jobs( + limit: int = 10, + ) -> list[dict[str, Any]]: + """Fetch jobs with status='pending' ordered by creation time.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'pending' " + "ORDER BY created_at ASC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_job_status( + job_id: int, + status: str, + error: str | None = None, + ) -> None: + """Update job status and error message.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if error is not None: + if status == "error": + cursor.execute( + "UPDATE queue SET status = ?, error_message = ?, " + "retry_count = retry_count + 1 WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ?, " + "error_message = ? WHERE id = ?", + (status, error, job_id), + ) + else: + cursor.execute( + "UPDATE queue SET status = ? WHERE id = ?", + (status, job_id), + ) + conn.commit() + logger.info("Updated job %s status to %s", job_id, status) + + @staticmethod + def get_job_by_id( + job_id: int, + ) -> dict[str, Any] | None: + """Fetch single job by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM queue WHERE id = ?", (job_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def create_episode( # noqa: PLR0913, PLR0917 + title: str, + audio_url: str, + duration: int, + content_length: int, + user_id: int | None = None, + author: str | None = None, + original_url: str | None = None, + original_url_hash: str | None = None, + ) -> int: + """Insert episode record, return episode ID. + + Raises: + ValueError: If episode ID cannot be retrieved after insert. + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episodes " + "(title, audio_url, duration, content_length, user_id, " + "author, original_url, original_url_hash) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + title, + audio_url, + duration, + content_length, + user_id, + author, + original_url, + original_url_hash, + ), + ) + conn.commit() + episode_id = cursor.lastrowid + if episode_id is None: + msg = "Failed to get episode ID after insert" + raise ValueError(msg) + logger.info("Created episode %s: %s", episode_id, title) + return episode_id + + @staticmethod + def get_recent_episodes( + limit: int = 20, + ) -> list[dict[str, Any]]: + """Get recent episodes for RSS feed generation.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_queue_status_summary() -> dict[str, Any]: + """Get queue status summary for web interface.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count jobs by status + cursor.execute( + "SELECT status, COUNT(*) as count FROM queue GROUP BY status", + ) + rows = cursor.fetchall() + status_counts = {row["status"]: row["count"] for row in rows} + + # Get recent jobs + cursor.execute( + "SELECT * FROM queue ORDER BY created_at DESC LIMIT 10", + ) + rows = cursor.fetchall() + recent_jobs = [dict(row) for row in rows] + + return {"status_counts": status_counts, "recent_jobs": recent_jobs} + + @staticmethod + def get_queue_status() -> list[dict[str, Any]]: + """Return pending/processing/error items for web interface.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT id, url, email, status, created_at, error_message, + title, author + FROM queue + WHERE status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) + ORDER BY created_at DESC + LIMIT 20 + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_episode_by_id(episode_id: int) -> dict[str, Any] | None: + """Fetch single episode by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url, user_id, is_public + FROM episodes + WHERE id = ? + """, + (episode_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_all_episodes( + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all episodes for RSS feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_retryable_jobs( + max_retries: int = 3, + ) -> list[dict[str, Any]]: + """Get failed jobs that can be retried.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM queue WHERE status = 'error' " + "AND retry_count < ? ORDER BY created_at ASC", + (max_retries,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def retry_job(job_id: int) -> None: + """Reset a job to pending status for retry.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE queue SET status = 'pending', " + "error_message = NULL WHERE id = ?", + (job_id,), + ) + conn.commit() + logger.info("Reset job %s to pending for retry", job_id) + + @staticmethod + def delete_job(job_id: int) -> None: + """Delete a job from the queue.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM queue WHERE id = ?", (job_id,)) + conn.commit() + logger.info("Deleted job %s from queue", job_id) + + @staticmethod + def get_all_queue_items( + user_id: int | None = None, + ) -> list[dict[str, Any]]: + """Return all queue items for admin view.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + if user_id: + cursor.execute( + """ + SELECT id, url, email, status, retry_count, created_at, + error_message, title, author + FROM queue + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ) + else: + cursor.execute(""" + SELECT id, url, email, status, retry_count, created_at, + error_message, title, author + FROM queue + ORDER BY created_at DESC + """) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_status_counts() -> dict[str, int]: + """Get count of queue items by status.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT status, COUNT(*) as count + FROM queue + GROUP BY status + """) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def get_user_status_counts( + user_id: int, + ) -> dict[str, int]: + """Get count of queue items by status for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) as count + FROM queue + WHERE user_id = ? + GROUP BY status + """, + (user_id,), + ) + rows = cursor.fetchall() + return {row["status"]: row["count"] for row in rows} + + @staticmethod + def migrate_to_multi_user() -> None: + """Migrate database to support multiple users.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + token TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add user_id columns to existing tables + # Check if columns already exist to make migration idempotent + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "user_id" not in queue_columns: + cursor.execute( + "ALTER TABLE queue ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "user_id" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN user_id INTEGER " + "REFERENCES users(id)", + ) + + # Create indexes + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_queue_user_id " + "ON queue(user_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_user_id " + "ON episodes(user_id)", + ) + + conn.commit() + logger.info("Database migrated to support multiple users") + + @staticmethod + def migrate_add_metadata_fields() -> None: + """Add title and author fields to queue table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if columns already exist + cursor.execute("PRAGMA table_info(queue)") + queue_info = cursor.fetchall() + queue_columns = [col[1] for col in queue_info] + + if "title" not in queue_columns: + cursor.execute("ALTER TABLE queue ADD COLUMN title TEXT") + + if "author" not in queue_columns: + cursor.execute("ALTER TABLE queue ADD COLUMN author TEXT") + + conn.commit() + logger.info("Database migrated to support metadata fields") + + @staticmethod + def migrate_add_episode_metadata() -> None: + """Add author and original_url fields to episodes table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if columns already exist + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "author" not in episodes_columns: + cursor.execute("ALTER TABLE episodes ADD COLUMN author TEXT") + + if "original_url" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN original_url TEXT", + ) + + conn.commit() + logger.info("Database migrated to support episode metadata fields") + + @staticmethod + def migrate_add_user_status() -> None: + """Add status field to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check if column already exists + cursor.execute("PRAGMA table_info(users)") + users_info = cursor.fetchall() + users_columns = [col[1] for col in users_info] + + if "status" not in users_columns: + # Add status column with default 'active' + cursor.execute( + "ALTER TABLE users ADD COLUMN status TEXT DEFAULT 'active'", + ) + + conn.commit() + logger.info("Database migrated to support user status") + + @staticmethod + def migrate_add_billing_fields() -> None: + """Add billing-related fields to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add columns one by one (SQLite limitation) + # Note: SQLite ALTER TABLE doesn't support adding UNIQUE constraints + # We add them without UNIQUE and rely on application logic + columns_to_add = [ + ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"), + ("stripe_customer_id", "TEXT"), + ("stripe_subscription_id", "TEXT"), + ("subscription_status", "TEXT"), + ("current_period_start", "TIMESTAMP"), + ("current_period_end", "TIMESTAMP"), + ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"), + ] + + for column_name, column_def in columns_to_add: + try: + query = f"ALTER TABLE users ADD COLUMN {column_name} " + cursor.execute(query + column_def) + logger.info("Added column users.%s", column_name) + except sqlite3.OperationalError as e: # noqa: PERF203 + # Column already exists, skip + logger.debug( + "Column users.%s already exists: %s", + column_name, + e, + ) + + conn.commit() + + @staticmethod + def migrate_add_stripe_events_table() -> None: + """Create stripe_events table for webhook idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS stripe_events ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_stripe_events_created " + "ON stripe_events(created_at)", + ) + conn.commit() + logger.info("Created stripe_events table") + + @staticmethod + def migrate_add_public_feed() -> None: + """Add is_public column and related tables for public feed feature.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add is_public column to episodes + cursor.execute("PRAGMA table_info(episodes)") + episodes_info = cursor.fetchall() + episodes_columns = [col[1] for col in episodes_info] + + if "is_public" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN is_public INTEGER " + "NOT NULL DEFAULT 0", + ) + logger.info("Added is_public column to episodes") + + # Create user_episodes junction table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS user_episodes ( + user_id INTEGER NOT NULL, + episode_id INTEGER NOT NULL, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (user_id, episode_id), + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (episode_id) REFERENCES episodes(id) + ) + """) + logger.info("Created user_episodes junction table") + + # Create index on episode_id for reverse lookups + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_user_episodes_episode " + "ON user_episodes(episode_id)", + ) + + # Add original_url_hash column to episodes + if "original_url_hash" not in episodes_columns: + cursor.execute( + "ALTER TABLE episodes ADD COLUMN original_url_hash TEXT", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_url_hash " + "ON episodes(original_url_hash)", + ) + logger.info("Added original_url_hash column to episodes") + + # Create episode_metrics table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + episode_id INTEGER NOT NULL, + user_id INTEGER, + event_type TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (episode_id) REFERENCES episodes(id), + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + logger.info("Created episode_metrics table") + + # Create indexes for metrics queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episode_metrics_episode " + "ON episode_metrics(episode_id)", + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episode_metrics_type " + "ON episode_metrics(event_type)", + ) + + # Create index on is_public for efficient public feed queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_episodes_public " + "ON episodes(is_public)", + ) + + conn.commit() + logger.info("Database migrated for public feed feature") + + @staticmethod + def migrate_add_default_titles() -> None: + """Add default titles to queue items that have None titles.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Update queue items with NULL titles to have a default + cursor.execute(""" + UPDATE queue + SET title = 'Untitled Article' + WHERE title IS NULL + """) + + # Get count of updated rows + updated_count = cursor.rowcount + + conn.commit() + logger.info( + "Updated %s queue items with default titles", + updated_count, + ) + + @staticmethod + def create_user(email: str, status: str = "active") -> tuple[int, str]: + """Create a new user and return (user_id, token). + + Args: + email: User email address + status: Initial status (active or disabled) + + Raises: + ValueError: If user ID cannot be retrieved after insert or if user + not found, or if status is invalid. + """ + if status not in {"pending", "active", "disabled"}: + msg = f"Invalid status: {status}" + raise ValueError(msg) + + # Generate a secure token for RSS feed access + token = secrets.token_urlsafe(32) + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO users (email, token, status) VALUES (?, ?, ?)", + (email, token, status), + ) + conn.commit() + user_id = cursor.lastrowid + if user_id is None: + msg = "Failed to get user ID after insert" + raise ValueError(msg) + logger.info( + "Created user %s with email %s (status: %s)", + user_id, + email, + status, + ) + except sqlite3.IntegrityError: + # User already exists + cursor.execute( + "SELECT id, token FROM users WHERE email = ?", + (email,), + ) + row = cursor.fetchone() + if row is None: + msg = f"User with email {email} not found" + raise ValueError(msg) from None + return int(row["id"]), str(row["token"]) + else: + return int(user_id), str(token) + + @staticmethod + def get_user_by_email( + email: str, + ) -> dict[str, Any] | None: + """Get user by email address.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_token( + token: str, + ) -> dict[str, Any] | None: + """Get user by RSS token.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE token = ?", (token,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_user_by_id( + user_id: int, + ) -> dict[str, Any] | None: + """Get user by ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_queue_position(job_id: int) -> int | None: + """Get position of job in pending queue.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + # Get created_at of this job + cursor.execute( + "SELECT created_at FROM queue WHERE id = ?", + (job_id,), + ) + row = cursor.fetchone() + if not row: + return None + created_at = row[0] + + # Count pending items created before or at same time + cursor.execute( + """ + SELECT COUNT(*) FROM queue + WHERE status = 'pending' AND created_at <= ? + """, + (created_at,), + ) + return int(cursor.fetchone()[0]) + + @staticmethod + def get_user_queue_status( + user_id: int, + ) -> list[dict[str, Any]]: + """Return pending/processing/error items for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, url, email, status, created_at, error_message, + title, author + FROM queue + WHERE user_id = ? AND + status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) + ORDER BY created_at DESC + LIMIT 20 + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_recent_episodes( + user_id: int, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Get recent episodes for a specific user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC LIMIT ?", + (user_id, limit), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_user_all_episodes( + user_id: int, + ) -> list[dict[str, Any]]: + """Get all episodes for a specific user for RSS feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE user_id = ? " + "ORDER BY created_at DESC", + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def update_user_status( + user_id: int, + status: str, + ) -> None: + """Update user account status.""" + if status not in {"pending", "active", "disabled"}: + msg = f"Invalid status: {status}" + raise ValueError(msg) + + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info("Updated user %s status to %s", user_id, status) + + @staticmethod + def delete_user(user_id: int) -> None: + """Delete user and all associated data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # 1. Get owned episode IDs + cursor.execute( + "SELECT id FROM episodes WHERE user_id = ?", + (user_id,), + ) + owned_episode_ids = [row[0] for row in cursor.fetchall()] + + # 2. Delete references to owned episodes + if owned_episode_ids: + # Construct placeholders for IN clause + placeholders = ",".join("?" * len(owned_episode_ids)) + + # Delete from user_episodes where these episodes are referenced + query = f"DELETE FROM user_episodes WHERE episode_id IN ({placeholders})" # noqa: S608, E501 + cursor.execute(query, tuple(owned_episode_ids)) + + # Delete metrics for these episodes + query = f"DELETE FROM episode_metrics WHERE episode_id IN ({placeholders})" # noqa: S608, E501 + cursor.execute(query, tuple(owned_episode_ids)) + + # 3. Delete owned episodes + cursor.execute("DELETE FROM episodes WHERE user_id = ?", (user_id,)) + + # 4. Delete user's data referencing others or themselves + cursor.execute( + "DELETE FROM user_episodes WHERE user_id = ?", + (user_id,), + ) + cursor.execute( + "DELETE FROM episode_metrics WHERE user_id = ?", + (user_id,), + ) + cursor.execute("DELETE FROM queue WHERE user_id = ?", (user_id,)) + + # 5. Delete user + cursor.execute("DELETE FROM users WHERE id = ?", (user_id,)) + + conn.commit() + logger.info("Deleted user %s and all associated data", user_id) + + @staticmethod + def update_user_email(user_id: int, new_email: str) -> None: + """Update user's email address. + + Args: + user_id: ID of the user to update + new_email: New email address + + Raises: + ValueError: If email is already taken by another user + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "UPDATE users SET email = ? WHERE id = ?", + (new_email, user_id), + ) + conn.commit() + logger.info("Updated user %s email to %s", user_id, new_email) + except sqlite3.IntegrityError: + msg = f"Email {new_email} is already taken" + raise ValueError(msg) from None + + @staticmethod + def mark_episode_public(episode_id: int) -> None: + """Mark an episode as public.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE episodes SET is_public = 1 WHERE id = ?", + (episode_id,), + ) + conn.commit() + logger.info("Marked episode %s as public", episode_id) + + @staticmethod + def unmark_episode_public(episode_id: int) -> None: + """Mark an episode as private (not public).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE episodes SET is_public = 0 WHERE id = ?", + (episode_id,), + ) + conn.commit() + logger.info("Unmarked episode %s as public", episode_id) + + @staticmethod + def get_public_episodes(limit: int = 50) -> list[dict[str, Any]]: + """Get public episodes for public feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, title, audio_url, duration, created_at, + content_length, author, original_url + FROM episodes + WHERE is_public = 1 + ORDER BY created_at DESC + LIMIT ? + """, + (limit,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def add_episode_to_user(user_id: int, episode_id: int) -> None: + """Add an episode to a user's feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO user_episodes (user_id, episode_id) " + "VALUES (?, ?)", + (user_id, episode_id), + ) + conn.commit() + logger.info( + "Added episode %s to user %s feed", + episode_id, + user_id, + ) + except sqlite3.IntegrityError: + # Episode already in user's feed + logger.info( + "Episode %s already in user %s feed", + episode_id, + user_id, + ) + + @staticmethod + def user_has_episode(user_id: int, episode_id: int) -> bool: + """Check if a user has an episode in their feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM user_episodes " + "WHERE user_id = ? AND episode_id = ?", + (user_id, episode_id), + ) + return cursor.fetchone() is not None + + @staticmethod + def track_episode_metric( + episode_id: int, + event_type: str, + user_id: int | None = None, + ) -> None: + """Track an episode metric event. + + Args: + episode_id: ID of the episode + event_type: Type of event ('added', 'played', 'downloaded') + user_id: Optional user ID (None for anonymous events) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episode_metrics (episode_id, user_id, event_type) " + "VALUES (?, ?, ?)", + (episode_id, user_id, event_type), + ) + conn.commit() + logger.info( + "Tracked %s event for episode %s (user: %s)", + event_type, + episode_id, + user_id or "anonymous", + ) + + @staticmethod + def get_user_episodes(user_id: int) -> list[dict[str, Any]]: + """Get all episodes in a user's feed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT e.id, e.title, e.audio_url, e.duration, e.created_at, + e.content_length, e.author, e.original_url, e.is_public, + ue.added_at + FROM episodes e + JOIN user_episodes ue ON e.id = ue.episode_id + WHERE ue.user_id = ? + ORDER BY ue.added_at DESC + """, + (user_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def get_episode_by_url_hash(url_hash: str) -> dict[str, Any] | None: + """Get episode by original URL hash.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE original_url_hash = ?", + (url_hash,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def get_metrics_summary() -> dict[str, Any]: + """Get aggregate metrics summary for admin dashboard. + + Returns: + dict with keys: + - total_episodes: Total number of episodes + - total_plays: Total play events + - total_downloads: Total download events + - total_adds: Total add events + - most_played: List of top 10 most played episodes + - most_downloaded: List of top 10 most downloaded episodes + - most_added: List of top 10 most added episodes + - total_users: Total number of users + - active_subscriptions: Number of active subscriptions + - submissions_24h: Submissions in last 24 hours + - submissions_7d: Submissions in last 7 days + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Get total episodes + cursor.execute("SELECT COUNT(*) as count FROM episodes") + total_episodes = cursor.fetchone()["count"] + + # Get event counts + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'played'", + ) + total_plays = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'downloaded'", + ) + total_downloads = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM episode_metrics " + "WHERE event_type = 'added'", + ) + total_adds = cursor.fetchone()["count"] + + # Get most played episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as play_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'played' + GROUP BY em.episode_id + ORDER BY play_count DESC + LIMIT 10 + """, + ) + most_played = [dict(row) for row in cursor.fetchall()] + + # Get most downloaded episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as download_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'downloaded' + GROUP BY em.episode_id + ORDER BY download_count DESC + LIMIT 10 + """, + ) + most_downloaded = [dict(row) for row in cursor.fetchall()] + + # Get most added episodes + cursor.execute( + """ + SELECT e.id, e.title, e.author, COUNT(*) as add_count + FROM episode_metrics em + JOIN episodes e ON em.episode_id = e.id + WHERE em.event_type = 'added' + GROUP BY em.episode_id + ORDER BY add_count DESC + LIMIT 10 + """, + ) + most_added = [dict(row) for row in cursor.fetchall()] + + # Get user metrics + cursor.execute("SELECT COUNT(*) as count FROM users") + total_users = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM users " + "WHERE subscription_status = 'active'", + ) + active_subscriptions = cursor.fetchone()["count"] + + # Get recent submission metrics + cursor.execute( + "SELECT COUNT(*) as count FROM queue " + "WHERE created_at >= datetime('now', '-1 day')", + ) + submissions_24h = cursor.fetchone()["count"] + + cursor.execute( + "SELECT COUNT(*) as count FROM queue " + "WHERE created_at >= datetime('now', '-7 days')", + ) + submissions_7d = cursor.fetchone()["count"] + + return { + "total_episodes": total_episodes, + "total_plays": total_plays, + "total_downloads": total_downloads, + "total_adds": total_adds, + "most_played": most_played, + "most_downloaded": most_downloaded, + "most_added": most_added, + "total_users": total_users, + "active_subscriptions": active_subscriptions, + "submissions_24h": submissions_24h, + "submissions_7d": submissions_7d, + } + + @staticmethod + def track_episode_event( + episode_id: int, + event_type: str, + user_id: int | None = None, + ) -> None: + """Track an episode event (added, played, downloaded).""" + if event_type not in {"added", "played", "downloaded"}: + msg = f"Invalid event type: {event_type}" + raise ValueError(msg) + + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO episode_metrics " + "(episode_id, user_id, event_type) VALUES (?, ?, ?)", + (episode_id, user_id, event_type), + ) + conn.commit() + logger.info( + "Tracked %s event for episode %s", + event_type, + episode_id, + ) + + @staticmethod + def get_episode_metrics(episode_id: int) -> dict[str, int]: + """Get aggregated metrics for an episode.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT event_type, COUNT(*) as count + FROM episode_metrics + WHERE episode_id = ? + GROUP BY event_type + """, + (episode_id,), + ) + rows = cursor.fetchall() + return {row["event_type"]: row["count"] for row in rows} + + @staticmethod + def get_episode_metric_events(episode_id: int) -> list[dict[str, Any]]: + """Get raw metric events for an episode (for testing).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, episode_id, user_id, event_type, created_at + FROM episode_metrics + WHERE episode_id = ? + ORDER BY created_at DESC + """, + (episode_id,), + ) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + @staticmethod + def set_user_stripe_customer(user_id: int, customer_id: str) -> None: + """Link Stripe customer ID to user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET stripe_customer_id = ? WHERE id = ?", + (customer_id, user_id), + ) + conn.commit() + logger.info( + "Linked user %s to Stripe customer %s", + user_id, + customer_id, + ) + + @staticmethod + def update_user_subscription( # noqa: PLR0913, PLR0917 + user_id: int, + subscription_id: str, + status: str, + period_start: Any, + period_end: Any, + tier: str, + cancel_at_period_end: bool, # noqa: FBT001 + ) -> None: + """Update user subscription details.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + stripe_subscription_id = ?, + subscription_status = ?, + current_period_start = ?, + current_period_end = ?, + plan_tier = ?, + cancel_at_period_end = ? + WHERE id = ? + """, + ( + subscription_id, + status, + period_start.isoformat(), + period_end.isoformat(), + tier, + 1 if cancel_at_period_end else 0, + user_id, + ), + ) + conn.commit() + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user_id, + tier, + status, + ) + + @staticmethod + def update_subscription_status(user_id: int, status: str) -> None: + """Update only the subscription status (e.g., past_due).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET subscription_status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info( + "Updated user %s subscription status to %s", + user_id, + status, + ) + + @staticmethod + def downgrade_to_free(user_id: int) -> None: + """Downgrade user to free tier and clear subscription data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + plan_tier = 'free', + subscription_status = 'canceled', + stripe_subscription_id = NULL, + current_period_start = NULL, + current_period_end = NULL, + cancel_at_period_end = 0 + WHERE id = ? + """, + (user_id,), + ) + conn.commit() + logger.info("Downgraded user %s to free tier", user_id) + + @staticmethod + def get_user_by_stripe_customer_id( + customer_id: str, + ) -> dict[str, Any] | None: + """Get user by Stripe customer ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM users WHERE stripe_customer_id = ?", + (customer_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def has_processed_stripe_event(event_id: str) -> bool: + """Check if Stripe event has already been processed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id FROM stripe_events WHERE id = ?", + (event_id,), + ) + return cursor.fetchone() is not None + + @staticmethod + def mark_stripe_event_processed( + event_id: str, + event_type: str, + payload: bytes, + ) -> None: + """Mark Stripe event as processed for idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO stripe_events (id, type, payload) " + "VALUES (?, ?, ?)", + (event_id, event_type, payload.decode("utf-8")), + ) + conn.commit() + + @staticmethod + def get_usage( + user_id: int, + period_start: Any, + period_end: Any, + ) -> dict[str, int]: + """Get usage stats for user in period. + + Counts episodes added to user's feed (via user_episodes table) + during the billing period, regardless of who created them. + + Returns: + dict with keys: articles (int), minutes (int) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count articles added to user's feed in period + # Uses user_episodes junction table to track when episodes + # were added, which correctly handles shared/existing episodes + cursor.execute( + """ + SELECT COUNT(*) as count, SUM(e.duration) as total_seconds + FROM user_episodes ue + JOIN episodes e ON e.id = ue.episode_id + WHERE ue.user_id = ? AND ue.added_at >= ? AND ue.added_at < ? + """, + (user_id, period_start.isoformat(), period_end.isoformat()), + ) + row = cursor.fetchone() + + articles = row["count"] if row else 0 + total_seconds = ( + row["total_seconds"] if row and row["total_seconds"] else 0 + ) + minutes = total_seconds // 60 + + return {"articles": articles, "minutes": minutes} + + +class TestDatabase(Test.TestCase): + """Test the Database class.""" + + @staticmethod + def setUp() -> None: + """Set up test database.""" + Database.init_db() + + def tearDown(self) -> None: + """Clean up test database.""" + Database.teardown() + # Clear user ID + self.user_id = None + + def test_init_db(self) -> None: + """Verify all tables and indexes are created correctly.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Check tables exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + self.assertIn("queue", tables) + self.assertIn("episodes", tables) + self.assertIn("users", tables) + + # Check indexes exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='index'") + indexes = {row[0] for row in cursor.fetchall()} + self.assertIn("idx_queue_status", indexes) + self.assertIn("idx_queue_created", indexes) + self.assertIn("idx_episodes_created", indexes) + self.assertIn("idx_queue_user_id", indexes) + self.assertIn("idx_episodes_user_id", indexes) + + def test_connection_context_manager(self) -> None: + """Ensure connections are properly closed.""" + # Get a connection and verify it works + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchone() + self.assertEqual(result[0], 1) + + # Connection should be closed after context manager + with pytest.raises(sqlite3.ProgrammingError): + cursor.execute("SELECT 1") + + def test_migration_idempotency(self) -> None: + """Verify migrations can run multiple times safely.""" + # Run migration multiple times + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + Database.migrate_to_multi_user() + + # Should still work fine + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users") + # Should not raise an error + # Test completed successfully - migration worked + self.assertIsNotNone(conn) + + def test_get_metrics_summary_extended(self) -> None: + """Verify extended metrics summary.""" + # Create some data + user_id, _ = Database.create_user("test@example.com") + Database.create_episode( + "Test Article", + "url", + 100, + 1000, + user_id, + ) + + # Create a queue item + Database.add_to_queue( + "https://example.com", + "test@example.com", + user_id, + ) + + metrics = Database.get_metrics_summary() + + self.assertIn("total_users", metrics) + self.assertIn("active_subscriptions", metrics) + self.assertIn("submissions_24h", metrics) + self.assertIn("submissions_7d", metrics) + + self.assertEqual(metrics["total_users"], 1) + self.assertEqual(metrics["submissions_24h"], 1) + self.assertEqual(metrics["submissions_7d"], 1) + + +class TestUserManagement(Test.TestCase): + """Test user management functionality.""" + + @staticmethod + def setUp() -> None: + """Set up test database.""" + Database.init_db() + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_create_user(self) -> None: + """Create user with unique email and token.""" + user_id, token = Database.create_user("test@example.com") + + self.assertIsInstance(user_id, int) + self.assertIsInstance(token, str) + self.assertGreater(len(token), 20) # Should be a secure token + + def test_create_duplicate_user(self) -> None: + """Verify duplicate emails return existing user.""" + # Create first user + user_id1, token1 = Database.create_user( + "test@example.com", + ) + + # Try to create duplicate + user_id2, token2 = Database.create_user( + "test@example.com", + ) + + # Should return same user + self.assertIsNotNone(user_id1) + self.assertIsNotNone(user_id2) + self.assertEqual(user_id1, user_id2) + self.assertEqual(token1, token2) + + def test_get_user_by_email(self) -> None: + """Retrieve user by email.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_email("test@example.com") + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_get_user_by_token(self) -> None: + """Retrieve user by RSS token.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_token(token) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["id"], user_id) + self.assertEqual(user["email"], "test@example.com") + + def test_get_user_by_id(self) -> None: + """Retrieve user by ID.""" + user_id, token = Database.create_user("test@example.com") + + user = Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + if user is None: + self.fail("User should not be None") + self.assertEqual(user["email"], "test@example.com") + self.assertEqual(user["token"], token) + + def test_invalid_user_lookups(self) -> None: + """Verify None returned for non-existent users.""" + self.assertIsNone( + Database.get_user_by_email("nobody@example.com"), + ) + self.assertIsNone( + Database.get_user_by_token("invalid-token"), + ) + self.assertIsNone(Database.get_user_by_id(9999)) + + def test_token_uniqueness(self) -> None: + """Ensure tokens are cryptographically unique.""" + tokens = set() + for i in range(10): + _, token = Database.create_user( + f"user{i}@example.com", + ) + tokens.add(token) + + # All tokens should be unique + self.assertEqual(len(tokens), 10) + + def test_delete_user(self) -> None: + """Test user deletion and cleanup.""" + # Create user + user_id, _ = Database.create_user("delete_me@example.com") + + # Create some data for the user + Database.add_to_queue( + "https://example.com/article", + "delete_me@example.com", + user_id, + ) + + ep_id = Database.create_episode( + title="Test Episode", + audio_url="url", + duration=100, + content_length=1000, + user_id=user_id, + ) + Database.add_episode_to_user(user_id, ep_id) + Database.track_episode_metric(ep_id, "played", user_id) + + # Delete user + Database.delete_user(user_id) + + # Verify user is gone + self.assertIsNone(Database.get_user_by_id(user_id)) + + # Verify queue items are gone + queue = Database.get_user_queue_status(user_id) + self.assertEqual(len(queue), 0) + + # Verify episodes are gone (direct lookup) + self.assertIsNone(Database.get_episode_by_id(ep_id)) + + def test_update_user_email(self) -> None: + """Update user email address.""" + user_id, _ = Database.create_user("old@example.com") + + # Update email + Database.update_user_email(user_id, "new@example.com") + + # Verify update + user = Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + if user: + self.assertEqual(user["email"], "new@example.com") + + # Old email should not exist + self.assertIsNone(Database.get_user_by_email("old@example.com")) + + @staticmethod + def test_update_user_email_duplicate() -> None: + """Cannot update to an existing email.""" + user_id1, _ = Database.create_user("user1@example.com") + Database.create_user("user2@example.com") + + # Try to update user1 to user2's email + with pytest.raises(ValueError, match="already taken"): + Database.update_user_email(user_id1, "user2@example.com") + + +class TestQueueOperations(Test.TestCase): + """Test queue operations.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_add_to_queue(self) -> None: + """Add job with user association.""" + job_id = Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + self.assertIsInstance(job_id, int) + self.assertGreater(job_id, 0) + + def test_get_pending_jobs(self) -> None: + """Retrieve jobs in correct order.""" + # Add multiple jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) # Ensure different timestamps + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Database.get_pending_jobs(limit=10) + + self.assertEqual(len(jobs), 3) + # Should be in order of creation (oldest first) + self.assertEqual(jobs[0]["id"], job1) + self.assertEqual(jobs[1]["id"], job2) + self.assertEqual(jobs[2]["id"], job3) + + def test_update_job_status(self) -> None: + """Update status and error messages.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Update to processing + Database.update_job_status(job_id, "processing") + job = Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "processing") + + # Update to error with message + Database.update_job_status( + job_id, + "error", + "Network timeout", + ) + job = Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "error") + self.assertEqual(job["error_message"], "Network timeout") + self.assertEqual(job["retry_count"], 1) + + def test_retry_job(self) -> None: + """Reset failed jobs for retry.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Set to error + Database.update_job_status(job_id, "error", "Failed") + + # Retry + Database.retry_job(job_id) + job = Database.get_job_by_id(job_id) + + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertIsNone(job["error_message"]) + + def test_delete_job(self) -> None: + """Remove jobs from queue.""" + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + + # Delete job + Database.delete_job(job_id) + + # Should not exist + job = Database.get_job_by_id(job_id) + self.assertIsNone(job) + + def test_get_retryable_jobs(self) -> None: + """Find jobs eligible for retry.""" + # Add job and mark as error + job_id = Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + Database.update_job_status(job_id, "error", "Failed") + + # Should be retryable + retryable = Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + self.assertEqual(retryable[0]["id"], job_id) + + # Exceed retry limit + Database.update_job_status( + job_id, + "error", + "Failed again", + ) + Database.update_job_status( + job_id, + "error", + "Failed yet again", + ) + + # Should not be retryable anymore + retryable = Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_user_queue_isolation(self) -> None: + """Ensure users only see their own jobs.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com") + + # Add jobs for both users + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "user2@example.com", + user2_id, + ) + + # Get user-specific queue status + user1_jobs = Database.get_user_queue_status(self.user_id) + user2_jobs = Database.get_user_queue_status(user2_id) + + self.assertEqual(len(user1_jobs), 1) + self.assertEqual(user1_jobs[0]["id"], job1) + + self.assertEqual(len(user2_jobs), 1) + self.assertEqual(user2_jobs[0]["id"], job2) + + def test_status_counts(self) -> None: + """Verify status aggregation queries.""" + # Add jobs with different statuses + Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + Database.update_job_status(job2, "processing") + Database.update_job_status(job3, "error", "Failed") + + # Get status counts + counts = Database.get_user_status_counts(self.user_id) + + self.assertEqual(counts.get("pending", 0), 1) + self.assertEqual(counts.get("processing", 0), 1) + self.assertEqual(counts.get("error", 0), 1) + + def test_queue_position(self) -> None: + """Verify queue position calculation.""" + # Add multiple pending jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Check positions + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertEqual(Database.get_queue_position(job2), 2) + self.assertEqual(Database.get_queue_position(job3), 3) + + # Move job 2 to processing + Database.update_job_status(job2, "processing") + + # Check positions (job 3 should now be 2nd pending job) + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertIsNone(Database.get_queue_position(job2)) + self.assertEqual(Database.get_queue_position(job3), 2) + + +class TestEpisodeManagement(Test.TestCase): + """Test episode management functionality.""" + + def setUp(self) -> None: + """Set up test database with a user.""" + Database.init_db() + self.user_id, _ = Database.create_user("test@example.com") + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Database.teardown() + + def test_create_episode(self) -> None: + """Create episode with user association.""" + episode_id = Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + ) + + self.assertIsInstance(episode_id, int) + self.assertGreater(episode_id, 0) + + def test_get_recent_episodes(self) -> None: + """Retrieve episodes in reverse chronological order.""" + # Create multiple episodes + ep1 = Database.create_episode( + "Article 1", + "url1", + 100, + 1000, + self.user_id, + ) + time.sleep(0.01) + ep2 = Database.create_episode( + "Article 2", + "url2", + 200, + 2000, + self.user_id, + ) + time.sleep(0.01) + ep3 = Database.create_episode( + "Article 3", + "url3", + 300, + 3000, + self.user_id, + ) + + # Get recent episodes + episodes = Database.get_recent_episodes(limit=10) + + self.assertEqual(len(episodes), 3) + # Should be in reverse chronological order + self.assertEqual(episodes[0]["id"], ep3) + self.assertEqual(episodes[1]["id"], ep2) + self.assertEqual(episodes[2]["id"], ep1) + + def test_get_user_episodes(self) -> None: + """Ensure user isolation for episodes.""" + # Create second user + user2_id, _ = Database.create_user("user2@example.com") + + # Create episodes for both users + ep1 = Database.create_episode( + "User1 Article", + "url1", + 100, + 1000, + self.user_id, + ) + ep2 = Database.create_episode( + "User2 Article", + "url2", + 200, + 2000, + user2_id, + ) + + # Get user-specific episodes + user1_episodes = Database.get_user_all_episodes( + self.user_id, + ) + user2_episodes = Database.get_user_all_episodes(user2_id) + + self.assertEqual(len(user1_episodes), 1) + self.assertEqual(user1_episodes[0]["id"], ep1) + + self.assertEqual(len(user2_episodes), 1) + self.assertEqual(user2_episodes[0]["id"], ep2) + + def test_episode_metadata(self) -> None: + """Verify duration and content_length storage.""" + Database.create_episode( + title="Test Article", + audio_url="https://example.com/audio.mp3", + duration=12345, + content_length=98765, + user_id=self.user_id, + ) + + episodes = Database.get_user_all_episodes(self.user_id) + episode = episodes[0] + + self.assertEqual(episode["duration"], 12345) + self.assertEqual(episode["content_length"], 98765) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestDatabase, + TestUserManagement, + TestQueueOperations, + TestEpisodeManagement, + ], + ) + + +def main() -> None: + """Run all PodcastItLater.Core tests.""" + if "test" in sys.argv: + test() diff --git a/Biz/PodcastItLater/DESIGN.md b/Biz/PodcastItLater/DESIGN.md new file mode 100644 index 0000000..29c4837 --- /dev/null +++ b/Biz/PodcastItLater/DESIGN.md @@ -0,0 +1,43 @@ +# PodcastItLater Design & Architecture + +## Overview +Service converting web articles to podcast episodes via email/web submission. + +## Architecture +- **Web**: `Biz/PodcastItLater/Web.py` (Ludic + HTMX + Starlette) +- **Worker**: `Biz/PodcastItLater/Worker.py` (Background processing) +- **Core**: `Biz/PodcastItLater/Core.py` (DB & Shared Logic) +- **Billing**: `Biz/PodcastItLater/Billing.py` (Stripe Integration) + +## Key Features +1. **User Management**: Email-based magic links, RSS tokens. +2. **Article Processing**: Trafilatura extraction -> LLM cleanup -> TTS. +3. **Billing (In Progress)**: Stripe Checkout/Portal, Freemium model. + +## Path to Paid Product (Epic: t-143KQl2) + +### 1. Billing Infrastructure +- **Stripe**: Use Stripe Checkout for subs, Portal for management. +- **Webhooks**: Handle `checkout.session.completed`, `customer.subscription.*`. +- **Tiers**: + - `Free`: 10 articles/month. + - `Paid`: Unlimited (initially). + +### 2. Usage Tracking +- **Table**: `users` table needs `plan_tier`, `subscription_status`, `stripe_customer_id`. +- **Logic**: Check usage count vs tier limit before allowing submission. +- **Reset**: Usage counters reset at billing period boundary. + +### 3. Admin Dashboard +- View all users and their status. +- Manually retry/delete jobs. +- View metrics (signups, conversions). + +## UX Polish (Epic: t-1vIPJYG) +- **Mobile**: Ensure all pages work on mobile. +- **Feedback**: Better error messages for failed URLs. +- **Navigation**: Clean up navbar, account management access. + +## Audio Improvements +- **Intro/Outro**: Add metadata-rich intro ("Title by Author"). +- **Sound Design**: Crossfade intro music. diff --git a/Biz/PodcastItLater/Episode.py b/Biz/PodcastItLater/Episode.py new file mode 100644 index 0000000..7090c70 --- /dev/null +++ b/Biz/PodcastItLater/Episode.py @@ -0,0 +1,390 @@ +""" +PodcastItLater Episode Detail Components. + +Components for displaying individual episode pages with media player, +share functionality, and signup prompts for non-authenticated users. +""" + +# : out podcastitlater-episode +# : dep ludic +import Biz.PodcastItLater.UI as UI +import ludic.html as html +import sys +import typing +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from typing import override + + +class EpisodePlayerAttrs(Attrs): + """Attributes for EpisodePlayer component.""" + + audio_url: str + title: str + episode_id: int + + +class EpisodePlayer(Component[AnyChildren, EpisodePlayerAttrs]): + """HTML5 audio player for episode playback.""" + + @override + def render(self) -> html.div: + audio_url = self.attrs["audio_url"] + episode_id = self.attrs["episode_id"] + + return html.div( + html.div( + html.div( + html.h5( + html.i(classes=["bi", "bi-play-circle", "me-2"]), + "Listen", + classes=["card-title", "mb-3"], + ), + html.audio( + html.source(src=audio_url, type="audio/mpeg"), + "Your browser does not support the audio element.", + controls=True, + preload="metadata", + id=f"audio-player-{episode_id}", + classes=["w-100"], + style={"max-width": "100%"}, + ), + # JavaScript to track play events + html.script( + f""" + (function() {{ + var player = document.getElementById( + 'audio-player-{episode_id}' + ); + var hasTrackedPlay = false; + + player.addEventListener('play', function() {{ + // Track first play only + if (!hasTrackedPlay) {{ + hasTrackedPlay = true; + + // Send play event to server + fetch('/episode/{episode_id}/track', {{ + method: 'POST', + headers: {{ + 'Content-Type': + 'application/x-www-form-urlencoded' + }}, + body: 'event_type=played' + }}); + }} + }}); + }})(); + """, + ), + classes=["card-body"], + ), + classes=["card", "mb-4"], + ), + ) + + +class ShareButtonAttrs(Attrs): + """Attributes for ShareButton component.""" + + share_url: str + + +class ShareButton(Component[AnyChildren, ShareButtonAttrs]): + """Button to copy episode URL to clipboard.""" + + @override + def render(self) -> html.div: + share_url = self.attrs["share_url"] + + return html.div( + html.div( + html.div( + html.h5( + html.i(classes=["bi", "bi-share", "me-2"]), + "Share Episode", + classes=["card-title", "mb-3"], + ), + html.div( + html.div( + html.button( + html.i(classes=["bi", "bi-copy", "me-1"]), + "Copy", + type="button", + id="share-button", + on_click=f"navigator.clipboard.writeText('{share_url}'); " # noqa: E501 + "const btn = document.getElementById('share-button'); " # noqa: E501 + "const originalHTML = btn.innerHTML; " + "btn.innerHTML = '<i class=\"bi bi-check me-1\"></i>Copied!'; " # noqa: E501 + "btn.classList.remove('btn-outline-secondary'); " # noqa: E501 + "btn.classList.add('btn-success'); " + "setTimeout(() => {{ " + "btn.innerHTML = originalHTML; " + "btn.classList.remove('btn-success'); " + "btn.classList.add('btn-outline-secondary'); " + "}}, 2000);", + classes=["btn", "btn-outline-secondary"], + ), + html.input( + type="text", + value=share_url, + readonly=True, + on_focus="this.select()", + classes=["form-control"], + ), + classes=["input-group"], + ), + ), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + + +class SignupBannerAttrs(Attrs): + """Attributes for SignupBanner component.""" + + creator_email: str + base_url: str + + +class SignupBanner(Component[AnyChildren, SignupBannerAttrs]): + """Banner prompting non-authenticated users to sign up.""" + + @override + def render(self) -> html.div: + return html.div( + html.div( + html.div( + html.div( + html.i( + classes=[ + "bi", + "bi-info-circle-fill", + "me-2", + ], + ), + html.strong( + "This episode was created using PodcastItLater.", + ), + classes=["mb-3"], + ), + html.div( + html.p( + "Want to convert your own articles " + "to podcast episodes?", + classes=["mb-2"], + ), + html.form( + html.div( + html.input( + type="email", + name="email", + placeholder="Enter your email to start", + required=True, + classes=["form-control"], + ), + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-right-circle", + "me-2", + ], + ), + "Sign Up", + type="submit", + classes=["btn", "btn-primary"], + ), + classes=["input-group"], + ), + hx_post="/login", + hx_target="#signup-result", + hx_swap="innerHTML", + ), + html.div(id="signup-result", classes=["mt-2"]), + ), + classes=["card-body"], + ), + classes=["card", "border-primary"], + ), + classes=["mb-4"], + ) + + +class EpisodeDetailPageAttrs(Attrs): + """Attributes for EpisodeDetailPage component.""" + + episode: dict[str, typing.Any] + episode_sqid: str + creator_email: str | None + user: dict[str, typing.Any] | None + base_url: str + user_has_episode: bool + + +class EpisodeDetailPage(Component[AnyChildren, EpisodeDetailPageAttrs]): + """Full page view for a single episode.""" + + @override + def render(self) -> UI.PageLayout: + episode = self.attrs["episode"] + episode_sqid = self.attrs["episode_sqid"] + creator_email = self.attrs.get("creator_email") + user = self.attrs.get("user") + base_url = self.attrs["base_url"] + user_has_episode = self.attrs.get("user_has_episode", False) + + share_url = f"{base_url}/episode/{episode_sqid}" + duration_str = UI.format_duration(episode.get("duration")) + + # Build page title + page_title = f"{episode['title']} - PodcastItLater" + + # Build meta tags for Open Graph + meta_tags = [ + html.meta(property="og:title", content=episode["title"]), + html.meta(property="og:type", content="website"), + html.meta(property="og:url", content=share_url), + html.meta( + property="og:description", + content=f"Listen to this article read aloud. " + f"Duration: {duration_str}" + + (f" by {episode['author']}" if episode.get("author") else ""), + ), + html.meta( + property="og:site_name", + content="PodcastItLater", + ), + html.meta(property="og:audio", content=episode["audio_url"]), + html.meta(property="og:audio:type", content="audio/mpeg"), + ] + + # Add Twitter Card tags + meta_tags.extend([ + html.meta(name="twitter:card", content="summary"), + html.meta(name="twitter:title", content=episode["title"]), + html.meta( + name="twitter:description", + content=f"Listen to this article. Duration: {duration_str}", + ), + ]) + + # Add author if available + if episode.get("author"): + meta_tags.append( + html.meta(property="article:author", content=episode["author"]), + ) + + return UI.PageLayout( + # Show signup banner if user is not logged in + SignupBanner( + creator_email=creator_email or "a user", + base_url=base_url, + ) + if not user and creator_email + else html.div(), + # Episode title and metadata + html.div( + html.h2( + episode["title"], + classes=["display-6", "mb-3"], + ), + html.div( + html.span( + html.i(classes=["bi", "bi-person", "me-1"]), + f"by {episode['author']}", + classes=["text-muted", "me-3"], + ) + if episode.get("author") + else html.span(), + html.span( + html.i(classes=["bi", "bi-clock", "me-1"]), + f"Duration: {duration_str}", + classes=["text-muted", "me-3"], + ), + html.span( + html.i(classes=["bi", "bi-calendar", "me-1"]), + f"Created: {episode['created_at']}", + classes=["text-muted"], + ), + classes=["mb-3"], + ), + html.div( + html.a( + html.i(classes=["bi", "bi-link-45deg", "me-1"]), + "View original article", + href=episode["original_url"], + target="_blank", + rel="noopener", + classes=["btn", "btn-sm", "btn-outline-secondary"], + ), + ) + if episode.get("original_url") + else html.div(), + classes=["mb-4"], + ), + # Audio player + EpisodePlayer( + audio_url=episode["audio_url"], + title=episode["title"], + episode_id=episode["id"], + ), + # Share button + ShareButton(share_url=share_url), + # Add to feed button (logged-in users without episode) + html.div( + html.div( + html.div( + html.h5( + html.i(classes=["bi", "bi-plus-circle", "me-2"]), + "Add to Your Feed", + classes=["card-title", "mb-3"], + ), + html.p( + "Save this episode to your personal feed " + "to listen later.", + classes=["text-muted", "mb-3"], + ), + html.button( + html.i(classes=["bi", "bi-plus-lg", "me-1"]), + "Add to My Feed", + hx_post=f"/episode/{episode['id']}/add-to-feed", + hx_target="#add-to-feed-result", + hx_swap="innerHTML", + classes=["btn", "btn-primary"], + ), + html.div(id="add-to-feed-result", classes=["mt-2"]), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + if user and not user_has_episode + else html.div(), + # Back to home link + html.div( + html.a( + html.i(classes=["bi", "bi-arrow-left", "me-1"]), + "Back to Home", + href="/", + classes=["btn", "btn-link"], + ), + classes=["mt-4"], + ), + user=user, + current_page="", + error=None, + page_title=page_title, + meta_tags=meta_tags, + ) + + +def main() -> None: + """Episode module has no tests currently.""" + if "test" in sys.argv: + sys.exit(0) diff --git a/Biz/PodcastItLater/INFRASTRUCTURE.md b/Biz/PodcastItLater/INFRASTRUCTURE.md new file mode 100644 index 0000000..1c61618 --- /dev/null +++ b/Biz/PodcastItLater/INFRASTRUCTURE.md @@ -0,0 +1,38 @@ +# Infrastructure Setup for PodcastItLater + +## Mailgun Setup + +Since PodcastItLater requires sending transactional emails (magic links), we use Mailgun. + +### 1. Sign up for Mailgun +Sign up at [mailgun.com](https://www.mailgun.com/). + +### 2. Add Domain +Add `podcastitlater.com` (or `mg.podcastitlater.com`) to Mailgun. +We recommend using the root domain `podcastitlater.com` if you want emails to come from `@podcastitlater.com`. + +### 3. Configure DNS +Mailgun will provide DNS records to verify the domain and authorize email sending. You must add these to your DNS provider (e.g., Cloudflare, Namecheap). + +Required records usually include: +- **TXT** (SPF): `v=spf1 include:mailgun.org ~all` +- **TXT** (DKIM): `k=rsa; p=...` (Provided by Mailgun) +- **MX** (if receiving email, optional for just sending): `10 mxa.mailgun.org`, `10 mxb.mailgun.org` +- **CNAME** (for tracking, optional): `email.podcastitlater.com` -> `mailgun.org` + +### 4. Verify Domain +Click "Verify DNS Settings" in Mailgun dashboard. This may take up to 24 hours but is usually instant. + +### 5. Generate API Key / SMTP Credentials +Go to "Sending" -> "Domain Settings" -> "SMTP Credentials". +Create a new SMTP user (e.g., `postmaster@podcastitlater.com`). +**Save the password immediately.** + +### 6. Update Secrets +Update the production secrets file on the server (`/run/podcastitlater/env`): + +```bash +SMTP_SERVER=smtp.mailgun.org +SMTP_PASSWORD=your-new-smtp-password +EMAIL_FROM=noreply@podcastitlater.com +``` diff --git a/Biz/PodcastItLater/STRIPE_TESTING.md b/Biz/PodcastItLater/STRIPE_TESTING.md new file mode 100644 index 0000000..1461c06 --- /dev/null +++ b/Biz/PodcastItLater/STRIPE_TESTING.md @@ -0,0 +1,114 @@ +# Stripe Testing Guide + +## Testing Stripe Integration Without Real Transactions + +### 1. Use Stripe Test Mode + +Stripe provides test API keys that allow you to simulate payments without real money: + +1. Get your test keys from https://dashboard.stripe.com/test/apikeys +2. Set environment variables with test keys: + ```bash + export STRIPE_SECRET_KEY="sk_test_..." + export STRIPE_WEBHOOK_SECRET="whsec_test_..." + export STRIPE_PRICE_ID_PRO="price_test_..." + ``` + +### 2. Use Stripe Test Cards + +In test mode, use these test card numbers: +- **Success**: `4242 4242 4242 4242` +- **Decline**: `4000 0000 0000 0002` +- **3D Secure**: `4000 0025 0000 3155` + +Any future expiry date and any 3-digit CVC will work. + +### 3. Trigger Test Webhooks + +Use Stripe CLI to trigger webhook events locally: + +```bash +# Install Stripe CLI +# https://stripe.com/docs/stripe-cli + +# Login +stripe login + +# Forward webhooks to local server +stripe listen --forward-to localhost:8000/stripe/webhook + +# Trigger specific events +stripe trigger checkout.session.completed +stripe trigger customer.subscription.created +stripe trigger customer.subscription.updated +stripe trigger invoice.payment_failed +``` + +### 4. Run Unit Tests + +The billing module includes unit tests that mock Stripe webhooks: + +```bash +# Run billing tests +AREA=Test python3 Biz/PodcastItLater/Billing.py test + +# Or use bild +bild --test Biz/PodcastItLater/Billing.py +``` + +### 5. Test Migration on Production + +To fix the production database missing columns issue, you need to trigger the migration. + +The migration runs automatically when `Database.init_db()` is called, but production may have an old database. + +**Option A: Restart the web service** +The init_db() runs on startup, so restarting should apply migrations. + +**Option B: Run migration manually** +```bash +# SSH to production +# Run Python REPL with proper environment +python3 +>>> import os +>>> os.environ["AREA"] = "Live" +>>> os.environ["DATA_DIR"] = "/var/podcastitlater" +>>> import Biz.PodcastItLater.Core as Core +>>> Core.Database.init_db() +``` + +### 6. Verify Database Schema + +Check that billing columns exist: + +```bash +sqlite3 /var/podcastitlater/podcast.db +.schema users +``` + +Should show: +- `stripe_customer_id TEXT` +- `stripe_subscription_id TEXT` +- `subscription_status TEXT` +- `current_period_start TEXT` +- `current_period_end TEXT` +- `plan_tier TEXT NOT NULL DEFAULT 'free'` +- `cancel_at_period_end INTEGER DEFAULT 0` + +### 7. End-to-End Test Flow + +1. Start in test mode: `AREA=Test PORT=8000 python3 Biz/PodcastItLater/Web.py` +2. Login with test account +3. Go to /billing +4. Click "Upgrade Now" +5. Use test card: 4242 4242 4242 4242 +6. Stripe CLI will forward webhook to your local server +7. Verify subscription updated in database + +### 8. Common Issues + +**KeyError in webhook**: Make sure you're using safe `.get()` access for all Stripe object fields, as the structure can vary. + +**Database column missing**: Run migrations by restarting the service or calling `Database.init_db()`. + +**Webhook signature verification fails**: Make sure `STRIPE_WEBHOOK_SECRET` matches your endpoint secret from Stripe dashboard. diff --git a/Biz/PodcastItLater/TESTING.md b/Biz/PodcastItLater/TESTING.md new file mode 100644 index 0000000..2911610 --- /dev/null +++ b/Biz/PodcastItLater/TESTING.md @@ -0,0 +1,45 @@ +# PodcastItLater Testing Strategy + +## Overview +We use `pytest` with `Omni.Test` integration. Tests are co-located with code or in `Biz/PodcastItLater/Test.py` for E2E. + +## Test Categories + +### 1. Core (Database/Logic) +- **Location**: `Biz/PodcastItLater/Core.py` +- **Scope**: User creation, Job queue ops, Episode management. +- **Key Tests**: + - `test_create_user`: Unique tokens. + - `test_queue_isolation`: Users see only their jobs. + +### 2. Web (HTTP/UI) +- **Location**: `Biz/PodcastItLater/Web.py` +- **Scope**: Routes, Auth, HTMX responses. +- **Key Tests**: + - `test_submit_requires_auth`. + - `test_rss_feed_xml`. + - `test_admin_access_control`. + +### 3. Worker (Processing) +- **Location**: `Biz/PodcastItLater/Worker.py` +- **Scope**: Extraction, TTS, S3 upload. +- **Key Tests**: + - `test_extract_content`: Mocked network calls. + - `test_tts_chunking`: Handle long text. + - **Error Handling**: Ensure retries work and errors are logged. + +### 4. Billing (Stripe) +- **Location**: `Biz/PodcastItLater/Billing.py` +- **Scope**: Webhook processing, Entitlement checks. +- **Key Tests**: + - `test_webhook_subscription_update`: Update local DB. + - `test_enforce_limits`: Block submission if over limit. + +## Running Tests +```bash +# Run all +bild --test Biz/PodcastItLater.hs + +# Run specific file +./Biz/PodcastItLater/Web.py test +``` diff --git a/Biz/PodcastItLater/Test.py b/Biz/PodcastItLater/Test.py new file mode 100644 index 0000000..86b04f4 --- /dev/null +++ b/Biz/PodcastItLater/Test.py @@ -0,0 +1,276 @@ +"""End-to-end tests for PodcastItLater.""" + +# : dep boto3 +# : dep botocore +# : dep feedgen +# : dep httpx +# : dep itsdangerous +# : dep ludic +# : dep openai +# : dep psutil +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep sqids +# : dep starlette +# : dep stripe +# : dep trafilatura +# : dep uvicorn +# : out podcastitlater-e2e-test +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.UI as UI +import Biz.PodcastItLater.Web as Web +import Biz.PodcastItLater.Worker as Worker +import Omni.App as App +import Omni.Test as Test +import pathlib +import re +import sys +import unittest.mock +from starlette.testclient import TestClient + + +class BaseWebTest(Test.TestCase): + """Base test class with common setup.""" + + def setUp(self) -> None: + """Set up test environment.""" + self.app = Web.app + self.client = TestClient(self.app) + + # Initialize database for each test + Core.Database.init_db() + + @staticmethod + def tearDown() -> None: + """Clean up after each test.""" + Core.Database.teardown() + + +class TestEndToEnd(BaseWebTest): + """Test complete end-to-end flows.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + + # Create and login user + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + def test_full_article_to_rss_flow(self) -> None: # noqa: PLR0915 + """Test complete flow: submit URL → process → appears in RSS feed.""" + # Step 1: Submit article URL + response = self.client.post( + "/submit", + data={"url": "https://example.com/great-article"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Article submitted successfully", response.text) + + # Extract job ID from response + match = re.search(r"Job ID: (\d+)", response.text) + self.assertIsNotNone(match) + if match is None: + self.fail("Job ID not found in response") + job_id = int(match.group(1)) + + # Verify job was created + job = Core.Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "pending") + self.assertEqual(job["user_id"], self.user_id) + + # Step 2: Process the job with mocked external services + shutdown_handler = Worker.ShutdownHandler() + processor = Worker.ArticleProcessor(shutdown_handler) + + # Mock external dependencies + mock_audio_data = b"fake-mp3-audio-content-12345" + + with ( + unittest.mock.patch.object( + Worker.ArticleProcessor, + "extract_article_content", + return_value=( + "Great Article Title", + "This is the article content.", + ), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["This is the article content."], + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + unittest.mock.patch.object( + processor.openai_client.audio.speech, + "create", + ) as mock_tts, + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://cdn.example.com/episode_123_Great_Article.mp3", + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + ) as mock_audio_segment, + unittest.mock.patch( + "pathlib.Path.read_bytes", + return_value=mock_audio_data, + ), + ): + # Mock TTS response + mock_tts_response = unittest.mock.MagicMock() + mock_tts_response.content = mock_audio_data + mock_tts.return_value = mock_tts_response + + # Mock audio segment + mock_segment = unittest.mock.MagicMock() + mock_segment.export = lambda path, **_kwargs: pathlib.Path( + path, + ).write_bytes( + mock_audio_data, + ) + mock_audio_segment.return_value = mock_segment + + # Process the pending job + Worker.process_pending_jobs(processor) + + # Step 3: Verify job was marked completed + job = Core.Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: + self.fail("Job should not be None") + self.assertEqual(job["status"], "completed") + + # Step 4: Verify episode was created + episodes = Core.Database.get_user_all_episodes(self.user_id) + self.assertEqual(len(episodes), 1) + + episode = episodes[0] + self.assertEqual(episode["title"], "Great Article Title") + self.assertEqual( + episode["audio_url"], + "https://cdn.example.com/episode_123_Great_Article.mp3", + ) + self.assertGreater(episode["duration"], 0) + self.assertEqual(episode["user_id"], self.user_id) + + # Step 5: Verify episode appears in RSS feed + response = self.client.get(f"/feed/{self.token}.rss") + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.headers["content-type"], + "application/rss+xml; charset=utf-8", + ) + + # Check RSS contains the episode + self.assertIn("Great Article Title", response.text) + self.assertIn( + "https://cdn.example.com/episode_123_Great_Article.mp3", + response.text, + ) + self.assertIn("<enclosure", response.text) + self.assertIn('type="audio/mpeg"', response.text) + + # Step 6: Verify only this user's episode is in their feed + # Create another user with their own episode + user2_id, token2 = Core.Database.create_user("other@example.com") + Core.Database.create_episode( + "Other User's Article", + "https://cdn.example.com/other.mp3", + 200, + 3000, + user2_id, + ) + + # Original user's feed should not contain other user's episode + response = self.client.get(f"/feed/{self.token}.rss") + self.assertIn("Great Article Title", response.text) + self.assertNotIn("Other User's Article", response.text) + + # Other user's feed should only contain their episode + response = self.client.get(f"/feed/{token2}.rss") + self.assertNotIn("Great Article Title", response.text) + self.assertIn("Other User's Article", response.text) + + +class TestUI(Test.TestCase): + """Test UI components.""" + + def test_render_navbar(self) -> None: + """Test navbar rendering.""" + user = {"email": "test@example.com", "id": 1} + layout = UI.PageLayout( + user=user, + current_page="home", + error=None, + page_title="Test", + meta_tags=[], + ) + navbar = layout._render_navbar(user, "home") # noqa: SLF001 + html_output = navbar.to_html() + + # Check basic structure + self.assertIn("navbar", html_output) + self.assertIn("Home", html_output) + self.assertIn("Public Feed", html_output) + self.assertIn("Pricing", html_output) + self.assertIn("Manage Account", html_output) + + # Check active state + self.assertIn("active", html_output) + + # Check non-admin user doesn't see admin menu + self.assertNotIn("Admin", html_output) + + def test_render_navbar_admin(self) -> None: + """Test navbar rendering for admin.""" + user = {"email": "ben@bensima.com", "id": 1} # Admin email + layout = UI.PageLayout( + user=user, + current_page="admin", + error=None, + page_title="Test", + meta_tags=[], + ) + navbar = layout._render_navbar(user, "admin") # noqa: SLF001 + html_output = navbar.to_html() + + # Check admin menu present + self.assertIn("Admin", html_output) + self.assertIn("Queue Status", html_output) + + +def test() -> None: + """Run all end-to-end tests.""" + Test.run( + App.Area.Test, + [ + TestEndToEnd, + TestUI, + ], + ) + + +def main() -> None: + """Run the tests.""" + if "test" in sys.argv: + test() + else: + test() diff --git a/Biz/PodcastItLater/TestMetricsView.py b/Biz/PodcastItLater/TestMetricsView.py new file mode 100644 index 0000000..c6fdd46 --- /dev/null +++ b/Biz/PodcastItLater/TestMetricsView.py @@ -0,0 +1,121 @@ +"""Tests for Admin metrics view.""" + +# : out podcastitlater-test-metrics +# : dep pytest +# : dep starlette +# : dep httpx +# : dep ludic +# : dep feedgen +# : dep itsdangerous +# : dep uvicorn +# : dep stripe +# : dep sqids + +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.Web as Web +import Omni.Test as Test +from starlette.testclient import TestClient + + +class BaseWebTest(Test.TestCase): + """Base class for web tests.""" + + def setUp(self) -> None: + """Set up test database and client.""" + Core.Database.init_db() + self.client = TestClient(Web.app) + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Core.Database.teardown() + + +class TestMetricsView(BaseWebTest): + """Test Admin Metrics View.""" + + def test_admin_metrics_view_access(self) -> None: + """Admin user should be able to access metrics view.""" + # Create admin user + _admin_id, _ = Core.Database.create_user("ben@bensima.com") + self.client.post("/login", data={"email": "ben@bensima.com"}) + + response = self.client.get("/admin/metrics") + self.assertEqual(response.status_code, 200) + self.assertIn("Growth & Usage", response.text) + self.assertIn("Total Users", response.text) + + def test_admin_metrics_data(self) -> None: + """Metrics view should show correct data.""" + # Create admin user + admin_id, _ = Core.Database.create_user("ben@bensima.com") + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Create some data + # 1. Users + Core.Database.create_user("user1@example.com") + user2_id, _ = Core.Database.create_user("user2@example.com") + + # 2. Subscriptions (simulate by setting subscription_status) + with Core.Database.get_connection() as conn: + conn.execute( + "UPDATE users SET subscription_status = 'active' WHERE id = ?", + (user2_id,), + ) + conn.commit() + + # 3. Submissions + Core.Database.add_to_queue( + "http://example.com/1", + "user1@example.com", + admin_id, + ) + + # Get metrics page + response = self.client.get("/admin/metrics") + self.assertEqual(response.status_code, 200) + + # Check labels + self.assertIn("Total Users", response.text) + self.assertIn("Active Subs", response.text) + self.assertIn("Submissions (24h)", response.text) + + # Check values (metrics dict is passed to template, + # we check rendered HTML) + # Total users: 3 (admin + user1 + user2) + # Active subs: 1 (user2) + # Submissions 24h: 1 + + # Check for values in HTML + # Note: This is a bit brittle, but effective for quick verification + self.assertIn('<h3 class="mb-0">3</h3>', response.text) + self.assertIn('<h3 class="mb-0">1</h3>', response.text) + + def test_non_admin_access_denied(self) -> None: + """Non-admin users should be denied access.""" + # Create regular user + Core.Database.create_user("regular@example.com") + self.client.post("/login", data={"email": "regular@example.com"}) + + response = self.client.get("/admin/metrics") + # Should redirect to /?error=forbidden + self.assertEqual(response.status_code, 302) + self.assertIn("error=forbidden", response.headers["Location"]) + + def test_anonymous_access_redirect(self) -> None: + """Anonymous users should be redirected to login.""" + response = self.client.get("/admin/metrics") + self.assertEqual(response.status_code, 302) + self.assertEqual(response.headers["Location"], "/") + + +def main() -> None: + """Run the tests.""" + Test.run( + Web.area, + [TestMetricsView], + ) + + +if __name__ == "__main__": + main() diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py new file mode 100644 index 0000000..e9ef27d --- /dev/null +++ b/Biz/PodcastItLater/UI.py @@ -0,0 +1,755 @@ +""" +PodcastItLater Shared UI Components. + +Common UI components and utilities shared across web pages. +""" + +# : lib +# : dep ludic +import Biz.PodcastItLater.Core as Core +import ludic.html as html +import typing +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from typing import override + + +def format_duration(seconds: int | None) -> str: + """Format duration from seconds to human-readable format. + + Examples: + 300 -> "5m" + 3840 -> "1h 4m" + 11520 -> "3h 12m" + """ + if seconds is None or seconds <= 0: + return "Unknown" + + # Constants for time conversion + seconds_per_minute = 60 + minutes_per_hour = 60 + seconds_per_hour = 3600 + + # Round up to nearest minute + minutes = (seconds + seconds_per_minute - 1) // seconds_per_minute + + # Show as minutes only if under 60 minutes (exclusive) + # 3599 seconds rounds up to 60 minutes, which we keep as "60m" + if minutes <= minutes_per_hour: + # If exactly 3600 seconds (already 60 full minutes without rounding) + if seconds >= seconds_per_hour: + return "1h" + return f"{minutes}m" + + hours = minutes // minutes_per_hour + remaining_minutes = minutes % minutes_per_hour + + if remaining_minutes == 0: + return f"{hours}h" + + return f"{hours}h {remaining_minutes}m" + + +def create_bootstrap_styles() -> html.style: + """Load Bootstrap CSS and icons.""" + return html.style( + "@import url('https://cdn.jsdelivr.net/npm/bootstrap@5.3.2" + "/dist/css/bootstrap.min.css');" + "@import url('https://cdn.jsdelivr.net/npm/bootstrap-icons" + "@1.11.3/font/bootstrap-icons.min.css');", + ) + + +def create_auto_dark_mode_style() -> html.style: + """Create CSS for automatic dark mode based on prefers-color-scheme.""" + return html.style( + """ + /* Auto dark mode - applies Bootstrap dark theme via media query */ + @media (prefers-color-scheme: dark) { + :root { + color-scheme: dark; + --bs-body-color: #dee2e6; + --bs-body-color-rgb: 222, 226, 230; + --bs-body-bg: #212529; + --bs-body-bg-rgb: 33, 37, 41; + --bs-emphasis-color: #fff; + --bs-emphasis-color-rgb: 255, 255, 255; + --bs-secondary-color: rgba(222, 226, 230, 0.75); + --bs-secondary-bg: #343a40; + --bs-tertiary-color: rgba(222, 226, 230, 0.5); + --bs-tertiary-bg: #2b3035; + --bs-heading-color: inherit; + --bs-link-color: #6ea8fe; + --bs-link-hover-color: #8bb9fe; + --bs-link-color-rgb: 110, 168, 254; + --bs-link-hover-color-rgb: 139, 185, 254; + --bs-code-color: #e685b5; + --bs-border-color: #495057; + --bs-border-color-translucent: rgba(255, 255, 255, 0.15); + } + + /* Navbar dark mode */ + .navbar.bg-body-tertiary { + background-color: #2b3035 !important; + } + + .navbar .navbar-text { + color: #dee2e6 !important; + } + + /* Table header dark mode */ + .table-light { + --bs-table-bg: #343a40; + --bs-table-color: #dee2e6; + background-color: #343a40 !important; + color: #dee2e6 !important; + } + } + """, + ) + + +def create_htmx_script() -> html.script: + """Load HTMX library.""" + return html.script( + src="https://unpkg.com/htmx.org@1.9.10", + integrity="sha384-D1Kt99CQMDuVetoL1lrYwg5t+9QdHe7NLX/SoJYkXDFfX37iInKRy5xLSi8nO7UC", + crossorigin="anonymous", + ) + + +def create_bootstrap_js() -> html.script: + """Load Bootstrap JavaScript bundle.""" + return html.script( + src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/js/bootstrap.bundle.min.js", + integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL", + crossorigin="anonymous", + ) + + +class PageLayoutAttrs(Attrs): + """Attributes for PageLayout component.""" + + user: dict[str, typing.Any] | None + current_page: str + error: str | None + page_title: str | None + meta_tags: list[html.meta] | None + + +class PageLayout(Component[AnyChildren, PageLayoutAttrs]): + """Reusable page layout with header and navbar.""" + + @staticmethod + def _render_nav_item( + label: str, + href: str, + icon: str, + *, + is_active: bool, + ) -> html.li: + return html.li( + html.a( + html.i(classes=["bi", f"bi-{icon}", "me-1"]), + label, + href=href, + classes=[ + "nav-link", + "active" if is_active else "", + ], + ), + classes=["nav-item"], + ) + + @staticmethod + def _render_admin_dropdown( + is_active_func: typing.Callable[[str], bool], + ) -> html.li: + is_active = is_active_func("admin") or is_active_func("admin-users") + return html.li( + html.a( # type: ignore[call-arg] + html.i(classes=["bi", "bi-gear-fill", "me-1"]), + "Admin", + href="#", + id="adminDropdown", + role="button", + data_bs_toggle="dropdown", + aria_expanded="false", + classes=[ + "nav-link", + "dropdown-toggle", + "active" if is_active else "", + ], + ), + html.ul( # type: ignore[call-arg] + html.li( + html.a( + html.i(classes=["bi", "bi-list-task", "me-2"]), + "Queue Status", + href="/admin", + classes=["dropdown-item"], + ), + ), + html.li( + html.a( + html.i(classes=["bi", "bi-people-fill", "me-2"]), + "Manage Users", + href="/admin/users", + classes=["dropdown-item"], + ), + ), + html.li( + html.a( + html.i(classes=["bi", "bi-graph-up", "me-2"]), + "Metrics", + href="/admin/metrics", + classes=["dropdown-item"], + ), + ), + classes=["dropdown-menu"], + aria_labelledby="adminDropdown", + ), + classes=["nav-item", "dropdown"], + ) + + @staticmethod + def _render_navbar( + user: dict[str, typing.Any] | None, + current_page: str, + ) -> html.nav: + """Render navigation bar.""" + + def is_active(page: str) -> bool: + return current_page == page + + return html.nav( + html.div( + html.button( # type: ignore[call-arg] + html.span(classes=["navbar-toggler-icon"]), + classes=["navbar-toggler", "ms-auto"], + type="button", + data_bs_toggle="collapse", + data_bs_target="#navbarNav", + aria_controls="navbarNav", + aria_expanded="false", + aria_label="Toggle navigation", + ), + html.div( + html.ul( + PageLayout._render_nav_item( + "Home", + "/", + "house-fill", + is_active=is_active("home"), + ), + PageLayout._render_nav_item( + "Public Feed", + "/public", + "globe", + is_active=is_active("public"), + ), + PageLayout._render_nav_item( + "Pricing", + "/pricing", + "stars", + is_active=is_active("pricing"), + ), + PageLayout._render_nav_item( + "Manage Account", + "/account", + "person-circle", + is_active=is_active("account"), + ), + PageLayout._render_admin_dropdown(is_active) + if user and Core.is_admin(user) + else html.span(), + classes=["navbar-nav"], + ), + id="navbarNav", + classes=["collapse", "navbar-collapse"], + ), + classes=["container-fluid"], + ), + classes=[ + "navbar", + "navbar-expand-lg", + "bg-body-tertiary", + "rounded", + "mb-4", + ], + ) + + @override + def render(self) -> html.html: + user = self.attrs.get("user") + current_page = self.attrs.get("current_page", "") + error = self.attrs.get("error") + page_title = self.attrs.get("page_title") or "PodcastItLater" + meta_tags = self.attrs.get("meta_tags") or [] + + return html.html( + html.head( + html.meta(charset="utf-8"), + html.meta( + name="viewport", + content="width=device-width, initial-scale=1", + ), + html.meta( + name="color-scheme", + content="light dark", + ), + html.title(page_title), + *meta_tags, + create_htmx_script(), + ), + html.body( + create_bootstrap_styles(), + create_auto_dark_mode_style(), + html.div( + html.div( + html.h1( + "PodcastItLater", + classes=["display-4", "mb-2"], + ), + html.p( + "Convert web articles to podcast episodes", + classes=["lead", "text-muted"], + ), + classes=["text-center", "mb-4", "pt-4"], + ), + html.div( + html.div( + html.i( + classes=[ + "bi", + "bi-exclamation-triangle-fill", + "me-2", + ], + ), + error or "", + classes=[ + "alert", + "alert-danger", + "d-flex", + "align-items-center", + ], + role="alert", # type: ignore[call-arg] + ), + ) + if error + else html.div(), + self._render_navbar(user, current_page) + if user + else html.div(), + *self.children, + classes=["container", "px-3", "px-md-4"], + style={"max-width": "900px"}, + ), + create_bootstrap_js(), + ), + ) + + +class AccountPageAttrs(Attrs): + """Attributes for AccountPage component.""" + + user: dict[str, typing.Any] + usage: dict[str, int] + limits: dict[str, int | None] + portal_url: str | None + + +class AccountPage(Component[AnyChildren, AccountPageAttrs]): + """Account management page component.""" + + @override + def render(self) -> PageLayout: + user = self.attrs["user"] + usage = self.attrs["usage"] + limits = self.attrs["limits"] + portal_url = self.attrs["portal_url"] + + plan_tier = user.get("plan_tier", "free") + is_paid = plan_tier == "paid" + + article_limit = limits.get("articles_per_period") + article_usage = usage.get("articles", 0) + + limit_text = ( + "Unlimited" if article_limit is None else str(article_limit) + ) + + usage_percent = 0 + if article_limit: + usage_percent = min(100, int((article_usage / article_limit) * 100)) + + progress_style = ( + {"width": f"{usage_percent}%"} if article_limit else {"width": "0%"} + ) + + return PageLayout( + html.div( + html.div( + html.div( + html.div( + html.div( + html.h2( + html.i( + classes=[ + "bi", + "bi-person-circle", + "me-2", + ], + ), + "My Account", + classes=["card-title", "mb-4"], + ), + # User Info Section + html.div( + html.h5("Profile", classes=["mb-3"]), + html.div( + html.strong("Email: "), + html.span(user.get("email", "")), + html.button( + "Change", + classes=[ + "btn", + "btn-sm", + "btn-outline-secondary", + "ms-2", + "py-0", + ], + hx_get="/settings/email/edit", + hx_target="closest div", + hx_swap="outerHTML", + ), + classes=[ + "mb-2", + "d-flex", + "align-items-center", + ], + ), + html.p( + html.strong("Member since: "), + user.get("created_at", "").split("T")[ + 0 + ], + classes=["mb-4"], + ), + classes=["mb-5"], + ), + # Subscription Section + html.div( + html.h5("Subscription", classes=["mb-3"]), + html.div( + html.div( + html.strong("Current Plan"), + html.span( + plan_tier.title(), + classes=[ + "badge", + "bg-success" + if is_paid + else "bg-secondary", + "ms-2", + ], + ), + classes=[ + "d-flex", + "align-items-center", + "mb-3", + ], + ), + # Usage Stats + html.div( + html.p( + "Usage this period:", + classes=["mb-2", "text-muted"], + ), + html.div( + html.div( + f"{article_usage} / " + f"{limit_text}", + classes=["mb-1"], + ), + html.div( + html.div( + classes=[ + "progress-bar", + ], + role="progressbar", # type: ignore[call-arg] + style=progress_style, # type: ignore[arg-type] + ), + classes=[ + "progress", + "mb-3", + ], + style={"height": "10px"}, + ) + if article_limit + else html.div(), + classes=["mb-3"], + ), + ), + # Actions + html.div( + html.form( + html.button( + html.i( + classes=[ + "bi", + "bi-credit-card", + "me-2", + ], + ), + "Manage Subscription", + type="submit", + classes=[ + "btn", + "btn-outline-primary", + ], + ), + method="post", + action=portal_url, + ) + if is_paid and portal_url + else html.a( + html.i( + classes=[ + "bi", + "bi-star-fill", + "me-2", + ], + ), + "Upgrade to Pro", + href="/pricing", + classes=["btn", "btn-primary"], + ), + classes=["d-flex", "gap-2"], + ), + classes=[ + "card", + "card-body", + "bg-light", + ], + ), + classes=["mb-5"], + ), + # Logout Section + html.div( + html.form( + html.button( + html.i( + classes=[ + "bi", + "bi-box-arrow-right", + "me-2", + ], + ), + "Log Out", + type="submit", + classes=[ + "btn", + "btn-outline-danger", + ], + ), + action="/logout", + method="post", + ), + classes=["border-top", "pt-4"], + ), + # Delete Account Section + html.div( + html.h5( + "Danger Zone", + classes=["text-danger", "mb-3"], + ), + html.div( + html.h6("Delete Account"), + html.p( + "Once you delete your account, " + "there is no going back. " + "Please be certain.", + classes=["card-text"], + ), + html.button( + html.i( + classes=[ + "bi", + "bi-trash", + "me-2", + ], + ), + "Delete Account", + hx_delete="/account", + hx_confirm=( + "Are you absolutely sure you " + "want to delete your account? " + "This action cannot be undone." + ), + classes=["btn", "btn-danger"], + ), + classes=[ + "card", + "card-body", + "border-danger", + ], + ), + classes=["mt-5", "pt-4", "border-top"], + ), + classes=["card-body", "p-4"], + ), + classes=["card", "shadow-sm"], + ), + classes=["col-lg-8", "mx-auto"], + ), + classes=["row"], + ), + ), + user=user, + current_page="account", + page_title="Account - PodcastItLater", + error=None, + meta_tags=[], + ) + + +class PricingPageAttrs(Attrs): + """Attributes for PricingPage component.""" + + user: dict[str, typing.Any] | None + + +class PricingPage(Component[AnyChildren, PricingPageAttrs]): + """Pricing page component.""" + + @override + def render(self) -> PageLayout: + user = self.attrs.get("user") + current_tier = user.get("plan_tier", "free") if user else "free" + + return PageLayout( + html.div( + html.div( + # Free Tier + html.div( + html.div( + html.div( + html.h3("Free", classes=["card-title"]), + html.h4( + "$0", + classes=[ + "card-subtitle", + "mb-3", + "text-muted", + ], + ), + html.p( + "10 articles total", + classes=["card-text"], + ), + html.ul( + html.li("Convert 10 articles"), + html.li("Basic features"), + classes=["list-unstyled", "mb-4"], + ), + html.button( + "Current Plan", + classes=[ + "btn", + "btn-outline-primary", + "w-100", + ], + disabled=True, + ) + if current_tier == "free" + else html.div(), + classes=["card-body"], + ), + classes=["card", "mb-4", "shadow-sm", "h-100"], + ), + classes=["col-md-6"], + ), + # Paid Tier + html.div( + html.div( + html.div( + html.h3( + "Unlimited", + classes=["card-title"], + ), + html.h4( + "$12/mo", + classes=[ + "card-subtitle", + "mb-3", + "text-muted", + ], + ), + html.p( + "Unlimited articles", + classes=["card-text"], + ), + html.ul( + html.li("Unlimited conversions"), + html.li("Priority processing"), + html.li("Support independent software"), + classes=["list-unstyled", "mb-4"], + ), + html.form( + html.button( + "Upgrade Now", + type="submit", + classes=[ + "btn", + "btn-primary", + "w-100", + ], + ), + action="/upgrade", + method="post", + ) + if user and current_tier == "free" + else ( + html.button( + "Current Plan", + classes=[ + "btn", + "btn-success", + "w-100", + ], + disabled=True, + ) + if user and current_tier == "paid" + else html.a( + "Login to Upgrade", + href="/", + classes=[ + "btn", + "btn-primary", + "w-100", + ], + ) + ), + classes=["card-body"], + ), + classes=[ + "card", + "mb-4", + "shadow-sm", + "border-primary", + "h-100", + ], + ), + classes=["col-md-6"], + ), + classes=["row"], + ), + ), + user=user, + current_page="pricing", + page_title="Pricing - PodcastItLater", + error=None, + meta_tags=[], + ) diff --git a/Biz/PodcastItLater/Web.nix b/Biz/PodcastItLater/Web.nix new file mode 100644 index 0000000..7533ca4 --- /dev/null +++ b/Biz/PodcastItLater/Web.nix @@ -0,0 +1,93 @@ +{ + options, + lib, + config, + ... +}: let + cfg = config.services.podcastitlater-web; + rootDomain = "podcastitlater.com"; + ports = import ../../Omni/Cloud/Ports.nix; +in { + options.services.podcastitlater-web = { + enable = lib.mkEnableOption "Enable the PodcastItLater web service"; + port = lib.mkOption { + type = lib.types.int; + default = 8000; + description = '' + The port on which PodcastItLater web will listen for + incoming HTTP traffic. + ''; + }; + dataDir = lib.mkOption { + type = lib.types.path; + default = "/var/podcastitlater"; + description = "Data directory for PodcastItLater (shared with worker)"; + }; + package = lib.mkOption { + type = lib.types.package; + description = "PodcastItLater web package to use"; + }; + }; + config = lib.mkIf cfg.enable { + systemd.services.podcastitlater-web = { + path = [cfg.package]; + wantedBy = ["multi-user.target"]; + preStart = '' + # Create data directory if it doesn't exist + mkdir -p ${cfg.dataDir} + + # Manual step: create this file with secrets + # SECRET_KEY=your-secret-key-for-sessions + # SESSION_SECRET=your-session-secret + # EMAIL_FROM=noreply@podcastitlater.com + # SMTP_SERVER=smtp.mailgun.org + # SMTP_PASSWORD=your-smtp-password + # STRIPE_SECRET_KEY=sk_live_your_stripe_secret_key + # STRIPE_WEBHOOK_SECRET=whsec_your_webhook_secret + # STRIPE_PRICE_ID_PRO=price_your_pro_price_id + test -f /run/podcastitlater/env + ''; + script = '' + ${cfg.package}/bin/podcastitlater-web + ''; + description = '' + PodcastItLater Web Service + ''; + serviceConfig = { + Environment = [ + "PORT=${toString cfg.port}" + "AREA=Live" + "DATA_DIR=${cfg.dataDir}" + "BASE_URL=https://${rootDomain}" + ]; + EnvironmentFile = "/run/podcastitlater/env"; + KillSignal = "INT"; + Type = "simple"; + Restart = "on-abort"; + RestartSec = "1"; + }; + }; + + # Nginx configuration + services.nginx = { + enable = true; + recommendedGzipSettings = true; + recommendedOptimisation = true; + recommendedProxySettings = true; + recommendedTlsSettings = true; + statusPage = true; + + virtualHosts."${rootDomain}" = { + forceSSL = true; + enableACME = true; + locations."/" = { + proxyPass = "http://127.0.0.1:${toString cfg.port}"; + proxyWebsockets = true; + }; + }; + }; + + # Ensure firewall allows web traffic + networking.firewall.allowedTCPPorts = [ports.ssh ports.http ports.https]; + }; +} diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py new file mode 100644 index 0000000..30b5236 --- /dev/null +++ b/Biz/PodcastItLater/Web.py @@ -0,0 +1,3480 @@ +""" +PodcastItLater Web Service. + +Web frontend for converting articles to podcast episodes. +Provides ludic + htmx interface and RSS feed generation. +""" + +# : out podcastitlater-web +# : dep ludic +# : dep feedgen +# : dep httpx +# : dep itsdangerous +# : dep uvicorn +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep starlette +# : dep stripe +# : dep sqids +import Biz.EmailAgent +import Biz.PodcastItLater.Admin as Admin +import Biz.PodcastItLater.Billing as Billing +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.Episode as Episode +import Biz.PodcastItLater.UI as UI +import html as html_module +import httpx +import logging +import ludic.html as html +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import pathlib +import re +import sys +import tempfile +import typing +import urllib.parse +import uvicorn +from datetime import datetime +from datetime import timezone +from feedgen.feed import FeedGenerator # type: ignore[import-untyped] +from itsdangerous import URLSafeTimedSerializer +from ludic.attrs import Attrs +from ludic.components import Component +from ludic.types import AnyChildren +from ludic.web import LudicApp +from ludic.web import Request +from ludic.web.datastructures import FormData +from ludic.web.responses import Response +from sqids import Sqids +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import RedirectResponse +from starlette.testclient import TestClient +from typing import override +from unittest.mock import patch + +logger = logging.getLogger(__name__) +Log.setup(logger) + + +# Configuration +area = App.from_env() +BASE_URL = os.getenv("BASE_URL", "http://localhost:8000") +PORT = int(os.getenv("PORT", "8000")) + +# Initialize sqids for episode URL encoding +sqids = Sqids(min_length=8) + + +def encode_episode_id(episode_id: int) -> str: + """Encode episode ID to sqid for URLs.""" + return str(sqids.encode([episode_id])) + + +def decode_episode_id(sqid: str) -> int | None: + """Decode sqid to episode ID. Returns None if invalid.""" + try: + decoded = sqids.decode(sqid) + return decoded[0] if decoded else None + except (ValueError, IndexError): + return None + + +# Authentication configuration +MAGIC_LINK_MAX_AGE = 3600 # 1 hour +SESSION_MAX_AGE = 30 * 24 * 3600 # 30 days +EMAIL_FROM = os.getenv("EMAIL_FROM", "noreply@podcastitlater.com") +SMTP_SERVER = os.getenv("SMTP_SERVER", "smtp.mailgun.org") +SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "") + +# Initialize serializer for magic links +magic_link_serializer = URLSafeTimedSerializer( + os.getenv("SECRET_KEY", "dev-secret-key"), +) + + +RSS_CONFIG = { + "author": "PodcastItLater", + "language": "en-US", + "base_url": BASE_URL, +} + + +def extract_og_metadata(url: str) -> tuple[str | None, str | None]: + """Extract Open Graph title and author from URL. + + Returns: + tuple: (title, author) - both may be None if extraction fails + """ + try: + # Use httpx to fetch the page with a timeout + response = httpx.get(url, timeout=10.0, follow_redirects=True) + response.raise_for_status() + + # Simple regex-based extraction to avoid heavy dependencies + html_content = response.text + + # Extract og:title + title_match = re.search( + r'<meta\s+(?:property|name)=["\']og:title["\']\s+content=["\'](.*?)["\']', + html_content, + re.IGNORECASE, + ) + title = title_match.group(1) if title_match else None + + # Extract author - try article:author first, then og:site_name + author_match = re.search( + r'<meta\s+(?:property|name)=["\']article:author["\']\s+content=["\'](.*?)["\']', + html_content, + re.IGNORECASE, + ) + if not author_match: + author_match = re.search( + r'<meta\s+(?:property|name)=["\']og:site_name["\']\s+content=["\'](.*?)["\']', + html_content, + re.IGNORECASE, + ) + author = author_match.group(1) if author_match else None + + # Clean up HTML entities + if title: + title = html_module.unescape(title) + if author: + author = html_module.unescape(author) + + except (httpx.HTTPError, httpx.TimeoutException, re.error) as e: + logger.warning("Failed to extract metadata from %s: %s", url, e) + return None, None + else: + return title, author + + +def send_magic_link(email: str, token: str) -> None: + """Send magic link email to user.""" + subject = "Login to PodcastItLater" + + # Create temporary file for email body + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".txt", + delete=False, + encoding="utf-8", + ) as f: + body_text_path = pathlib.Path(f.name) + + # Create email body + magic_link = f"{BASE_URL}/auth/verify?token={token}" + body_text_path.write_text(f""" +Hello, + +Click this link to login to PodcastItLater: +{magic_link} + +This link will expire in 1 hour. + +If you didn't request this, please ignore this email. + +Best, +PodcastItLater +""") + + try: + Biz.EmailAgent.send_email( + to_addrs=[email], + from_addr=EMAIL_FROM, + smtp_server=SMTP_SERVER, + password=SMTP_PASSWORD, + subject=subject, + body_text=body_text_path, + ) + finally: + # Clean up temporary file + body_text_path.unlink(missing_ok=True) + + +class LoginFormAttrs(Attrs): + """Attributes for LoginForm component.""" + + error: str | None + + +class LoginForm(Component[AnyChildren, LoginFormAttrs]): + """Simple email-based login/registration form.""" + + @override + def render(self) -> html.div: + error = self.attrs.get("error") + is_dev_mode = App.from_env() == App.Area.Test + + return html.div( + # Dev mode banner + html.div( + html.div( + html.i(classes=["bi", "bi-info-circle", "me-2"]), + html.strong("Dev/Test Mode: "), + "Use ", + html.code( + "demo@example.com", + classes=["text-dark", "mx-1"], + ), + " for instant login", + classes=[ + "alert", + "alert-info", + "d-flex", + "align-items-center", + "mb-3", + ], + ), + ) + if is_dev_mode + else html.div(), + html.div( + html.div( + html.h4( + html.i(classes=["bi", "bi-envelope-fill", "me-2"]), + "Login / Register", + classes=["card-title", "mb-3"], + ), + html.form( + html.div( + html.label( + "Email address", + for_="email", + classes=["form-label"], + ), + html.input( + type="email", + id="email", + name="email", + placeholder="your@email.com", + value="demo@example.com" if is_dev_mode else "", + required=True, + classes=["form-control", "mb-3"], + ), + ), + html.button( + html.i( + classes=["bi", "bi-arrow-right-circle", "me-2"], + ), + "Continue", + type="submit", + classes=["btn", "btn-primary", "w-100"], + ), + hx_post="/login", + hx_target="#login-result", + hx_swap="innerHTML", + ), + html.div( + error or "", + id="login-result", + classes=["mt-3"], + ), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + + +class SubmitForm(Component[AnyChildren, Attrs]): + """Article submission form with HTMX.""" + + @override + def render(self) -> html.div: + return html.div( + html.div( + html.div( + html.h4( + html.i(classes=["bi", "bi-file-earmark-plus", "me-2"]), + "Submit Article", + classes=["card-title", "mb-3"], + ), + html.form( + html.div( + html.label( + "Article URL", + for_="url", + classes=["form-label"], + ), + html.div( + html.input( + type="url", + id="url", + name="url", + placeholder="https://example.com/article", + required=True, + classes=["form-control"], + on_focus="this.select()", + ), + html.button( + html.i(classes=["bi", "bi-send-fill"]), + type="submit", + classes=["btn", "btn-primary"], + ), + classes=["input-group", "mb-3"], + ), + ), + hx_post="/submit", + hx_target="#submit-result", + hx_swap="innerHTML", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "document.getElementById('url').value = ''" + ), + ), + html.div(id="submit-result", classes=["mt-2"]), + classes=["card-body"], + ), + classes=["card"], + ), + classes=["mb-4"], + ) + + +class QueueStatusAttrs(Attrs): + """Attributes for QueueStatus component.""" + + items: list[dict[str, typing.Any]] + + +class QueueStatus(Component[AnyChildren, QueueStatusAttrs]): + """Display queue items with auto-refresh.""" + + @override + def render(self) -> html.div: + items = self.attrs["items"] + if not items: + return html.div( + html.h4( + html.i(classes=["bi", "bi-list-check", "me-2"]), + "Queue Status", + classes=["mb-3"], + ), + html.p("No items in queue", classes=["text-muted"]), + ) + + # Map status to Bootstrap badge classes + status_classes = { + "pending": "bg-warning text-dark", + "processing": "bg-primary", + "extracting": "bg-info text-dark", + "synthesizing": "bg-primary", + "uploading": "bg-success", + "error": "bg-danger", + "cancelled": "bg-secondary", + } + + status_icons = { + "pending": "bi-clock", + "processing": "bi-arrow-repeat", + "extracting": "bi-file-text", + "synthesizing": "bi-mic", + "uploading": "bi-cloud-arrow-up", + "error": "bi-exclamation-triangle", + "cancelled": "bi-x-circle", + } + + queue_items = [] + for item in items: + badge_class = status_classes.get(item["status"], "bg-secondary") + icon_class = status_icons.get(item["status"], "bi-question-circle") + + # Get queue position for pending items + queue_pos = None + if item["status"] == "pending": + queue_pos = Core.Database.get_queue_position(item["id"]) + + queue_items.append( + html.div( + html.div( + html.div( + html.strong(f"#{item['id']}", classes=["me-2"]), + html.span( + html.i(classes=["bi", icon_class, "me-1"]), + item["status"].upper(), + classes=["badge", badge_class], + ), + classes=[ + "d-flex", + "align-items-center", + "justify-content-between", + ], + ), + # Add title and author if available + *( + [ + html.div( + html.strong( + item["title"], + classes=["d-block"], + ), + html.small( + f"by {item['author']}", + classes=["text-muted"], + ) + if item.get("author") + else html.span(), + classes=["mt-2"], + ), + ] + if item.get("title") + else [] + ), + html.small( + html.i(classes=["bi", "bi-link-45deg", "me-1"]), + item["url"][: Core.URL_TRUNCATE_LENGTH] + + ( + "..." + if len(item["url"]) > Core.URL_TRUNCATE_LENGTH + else "" + ), + classes=["text-muted", "d-block", "mt-2"], + ), + html.small( + html.i(classes=["bi", "bi-calendar", "me-1"]), + f"Created: {item['created_at']}", + classes=["text-muted", "d-block", "mt-1"], + ), + # Display queue position if available + html.small( + html.i( + classes=["bi", "bi-hourglass-split", "me-1"], + ), + f"Position in queue: #{queue_pos}", + classes=["text-info", "d-block", "mt-1"], + ) + if queue_pos + else html.span(), + *( + [ + html.div( + html.i( + classes=[ + "bi", + "bi-exclamation-circle", + "me-1", + ], + ), + f"Error: {item['error_message']}", + classes=[ + "alert", + "alert-danger", + "mt-2", + "mb-0", + "py-1", + "px-2", + "small", + ], + ), + ] + if item["error_message"] + else [] + ), + # Add cancel button for pending jobs, remove for others + html.div( + # Retry button for error items + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-clockwise", + "me-1", + ], + ), + "Retry", + hx_post=f"/queue/{item['id']}/retry", + hx_trigger="click", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + "mt-2", + "me-2", + ], + ) + if item["status"] == "error" + else html.span(), + html.button( + html.i(classes=["bi", "bi-x-lg", "me-1"]), + "Cancel", + hx_post=f"/queue/{item['id']}/cancel", + hx_trigger="click", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-danger", + "mt-2", + ], + ) + if item["status"] == "pending" + else html.button( + html.i(classes=["bi", "bi-trash", "me-1"]), + "Remove", + hx_delete=f"/queue/{item['id']}", + hx_trigger="click", + hx_confirm="Remove this item from the queue?", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-secondary", + "mt-2", + ], + ), + classes=["mt-2"], + ), + classes=["card-body"], + ), + classes=["card", "mb-2"], + ), + ) + + return html.div( + html.h4( + html.i(classes=["bi", "bi-list-check", "me-2"]), + "Queue Status", + classes=["mb-3"], + ), + *queue_items, + ) + + +class EpisodeListAttrs(Attrs): + """Attributes for EpisodeList component.""" + + episodes: list[dict[str, typing.Any]] + rss_url: str | None + user: dict[str, typing.Any] | None + viewing_own_feed: bool + + +class EpisodeList(Component[AnyChildren, EpisodeListAttrs]): + """List recent episodes (no audio player - use podcast app).""" + + @override + def render(self) -> html.div: + episodes = self.attrs["episodes"] + rss_url = self.attrs.get("rss_url") + user = self.attrs.get("user") + viewing_own_feed = self.attrs.get("viewing_own_feed", False) + + if not episodes: + return html.div( + html.h4( + html.i(classes=["bi", "bi-broadcast", "me-2"]), + "Recent Episodes", + classes=["mb-3"], + ), + html.p("No episodes yet", classes=["text-muted"]), + ) + + episode_items = [] + for episode in episodes: + duration_str = UI.format_duration(episode.get("duration")) + episode_sqid = encode_episode_id(episode["id"]) + is_public = episode.get("is_public", 0) == 1 + + # Admin "Add to public feed" button at bottom of card + admin_button: html.div | html.button = html.div() + if user and Core.is_admin(user): + if is_public: + admin_button = html.button( + html.i(classes=["bi", "bi-check-circle-fill", "me-1"]), + "Added to public feed", + hx_post=f"/admin/episode/{episode['id']}/toggle-public", + hx_target="body", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-success", "mt-2"], + ) + else: + admin_button = html.button( + html.i(classes=["bi", "bi-plus-circle", "me-1"]), + "Add to public feed", + hx_post=f"/admin/episode/{episode['id']}/toggle-public", + hx_target="body", + hx_swap="outerHTML", + classes=[ + "btn", + "btn-sm", + "btn-outline-success", + "mt-2", + ], + ) + + # "Add to my feed" button for logged-in users + # (only when NOT viewing own feed) + user_button: html.div | html.button = html.div() + if user and not viewing_own_feed: + # Check if user already has this episode + user_has_episode = Core.Database.user_has_episode( + user["id"], + episode["id"], + ) + if user_has_episode: + user_button = html.button( + html.i(classes=["bi", "bi-check-circle-fill", "me-1"]), + "In your feed", + disabled=True, + classes=[ + "btn", + "btn-sm", + "btn-secondary", + "mt-2", + "ms-2", + ], + ) + else: + user_button = html.button( + html.i(classes=["bi", "bi-plus-circle", "me-1"]), + "Add to my feed", + hx_post=f"/episode/{episode['id']}/add-to-feed", + hx_swap="none", + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + "mt-2", + "ms-2", + ], + ) + + episode_items.append( + html.div( + html.div( + html.h5( + html.a( + episode["title"], + href=f"/episode/{episode_sqid}", + classes=["text-decoration-none"], + ), + classes=["card-title", "mb-2"], + ), + # Show author if available + html.p( + html.i(classes=["bi", "bi-person", "me-1"]), + f"by {episode['author']}", + classes=["text-muted", "small", "mb-3"], + ) + if episode.get("author") + else html.div(), + html.div( + html.small( + html.i(classes=["bi", "bi-clock", "me-1"]), + f"Duration: {duration_str}", + classes=["text-muted", "me-3"], + ), + html.small( + html.i(classes=["bi", "bi-calendar", "me-1"]), + f"Created: {episode['created_at']}", + classes=["text-muted"], + ), + classes=["mb-2"], + ), + # Show link to original article if available + html.div( + html.a( + html.i(classes=["bi", "bi-link-45deg", "me-1"]), + "View original article", + href=episode["original_url"], + target="_blank", + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + ], + ), + ) + if episode.get("original_url") + else html.div(), + # Buttons row (admin and user buttons) + html.div( + admin_button, + user_button, + classes=["d-flex", "flex-wrap"], + ), + classes=["card-body"], + ), + classes=["card", "mb-3"], + ), + ) + + return html.div( + html.h4( + html.i(classes=["bi", "bi-broadcast", "me-2"]), + "Recent Episodes", + classes=["mb-3"], + ), + # RSS feed link with copy-to-clipboard + html.div( + html.div( + html.label( + html.i(classes=["bi", "bi-rss-fill", "me-2"]), + "Subscribe in your podcast app:", + classes=["form-label", "fw-bold"], + ), + html.div( + html.button( + html.i(classes=["bi", "bi-copy", "me-1"]), + "Copy", + type="button", + id="rss-copy-button", + on_click=f"navigator.clipboard.writeText('{rss_url}'); " # noqa: E501 + "const btn = document.getElementById('rss-copy-button'); " # noqa: E501 + "const originalHTML = btn.innerHTML; " + "btn.innerHTML = '<i class=\"bi bi-check me-1\"></i>Copied!'; " # noqa: E501 + "btn.classList.remove('btn-outline-secondary'); " + "btn.classList.add('btn-success'); " + "setTimeout(() => {{ " + "btn.innerHTML = originalHTML; " + "btn.classList.remove('btn-success'); " + "btn.classList.add('btn-outline-secondary'); " + "}}, 2000);", + classes=["btn", "btn-outline-secondary"], + ), + html.input( + type="text", + value=rss_url or "", + readonly=True, + on_focus="this.select()", + classes=["form-control"], + ), + classes=["input-group", "mb-3"], + ), + ), + ) + if rss_url + else html.div(), + *episode_items, + ) + + +class HomePageAttrs(Attrs): + """Attributes for HomePage component.""" + + queue_items: list[dict[str, typing.Any]] + episodes: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + error: str | None + + +class PublicFeedPageAttrs(Attrs): + """Attributes for PublicFeedPage component.""" + + episodes: list[dict[str, typing.Any]] + user: dict[str, typing.Any] | None + + +class PublicFeedPage(Component[AnyChildren, PublicFeedPageAttrs]): + """Public feed page without auto-refresh.""" + + @override + def render(self) -> UI.PageLayout: + episodes = self.attrs["episodes"] + user = self.attrs.get("user") + + return UI.PageLayout( + html.div( + html.h2( + html.i(classes=["bi", "bi-globe", "me-2"]), + "Public Feed", + classes=["mb-3"], + ), + html.p( + "Featured articles converted to audio by our community. " + "Subscribe to get new episodes in your podcast app!", + classes=["lead", "text-muted", "mb-4"], + ), + EpisodeList( + episodes=episodes, + rss_url=f"{BASE_URL}/public.rss", + user=user, + viewing_own_feed=False, + ), + ), + user=user, + current_page="public", + error=None, + ) + + +class HomePage(Component[AnyChildren, HomePageAttrs]): + """Main page combining all components.""" + + @staticmethod + def _render_plan_callout( + user: dict[str, typing.Any], + ) -> html.div: + """Render plan info callout box below navbar.""" + tier = user.get("plan_tier", "free") + + if tier == "free": + # Get usage and show quota + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(user["id"], period_start, period_end) + articles_used = usage["articles"] + articles_limit = 10 + articles_left = max(0, articles_limit - articles_used) + + return html.div( + html.div( + html.div( + html.i( + classes=[ + "bi", + "bi-info-circle-fill", + "me-2", + ], + ), + html.strong(f"{articles_left} articles remaining"), + " of your free plan limit. ", + html.br(), + "Upgrade to ", + html.strong("Paid Plan"), + " for unlimited articles at $12/month.", + ), + html.form( + html.input( + type="hidden", + name="tier", + value="paid", + ), + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-up-circle", + "me-1", + ], + ), + "Upgrade Now", + type="submit", + classes=[ + "btn", + "btn-success", + "btn-sm", + "mt-2", + ], + ), + method="post", + action="/billing/checkout", + ), + classes=[ + "alert", + "alert-info", + "d-flex", + "justify-content-between", + "align-items-center", + "mb-4", + ], + ), + classes=["mb-4"], + ) + # Paid user - no callout needed + return html.div() + + @override + def render(self) -> UI.PageLayout | html.html: + queue_items = self.attrs["queue_items"] + episodes = self.attrs["episodes"] + user = self.attrs.get("user") + error = self.attrs.get("error") + + if not user: + # Show public feed with login form for logged-out users + return UI.PageLayout( + LoginForm(error=error), + html.div( + html.h4( + html.i(classes=["bi", "bi-broadcast", "me-2"]), + "Public Feed", + classes=["mb-3", "mt-4"], + ), + html.p( + "Featured articles converted to audio. " + "Sign up to create your own personal feed!", + classes=["text-muted", "mb-3"], + ), + EpisodeList( + episodes=episodes, + rss_url=None, + user=None, + viewing_own_feed=False, + ), + ), + user=None, + current_page="home", + error=error, + ) + + return UI.PageLayout( + self._render_plan_callout(user), + SubmitForm(), + html.div( + QueueStatus(items=queue_items), + EpisodeList( + episodes=episodes, + rss_url=f"{BASE_URL}/feed/{user['token']}.rss", + user=user, + viewing_own_feed=True, + ), + id="dashboard-content", + hx_get="/dashboard-updates", + hx_trigger="every 3s, queue-updated from:body", + hx_swap="innerHTML", + ), + user=user, + current_page="home", + error=error, + ) + + +# Create ludic app with session support +app = LudicApp() +app.add_middleware( + SessionMiddleware, + secret_key=os.getenv("SESSION_SECRET", "dev-secret-key"), + max_age=SESSION_MAX_AGE, # 30 days + same_site="lax", + https_only=App.from_env() == App.Area.Live, # HTTPS only in production +) + + +@app.get("/") +def index(request: Request) -> HomePage: + """Display main page with form and status.""" + user_id = request.session.get("user_id") + user = None + queue_items = [] + episodes = [] + error = request.query_params.get("error") + status = request.query_params.get("status") + + # Map error codes to user-friendly messages + error_messages = { + "invalid_link": "Invalid login link", + "expired_link": "Login link has expired. Please request a new one.", + "user_not_found": "User not found. Please try logging in again.", + "forbidden": "Access denied. Admin privileges required.", + "cancel": "Checkout cancelled.", + } + + # Handle billing status messages + if status == "success": + error_message = None + elif status == "cancel": + error_message = error_messages["cancel"] + else: + error_message = error_messages.get(error) if error else None + + if user_id: + user = Core.Database.get_user_by_id(user_id) + if user: + # Get user-specific queue items and episodes + queue_items = Core.Database.get_user_queue_status( + user_id, + ) + episodes = Core.Database.get_user_episodes( + user_id, + ) + else: + # Show public feed when not logged in + episodes = Core.Database.get_public_episodes(10) + + return HomePage( + queue_items=queue_items, + episodes=episodes, + user=user, + error=error_message, + ) + + +@app.get("/public") +def public_feed(request: Request) -> PublicFeedPage: + """Display public feed page.""" + # Always show public episodes, whether user is logged in or not + episodes = Core.Database.get_public_episodes(50) + user_id = request.session.get("user_id") + user = Core.Database.get_user_by_id(user_id) if user_id else None + + return PublicFeedPage( + episodes=episodes, + user=user, + ) + + +@app.get("/pricing") +def pricing(request: Request) -> UI.PricingPage: + """Display pricing page.""" + user_id = request.session.get("user_id") + user = Core.Database.get_user_by_id(user_id) if user_id else None + + return UI.PricingPage( + user=user, + ) + + +@app.post("/upgrade") +def upgrade(request: Request) -> RedirectResponse: + """Start upgrade checkout flow.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + try: + checkout_url = Billing.create_checkout_session( + user_id, + "paid", + BASE_URL, + ) + return RedirectResponse(url=checkout_url, status_code=303) + except ValueError: + logger.exception("Failed to create checkout session") + return RedirectResponse(url="/pricing?error=checkout_failed") + + +@app.post("/logout") +def logout(request: Request) -> RedirectResponse: + """Log out user.""" + request.session.clear() + return RedirectResponse(url="/", status_code=303) + + +@app.post("/billing/portal") +def billing_portal(request: Request) -> RedirectResponse: + """Redirect to Stripe billing portal.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + try: + portal_url = Billing.create_portal_session(user_id, BASE_URL) + return RedirectResponse(url=portal_url, status_code=303) + except ValueError as e: + logger.warning("Failed to create portal session: %s", e) + # If user has no customer ID (e.g. free tier), redirect to pricing + return RedirectResponse(url="/pricing") + + +def _handle_test_login(email: str, request: Request) -> Response: + """Handle login in test mode.""" + # Special handling for demo account + is_demo_account = email == "demo@example.com" + + user = Core.Database.get_user_by_email(email) + if not user: + # Create new user + status = "active" + user_id, token = Core.Database.create_user(email, status=status) + user = { + "id": user_id, + "email": email, + "token": token, + "status": status, + } + elif is_demo_account and user.get("status") != "active": + # Auto-activate demo account if it exists but isn't active + Core.Database.update_user_status(user["id"], "active") + user["status"] = "active" + + # Check if user is active + if user.get("status") != "active": + pending_message = ( + '<div class="alert alert-warning">' + "Account created, currently pending. " + 'Email <a href="mailto:ben@bensima.com" ' + 'class="alert-link">ben@bensima.com</a> ' + 'or message <a href="https://x.com/bensima" ' + 'target="_blank" class="alert-link">@bensima</a> ' + "to get your account activated.</div>" + ) + return Response(pending_message, status_code=200) + + # Set session with extended lifetime + request.session["user_id"] = user["id"] + request.session["permanent"] = True + + return Response( + '<div class="alert alert-success">✓ Logged in (dev mode)</div>', + status_code=200, + headers={"HX-Redirect": "/"}, + ) + + +def _handle_production_login(email: str) -> Response: + """Handle login in production mode.""" + pending_message = ( + '<div class="alert alert-warning">' + "Account created, currently pending. " + 'Email <a href="mailto:ben@bensima.com" ' + 'class="alert-link">ben@bensima.com</a> ' + 'or message <a href="https://x.com/bensima" ' + 'target="_blank" class="alert-link">@bensima</a> ' + "to get your account activated.</div>" + ) + + # Get or create user + user = Core.Database.get_user_by_email(email) + if not user: + user_id, token = Core.Database.create_user(email) + user = { + "id": user_id, + "email": email, + "token": token, + "status": "active", + } + + # Check if user is active + if user.get("status") != "active": + return Response(pending_message, status_code=200) + + # Generate magic link token + magic_token = magic_link_serializer.dumps({ + "user_id": user["id"], + "email": email, + }) + + # Send email + send_magic_link(email, magic_token) + + return Response( + f'<div class="alert alert-success">✓ Magic link sent to {email}. ' + f"Check your email!</div>", + status_code=200, + ) + + +@app.post("/login") +def login(request: Request, data: FormData) -> Response: + """Handle login/registration.""" + try: + email_raw = data.get("email", "") + email = email_raw.strip().lower() if isinstance(email_raw, str) else "" + + if not email: + return Response( + '<div class="alert alert-danger">Email is required</div>', + status_code=400, + ) + + area = App.from_env() + + if area == App.Area.Test: + return _handle_test_login(email, request) + return _handle_production_login(email) + + except Exception as e: + logger.exception("Login error") + return Response( + f'<div class="alert alert-danger">Error: {e!s}</div>', + status_code=500, + ) + + +@app.get("/auth/verify") +def verify_magic_link(request: Request) -> Response: + """Verify magic link and log user in.""" + token = request.query_params.get("token") + + if not token: + return RedirectResponse("/?error=invalid_link") + + try: + # Verify token + data = magic_link_serializer.loads(token, max_age=MAGIC_LINK_MAX_AGE) + user_id = data["user_id"] + + # Verify user still exists + user = Core.Database.get_user_by_id(user_id) + if not user: + return RedirectResponse("/?error=user_not_found") + + # Set session with extended lifetime + request.session["user_id"] = user_id + request.session["permanent"] = True + + return RedirectResponse("/") + + except (ValueError, KeyError): + # Token is invalid or expired + return RedirectResponse("/?error=expired_link") + + +@app.get("/settings/email/edit") +def edit_email_form(request: Request) -> typing.Any: + """Return form to edit email.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id(user_id) + if not user: + return Response("User not found", status_code=404) + + return html.div( + html.form( + html.strong("Email: ", classes=["me-2"]), + html.input( + type="email", + name="email", + value=user["email"], + required=True, + classes=[ + "form-control", + "form-control-sm", + "d-inline-block", + "w-auto", + "me-2", + ], + ), + html.button( + "Save", + type="submit", + classes=["btn", "btn-sm", "btn-primary", "me-1"], + ), + html.button( + "Cancel", + hx_get="/settings/email/cancel", + hx_target="closest div", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-secondary"], + ), + hx_post="/settings/email", + hx_target="closest div", + hx_swap="outerHTML", + classes=["d-flex", "align-items-center"], + ), + classes=["mb-2"], + ) + + +@app.get("/settings/email/cancel") +def cancel_edit_email(request: Request) -> typing.Any: + """Cancel email editing and show original view.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + user = Core.Database.get_user_by_id(user_id) + if not user: + return Response("User not found", status_code=404) + + return html.div( + html.strong("Email: "), + html.span(user["email"]), + html.button( + "Change", + classes=[ + "btn", + "btn-sm", + "btn-outline-secondary", + "ms-2", + "py-0", + ], + hx_get="/settings/email/edit", + hx_target="closest div", + hx_swap="outerHTML", + ), + classes=["mb-2", "d-flex", "align-items-center"], + ) + + +@app.post("/settings/email") +def update_email(request: Request, data: FormData) -> typing.Any: + """Update user email.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + new_email_raw = data.get("email", "") + new_email = ( + new_email_raw.strip().lower() if isinstance(new_email_raw, str) else "" + ) + + if not new_email: + return Response("Email required", status_code=400) + + try: + Core.Database.update_user_email(user_id, new_email) + return cancel_edit_email(request) + except ValueError as e: + # Return form with error + return html.div( + html.form( + html.strong("Email: ", classes=["me-2"]), + html.input( + type="email", + name="email", + value=new_email, + required=True, + classes=[ + "form-control", + "form-control-sm", + "d-inline-block", + "w-auto", + "me-2", + "is-invalid", + ], + ), + html.button( + "Save", + type="submit", + classes=["btn", "btn-sm", "btn-primary", "me-1"], + ), + html.button( + "Cancel", + hx_get="/settings/email/cancel", + hx_target="closest div", + hx_swap="outerHTML", + classes=["btn", "btn-sm", "btn-secondary"], + ), + html.div( + str(e), + classes=["invalid-feedback", "d-block", "ms-2"], + ), + hx_post="/settings/email", + hx_target="closest div", + hx_swap="outerHTML", + classes=["d-flex", "align-items-center", "flex-wrap"], + ), + classes=["mb-2"], + ) + + +@app.get("/account") +def account_page(request: Request) -> typing.Any: + """Account management page.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + user = Core.Database.get_user_by_id(user_id) + if not user: + return RedirectResponse(url="/?error=user_not_found") + + # Get usage stats + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(user["id"], period_start, period_end) + + # Get limits + tier = user.get("plan_tier", "free") + limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"]) + + return UI.AccountPage( + user=user, + usage=usage, + limits=limits, + portal_url="/billing/portal" if tier == "paid" else None, + ) + + +@app.delete("/account") +def delete_account(request: Request) -> Response: + """Delete user account.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + Core.Database.delete_user(user_id) + request.session.clear() + + return Response( + "Account deleted", + headers={"HX-Redirect": "/?message=account_deleted"}, + ) + + +@app.post("/submit") +def submit_article( # noqa: PLR0911, PLR0914 + request: Request, + data: FormData, +) -> typing.Any: + """Handle manual form submission.""" + try: + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: Please login first", + classes=["alert", "alert-danger"], + ) + + user = Core.Database.get_user_by_id(user_id) + if not user: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: Invalid session", + classes=["alert", "alert-danger"], + ) + + url_raw = data.get("url", "") + url = url_raw.strip() if isinstance(url_raw, str) else "" + + if not url: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: URL is required", + classes=["alert", "alert-danger"], + ) + + # Basic URL validation + parsed = urllib.parse.urlparse(url) + if not parsed.scheme or not parsed.netloc: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + "Error: Invalid URL format", + classes=["alert", "alert-danger"], + ) + + # Check usage limits + allowed, _msg, usage = Billing.can_submit(user_id) + if not allowed: + tier = user.get("plan_tier", "free") + tier_info = Billing.get_tier_info(tier) + limit = tier_info.get("articles_limit", 0) + return html.div( + html.i(classes=["bi", "bi-exclamation-circle", "me-2"]), + html.strong("Limit reached: "), + f"You've used {usage['articles']}/{limit} articles " + "this period. ", + html.a( + "Upgrade your plan", + href="/billing", + classes=["alert-link"], + ), + " to continue.", + classes=["alert", "alert-warning"], + ) + + # Check if episode already exists for this URL + url_hash = Core.hash_url(url) + existing_episode = Core.Database.get_episode_by_url_hash(url_hash) + + if existing_episode: + # Episode already processed - check if user has it + episode_id = existing_episode["id"] + if Core.Database.user_has_episode(user_id, episode_id): + return html.div( + html.i(classes=["bi", "bi-info-circle", "me-2"]), + "This episode is already in your feed.", + classes=["alert", "alert-info"], + ) + # Add existing episode to user's feed + Core.Database.add_episode_to_user(user_id, episode_id) + Core.Database.track_episode_event( + episode_id, + "added", + user_id, + ) + return html.div( + html.i(classes=["bi", "bi-check-circle", "me-2"]), + "✓ Episode added to your feed! ", + html.a( + "View episode", + href=f"/episode/{encode_episode_id(episode_id)}", + classes=["alert-link"], + ), + classes=["alert", "alert-success"], + ) + + # Episode doesn't exist yet - extract metadata and queue for processing + title, author = extract_og_metadata(url) + + job_id = Core.Database.add_to_queue( + url, + user["email"], + user_id, + title=title, + author=author, + ) + return html.div( + html.i(classes=["bi", "bi-check-circle", "me-2"]), + f"✓ Article submitted successfully! Job ID: {job_id}", + classes=["alert", "alert-success"], + ) + + except (httpx.HTTPError, httpx.TimeoutException, ValueError) as e: + return html.div( + html.i(classes=["bi", "bi-exclamation-triangle", "me-2"]), + f"Error: {e!s}", + classes=["alert", "alert-danger"], + ) + + +@app.get("/feed/{token}.rss") +def rss_feed(request: Request, token: str) -> Response: # noqa: ARG001 + """Generate user-specific RSS podcast feed.""" + try: + # Validate token and get user + user = Core.Database.get_user_by_token(token) + if not user: + return Response("Invalid feed token", status_code=404) + + # Get episodes for this user only + episodes = Core.Database.get_user_all_episodes( + user["id"], + ) + + # Extract first name from email for personalization + email_name = user["email"].split("@")[0].split(".")[0].title() + + fg = FeedGenerator() + fg.title(f"{email_name}'s Article Podcast") + fg.description(f"Web articles converted to audio for {user['email']}") + fg.author(name=RSS_CONFIG["author"]) + fg.language(RSS_CONFIG["language"]) + fg.link(href=f"{RSS_CONFIG['base_url']}/feed/{token}.rss") + fg.id(f"{RSS_CONFIG['base_url']}/feed/{token}.rss") + + for episode in episodes: + fe = fg.add_entry() + episode_sqid = encode_episode_id(episode["id"]) + fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode_sqid}") + fe.title(episode["title"]) + fe.description(episode["title"]) + fe.enclosure( + episode["audio_url"], + str(episode.get("content_length", 0)), + "audio/mpeg", + ) + # SQLite timestamps don't have timezone info, so add UTC + created_at = datetime.fromisoformat(episode["created_at"]) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + fe.pubDate(created_at) + + rss_str = fg.rss_str(pretty=True) + return Response( + rss_str, + media_type="application/rss+xml; charset=utf-8", + ) + + except (ValueError, KeyError, AttributeError) as e: + return Response(f"Error generating feed: {e}", status_code=500) + + +# Backwards compatibility: .xml extension +@app.get("/feed/{token}.xml") +def rss_feed_xml_alias(request: Request, token: str) -> Response: + """Alias for .rss feed (backwards compatibility).""" + return rss_feed(request, token) + + +@app.get("/public.rss") +def public_rss_feed(request: Request) -> Response: # noqa: ARG001 + """Generate public RSS podcast feed.""" + try: + # Get public episodes + episodes = Core.Database.get_public_episodes(50) + + fg = FeedGenerator() + fg.title("PodcastItLater Public Feed") + fg.description("Curated articles converted to audio") + fg.author(name=RSS_CONFIG["author"]) + fg.language(RSS_CONFIG["language"]) + fg.link(href=f"{RSS_CONFIG['base_url']}/public.rss") + fg.id(f"{RSS_CONFIG['base_url']}/public.rss") + + for episode in episodes: + fe = fg.add_entry() + episode_sqid = encode_episode_id(episode["id"]) + fe.id(f"{RSS_CONFIG['base_url']}/episode/{episode_sqid}") + fe.title(episode["title"]) + fe.description(episode["title"]) + fe.enclosure( + episode["audio_url"], + str(episode.get("content_length", 0)), + "audio/mpeg", + ) + # SQLite timestamps don't have timezone info, so add UTC + created_at = datetime.fromisoformat(episode["created_at"]) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + fe.pubDate(created_at) + + rss_str = fg.rss_str(pretty=True) + return Response( + rss_str, + media_type="application/rss+xml; charset=utf-8", + ) + + except (ValueError, KeyError, AttributeError) as e: + return Response(f"Error generating feed: {e}", status_code=500) + + +# Backwards compatibility: .xml extension +@app.get("/public.xml") +def public_rss_feed_xml_alias(request: Request) -> Response: + """Alias for .rss feed (backwards compatibility).""" + return public_rss_feed(request) + + +@app.get("/episode/{episode_id:int}") +def episode_detail_legacy( + request: Request, # noqa: ARG001 + episode_id: int, +) -> RedirectResponse: + """Redirect legacy integer episode IDs to sqid URLs. + + Deprecated: This route exists for backward compatibility. + Will be removed in a future version. + """ + episode_sqid = encode_episode_id(episode_id) + return RedirectResponse( + url=f"/episode/{episode_sqid}", + status_code=301, # Permanent redirect + ) + + +@app.get("/episode/{episode_sqid}") +def episode_detail( + request: Request, + episode_sqid: str, +) -> Episode.EpisodeDetailPage | Response: + """Display individual episode page (public, no auth required).""" + try: + # Decode sqid to episode ID + episode_id = decode_episode_id(episode_sqid) + if episode_id is None: + return Response("Invalid episode ID", status_code=404) + + # Get episode from database + episode = Core.Database.get_episode_by_id(episode_id) + + if not episode: + return Response("Episode not found", status_code=404) + + # Get creator email if episode has user_id + creator_email = None + if episode.get("user_id"): + creator = Core.Database.get_user_by_id(episode["user_id"]) + creator_email = creator["email"] if creator else None + + # Check if current user is logged in + user_id = request.session.get("user_id") + user = None + user_has_episode = False + if user_id: + user = Core.Database.get_user_by_id(user_id) + user_has_episode = Core.Database.user_has_episode( + user_id, + episode_id, + ) + + return Episode.EpisodeDetailPage( + episode=episode, + episode_sqid=episode_sqid, + creator_email=creator_email, + user=user, + base_url=BASE_URL, + user_has_episode=user_has_episode, + ) + + except (ValueError, KeyError) as e: + logger.exception("Error loading episode") + return Response(f"Error loading episode: {e}", status_code=500) + + +@app.get("/status") +def queue_status(request: Request) -> QueueStatus: + """Return HTMX endpoint for live queue updates.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return QueueStatus(items=[]) + + # Get user-specific queue items + queue_items = Core.Database.get_user_queue_status( + user_id, + ) + return QueueStatus(items=queue_items) + + +@app.get("/dashboard-updates") +def dashboard_updates(request: Request) -> Response: + """Return both queue status and recent episodes for dashboard updates.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + queue_status = QueueStatus(items=[]) + episode_list = EpisodeList( + episodes=[], + rss_url=None, + user=None, + viewing_own_feed=False, + ) + # Return HTML as string with both components + return Response( + str(queue_status) + str(episode_list), + media_type="text/html", + ) + + # Get user info for RSS URL + user = Core.Database.get_user_by_id(user_id) + rss_url = f"{BASE_URL}/feed/{user['token']}.rss" if user else None + + # Get user-specific queue items and episodes + queue_items = Core.Database.get_user_queue_status(user_id) + episodes = Core.Database.get_user_recent_episodes(user_id, 10) + + # Return just the content components, not the wrapper div + # The wrapper div with HTMX attributes is in HomePage + queue_status = QueueStatus(items=queue_items) + episode_list = EpisodeList( + episodes=episodes, + rss_url=rss_url, + user=user, + viewing_own_feed=True, + ) + return Response( + str(queue_status) + str(episode_list), + media_type="text/html", + ) + + +# Register admin routes +app.get("/admin")(Admin.admin_queue_status) +app.post("/queue/{job_id}/retry")(Admin.retry_queue_item) + + +@app.post("/billing/checkout") +def billing_checkout(request: Request, data: FormData) -> Response: + """Create Stripe Checkout session.""" + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + tier_raw = data.get("tier", "paid") + tier = tier_raw if isinstance(tier_raw, str) else "paid" + if tier != "paid": + return Response("Invalid tier", status_code=400) + + try: + checkout_url = Billing.create_checkout_session(user_id, tier, BASE_URL) + return RedirectResponse(url=checkout_url, status_code=303) + except ValueError as e: + logger.exception("Checkout error") + return Response(f"Error: {e!s}", status_code=400) + + +@app.post("/stripe/webhook") +async def stripe_webhook(request: Request) -> Response: + """Handle Stripe webhook events.""" + payload = await request.body() + sig_header = request.headers.get("stripe-signature", "") + + try: + result = Billing.handle_webhook_event(payload, sig_header) + return Response(f"OK: {result['status']}", status_code=200) + except Exception as e: + logger.exception("Webhook error") + return Response(f"Error: {e!s}", status_code=400) + + +@app.post("/queue/{job_id}/cancel") +def cancel_queue_item(request: Request, job_id: int) -> Response: + """Cancel a pending queue item.""" + try: + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return Response("Unauthorized", status_code=401) + + # Get job and verify ownership + job = Core.Database.get_job_by_id(job_id) + if job is None or job.get("user_id") != user_id: + return Response("Forbidden", status_code=403) + + # Only allow canceling pending jobs + if job.get("status") != "pending": + return Response("Can only cancel pending jobs", status_code=400) + + # Update status to cancelled + Core.Database.update_job_status( + job_id, + "cancelled", + error="Cancelled by user", + ) + + # Return success with HTMX trigger to refresh + return Response( + "", + status_code=200, + headers={"HX-Trigger": "queue-updated"}, + ) + except (ValueError, KeyError) as e: + return Response( + f"Error cancelling job: {e!s}", + status_code=500, + ) + + +app.delete("/queue/{job_id}")(Admin.delete_queue_item) +app.get("/admin/users")(Admin.admin_users) +app.get("/admin/metrics")(Admin.admin_metrics) +app.post("/admin/users/{user_id}/status")(Admin.update_user_status) +app.post("/admin/episode/{episode_id}/toggle-public")( + Admin.toggle_episode_public, +) + + +@app.post("/episode/{episode_id}/add-to-feed") +def add_episode_to_feed(request: Request, episode_id: int) -> Response: + """Add an episode to the user's feed.""" + # Check if user is logged in + user_id = request.session.get("user_id") + if not user_id: + return Response( + '<div class="alert alert-warning">Please login first</div>', + status_code=200, + ) + + # Check if episode exists + episode = Core.Database.get_episode_by_id(episode_id) + if not episode: + return Response( + '<div class="alert alert-danger">Episode not found</div>', + status_code=404, + ) + + # Check if user already has this episode + if Core.Database.user_has_episode(user_id, episode_id): + return Response( + '<div class="alert alert-info">Already in your feed</div>', + status_code=200, + ) + + # Add episode to user's feed + Core.Database.add_episode_to_user(user_id, episode_id) + + # Track the "added" event + Core.Database.track_episode_event(episode_id, "added", user_id) + + # Reload the current page to show updated button state + # Check referer to determine where to redirect + referer = request.headers.get("referer", "/") + return Response( + "", + status_code=200, + headers={"HX-Redirect": referer}, + ) + + +@app.post("/episode/{episode_id}/track") +def track_episode( + request: Request, + episode_id: int, + data: FormData, +) -> Response: + """Track an episode metric event (play, download).""" + # Get event type from form data + event_type_raw = data.get("event_type", "") + event_type = event_type_raw if isinstance(event_type_raw, str) else "" + + # Validate event type + if event_type not in {"played", "downloaded"}: + return Response("Invalid event type", status_code=400) + + # Get user ID if logged in (None for anonymous) + user_id = request.session.get("user_id") + + # Track the event + Core.Database.track_episode_event(episode_id, event_type, user_id) + + return Response("", status_code=200) + + +class BaseWebTest(Test.TestCase): + """Base class for web tests with database setup.""" + + def setUp(self) -> None: + """Set up test database and client.""" + Core.Database.init_db() + # Create test client + self.client = TestClient(app) + + @staticmethod + def tearDown() -> None: + """Clean up test database.""" + Core.Database.teardown() + + +class TestDurationFormatting(Test.TestCase): + """Test duration formatting functionality.""" + + def test_format_duration_minutes_only(self) -> None: + """Test formatting durations less than an hour.""" + self.assertEqual(UI.format_duration(60), "1m") + self.assertEqual(UI.format_duration(240), "4m") + self.assertEqual(UI.format_duration(300), "5m") + self.assertEqual(UI.format_duration(3599), "60m") + + def test_format_duration_hours_and_minutes(self) -> None: + """Test formatting durations with hours and minutes.""" + self.assertEqual(UI.format_duration(3600), "1h") + self.assertEqual(UI.format_duration(3840), "1h 4m") + self.assertEqual(UI.format_duration(11520), "3h 12m") + self.assertEqual(UI.format_duration(7320), "2h 2m") + + def test_format_duration_round_up(self) -> None: + """Test that seconds are rounded up to nearest minute.""" + self.assertEqual(UI.format_duration(61), "2m") + self.assertEqual(UI.format_duration(119), "2m") + self.assertEqual(UI.format_duration(121), "3m") + self.assertEqual(UI.format_duration(3601), "1h 1m") + + def test_format_duration_edge_cases(self) -> None: + """Test edge cases for duration formatting.""" + self.assertEqual(UI.format_duration(None), "Unknown") + self.assertEqual(UI.format_duration(0), "Unknown") + self.assertEqual(UI.format_duration(-100), "Unknown") + + +class TestAuthentication(BaseWebTest): + """Test authentication functionality.""" + + def test_login_new_user_active(self) -> None: + """New users should be created with active status.""" + response = self.client.post("/login", data={"email": "new@example.com"}) + self.assertEqual(response.status_code, 200) + + # Verify user was created with active status + user = Core.Database.get_user_by_email( + "new@example.com", + ) + self.assertIsNotNone(user) + if user is None: + msg = "no user found" + raise Test.TestError(msg) + self.assertEqual(user.get("status"), "active") + + def test_login_active_user(self) -> None: + """Active users should be able to login.""" + # Create user and set to active + user_id, _ = Core.Database.create_user( + "active@example.com", + ) + Core.Database.update_user_status(user_id, "active") + + response = self.client.post( + "/login", + data={"email": "active@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + def test_login_existing_pending_user(self) -> None: + """Existing pending users should see the pending message.""" + # Create a pending user + _user_id, _ = Core.Database.create_user( + "pending@example.com", + status="pending", + ) + + response = self.client.post( + "/login", + data={"email": "pending@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Account created, currently pending", response.text) + self.assertIn("ben@bensima.com", response.text) + self.assertIn("@bensima", response.text) + + def test_login_disabled_user(self) -> None: + """Disabled users should not be able to login.""" + # Create user and set to disabled + user_id, _ = Core.Database.create_user( + "disabled@example.com", + ) + Core.Database.update_user_status( + user_id, + "disabled", + ) + + response = self.client.post( + "/login", + data={"email": "disabled@example.com"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Account created, currently pending", response.text) + + def test_login_invalid_email(self) -> None: + """Reject malformed emails.""" + response = self.client.post("/login", data={"email": ""}) + + self.assertEqual(response.status_code, 400) + self.assertIn("Email is required", response.text) + + def test_session_persistence(self) -> None: + """Verify session across requests.""" + # Create active user + _user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + # Login + self.client.post("/login", data={"email": "test@example.com"}) + + # Access protected page + response = self.client.get("/") + + # Should see logged-in content (navbar with Manage Account link) + self.assertIn("Manage Account", response.text) + self.assertIn("Home", response.text) + + def test_protected_routes_pending_user(self) -> None: + """Pending users should not access protected routes.""" + # Create pending user + Core.Database.create_user("pending@example.com", status="pending") + + # Try to login + response = self.client.post( + "/login", + data={"email": "pending@example.com"}, + ) + self.assertEqual(response.status_code, 200) + + # Should not have session + response = self.client.get("/") + self.assertNotIn("Logged in as:", response.text) + + def test_protected_routes(self) -> None: + """Ensure auth required for user actions.""" + # Try to submit without login + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + self.assertIn("Please login first", response.text) + + +class TestArticleSubmission(BaseWebTest): + """Test article submission functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + # Create active user and login + user_id, _ = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status(user_id, "active") + self.client.post("/login", data={"email": "test@example.com"}) + + def test_submit_valid_url(self) -> None: + """Accept well-formed URLs.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com/article"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("Article submitted successfully", response.text) + self.assertIn("Job ID:", response.text) + + def test_submit_invalid_url(self) -> None: + """Reject malformed URLs.""" + response = self.client.post("/submit", data={"url": "not-a-url"}) + + self.assertIn("Invalid URL format", response.text) + + def test_submit_without_auth(self) -> None: + """Reject unauthenticated submissions.""" + # Clear session + self.client.get("/logout") + + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + self.assertIn("Please login first", response.text) + + def test_submit_creates_job(self) -> None: + """Verify job creation in database.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com/test"}, + ) + + # Extract job ID from response + match = re.search(r"Job ID: (\d+)", response.text) + self.assertIsNotNone(match) + if match is None: + self.fail("Job ID not found in response") + job_id = int(match.group(1)) + + # Verify job in database + job = Core.Database.get_job_by_id(job_id) + self.assertIsNotNone(job) + if job is None: # Type guard for mypy + self.fail("Job should not be None") + self.assertEqual(job["url"], "https://example.com/test") + self.assertEqual(job["status"], "pending") + + def test_htmx_response(self) -> None: + """Ensure proper HTMX response format.""" + response = self.client.post( + "/submit", + data={"url": "https://example.com"}, + ) + + # Should return HTML fragment, not full page + self.assertNotIn("<!DOCTYPE", response.text) + self.assertIn("<div", response.text) + + +class TestRSSFeed(BaseWebTest): + """Test RSS feed generation.""" + + def setUp(self) -> None: + """Set up test client and create test data.""" + super().setUp() + + # Create user and episodes + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + + # Create test episodes + ep1_id = Core.Database.create_episode( + "Episode 1", + "https://example.com/ep1.mp3", + 300, + 5000, + self.user_id, + ) + ep2_id = Core.Database.create_episode( + "Episode 2", + "https://example.com/ep2.mp3", + 600, + 10000, + self.user_id, + ) + Core.Database.add_episode_to_user(self.user_id, ep1_id) + Core.Database.add_episode_to_user(self.user_id, ep2_id) + + def test_feed_generation(self) -> None: + """Generate valid RSS XML.""" + response = self.client.get(f"/feed/{self.token}.rss") + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.headers["content-type"], + "application/rss+xml; charset=utf-8", + ) + + # Verify RSS structure + self.assertIn("<?xml", response.text) + self.assertIn("<rss", response.text) + self.assertIn("<channel>", response.text) + self.assertIn("<item>", response.text) + + def test_feed_user_isolation(self) -> None: + """Only show user's episodes.""" + # Create another user with episodes + user2_id, _ = Core.Database.create_user( + "other@example.com", + ) + other_ep_id = Core.Database.create_episode( + "Other Episode", + "https://example.com/other.mp3", + 400, + 6000, + user2_id, + ) + Core.Database.add_episode_to_user(user2_id, other_ep_id) + + # Get first user's feed + response = self.client.get(f"/feed/{self.token}.rss") + + # Should only have user's episodes + self.assertIn("Episode 1", response.text) + self.assertIn("Episode 2", response.text) + self.assertNotIn("Other Episode", response.text) + + def test_feed_invalid_token(self) -> None: + """Return 404 for bad tokens.""" + response = self.client.get("/feed/invalid-token.rss") + + self.assertEqual(response.status_code, 404) + + def test_feed_metadata(self) -> None: + """Verify personalized feed titles.""" + response = self.client.get(f"/feed/{self.token}.rss") + + # Should personalize based on email + self.assertIn("Test's Article Podcast", response.text) + self.assertIn("test@example.com", response.text) + + def test_feed_episode_order(self) -> None: + """Ensure reverse chronological order.""" + response = self.client.get(f"/feed/{self.token}.rss") + + # Episode 2 should appear before Episode 1 + ep2_pos = response.text.find("Episode 2") + ep1_pos = response.text.find("Episode 1") + self.assertLess(ep2_pos, ep1_pos) + + def test_feed_enclosures(self) -> None: + """Verify audio URLs and metadata.""" + response = self.client.get(f"/feed/{self.token}.rss") + + # Check enclosure tags + self.assertIn("<enclosure", response.text) + self.assertIn('type="audio/mpeg"', response.text) + + def test_feed_xml_alias_works(self) -> None: + """Test .xml extension works for backwards compatibility.""" + # Get feed with .xml extension + response_xml = self.client.get(f"/feed/{self.token}.xml") + # Get feed with .rss extension + response_rss = self.client.get(f"/feed/{self.token}.rss") + + # Both should work and return same content + self.assertEqual(response_xml.status_code, 200) + self.assertEqual(response_rss.status_code, 200) + self.assertEqual(response_xml.text, response_rss.text) + + def test_public_feed_xml_alias_works(self) -> None: + """Test .xml extension works for public feed.""" + # Get feed with .xml extension + response_xml = self.client.get("/public.xml") + # Get feed with .rss extension + response_rss = self.client.get("/public.rss") + + # Both should work and return same content + self.assertEqual(response_xml.status_code, 200) + self.assertEqual(response_rss.status_code, 200) + self.assertEqual(response_xml.text, response_rss.text) + + +class TestAdminInterface(BaseWebTest): + """Test admin interface functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + + # Create and login admin user + self.user_id, _ = Core.Database.create_user( + "ben@bensima.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Create test data + self.job_id = Core.Database.add_to_queue( + "https://example.com/test", + "ben@bensima.com", + self.user_id, + ) + + def test_queue_status_view(self) -> None: + """Verify queue display.""" + response = self.client.get("/admin") + + self.assertEqual(response.status_code, 200) + self.assertIn("Queue Status", response.text) + self.assertIn("https://example.com/test", response.text) + + def test_retry_action(self) -> None: + """Test retry button functionality.""" + # Set job to error state + Core.Database.update_job_status( + self.job_id, + "error", + "Failed", + ) + + # Retry + response = self.client.post(f"/queue/{self.job_id}/retry") + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Job should be pending again + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job is not None: + self.assertEqual(job["status"], "pending") + + def test_delete_action(self) -> None: + """Test delete button functionality.""" + response = self.client.delete( + f"/queue/{self.job_id}", + headers={"referer": "/admin"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Job should be gone + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNone(job) + + def test_user_data_isolation(self) -> None: + """Ensure admin sees all data.""" + # Create another user's job + user2_id, _ = Core.Database.create_user( + "other@example.com", + ) + Core.Database.add_to_queue( + "https://example.com/other", + "other@example.com", + user2_id, + ) + + # View queue status as admin + response = self.client.get("/admin") + + # Admin should see all jobs + self.assertIn("https://example.com/test", response.text) + self.assertIn("https://example.com/other", response.text) + + def test_status_summary(self) -> None: + """Verify status counts display.""" + # Create jobs with different statuses + Core.Database.update_job_status( + self.job_id, + "error", + "Failed", + ) + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + Core.Database.update_job_status( + job2, + "processing", + ) + + response = self.client.get("/admin") + + # Should show status counts + self.assertIn("ERROR: 1", response.text) + self.assertIn("PROCESSING: 1", response.text) + + +class TestMetricsDashboard(BaseWebTest): + """Test metrics dashboard functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in admin user.""" + super().setUp() + + # Create and login admin user + self.user_id, _ = Core.Database.create_user( + "ben@bensima.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "ben@bensima.com"}) + + def test_metrics_page_requires_admin(self) -> None: + """Verify non-admin users cannot access metrics.""" + # Create non-admin user + user_id, _ = Core.Database.create_user("user@example.com") + Core.Database.update_user_status(user_id, "active") + + # Login as non-admin + self.client.get("/logout") + self.client.post("/login", data={"email": "user@example.com"}) + + # Try to access metrics + response = self.client.get("/admin/metrics", follow_redirects=False) + + # Should redirect + self.assertEqual(response.status_code, 302) + self.assertEqual(response.headers["Location"], "/?error=forbidden") + + def test_metrics_page_requires_login(self) -> None: + """Verify unauthenticated users are redirected.""" + self.client.get("/logout") + + response = self.client.get("/admin/metrics", follow_redirects=False) + + self.assertEqual(response.status_code, 302) + self.assertEqual(response.headers["Location"], "/") + + def test_metrics_displays_summary(self) -> None: + """Verify metrics summary is displayed.""" + # Create test episode + episode_id = Core.Database.create_episode( + title="Test Episode", + audio_url="http://example.com/audio.mp3", + content_length=1000, + duration=300, + ) + Core.Database.add_episode_to_user(self.user_id, episode_id) + + # Track some events + Core.Database.track_episode_event(episode_id, "played") + Core.Database.track_episode_event(episode_id, "played") + Core.Database.track_episode_event(episode_id, "downloaded") + Core.Database.track_episode_event(episode_id, "added", self.user_id) + + # Get metrics page + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Episode Metrics", response.text) + self.assertIn("Total Episodes", response.text) + self.assertIn("Total Plays", response.text) + + def test_growth_metrics_display(self) -> None: + """Verify growth and usage metrics are displayed.""" + # Create an active subscriber + user2_id, _ = Core.Database.create_user("active@example.com") + Core.Database.update_user_subscription( + user2_id, + subscription_id="sub_test", + status="active", + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + tier="paid", + cancel_at_period_end=False, + ) + + # Create a queue item + Core.Database.add_to_queue( + "https://example.com/new", + "active@example.com", + user2_id, + ) + + # Get metrics page + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Growth & Usage", response.text) + self.assertIn("Total Users", response.text) + self.assertIn("Active Subs", response.text) + self.assertIn("Submissions (24h)", response.text) + + self.assertIn("Total Downloads", response.text) + self.assertIn("Total Adds", response.text) + + def test_metrics_shows_top_episodes(self) -> None: + """Verify top episodes tables are displayed.""" + # Create test episodes + episode1 = Core.Database.create_episode( + title="Popular Episode", + audio_url="http://example.com/popular.mp3", + content_length=1000, + duration=300, + author="Test Author", + ) + Core.Database.add_episode_to_user(self.user_id, episode1) + + episode2 = Core.Database.create_episode( + title="Less Popular Episode", + audio_url="http://example.com/less.mp3", + content_length=1000, + duration=300, + ) + Core.Database.add_episode_to_user(self.user_id, episode2) + + # Track events - more for episode1 + for _ in range(5): + Core.Database.track_episode_event(episode1, "played") + for _ in range(2): + Core.Database.track_episode_event(episode2, "played") + + for _ in range(3): + Core.Database.track_episode_event(episode1, "downloaded") + Core.Database.track_episode_event(episode2, "downloaded") + + # Get metrics page + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Most Played", response.text) + self.assertIn("Most Downloaded", response.text) + self.assertIn("Popular Episode", response.text) + + def test_metrics_empty_state(self) -> None: + """Verify metrics page works with no data.""" + response = self.client.get("/admin/metrics") + + self.assertEqual(response.status_code, 200) + self.assertIn("Episode Metrics", response.text) + # Should show 0 for counts + self.assertIn("Total Episodes", response.text) + + +class TestJobCancellation(BaseWebTest): + """Test job cancellation functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user and pending job.""" + super().setUp() + + # Create and login user + self.user_id, _ = Core.Database.create_user( + "test@example.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + # Create pending job + self.job_id = Core.Database.add_to_queue( + "https://example.com/test", + "test@example.com", + self.user_id, + ) + + def test_cancel_pending_job(self) -> None: + """Successfully cancel a pending job.""" + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Trigger", response.headers) + self.assertEqual(response.headers["HX-Trigger"], "queue-updated") + + # Verify job status is cancelled + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job is not None: + self.assertEqual(job["status"], "cancelled") + self.assertEqual(job.get("error_message", ""), "Cancelled by user") + + def test_cannot_cancel_processing_job(self) -> None: + """Prevent cancelling jobs that are already processing.""" + # Set job to processing + Core.Database.update_job_status( + self.job_id, + "processing", + ) + + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 400) + self.assertIn("Can only cancel pending jobs", response.text) + + def test_cannot_cancel_completed_job(self) -> None: + """Prevent cancelling completed jobs.""" + # Set job to completed + Core.Database.update_job_status( + self.job_id, + "completed", + ) + + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 400) + + def test_cannot_cancel_other_users_job(self) -> None: + """Prevent users from cancelling other users' jobs.""" + # Create another user's job + user2_id, _ = Core.Database.create_user( + "other@example.com", + ) + other_job_id = Core.Database.add_to_queue( + "https://example.com/other", + "other@example.com", + user2_id, + ) + + # Try to cancel it + response = self.client.post(f"/queue/{other_job_id}/cancel") + + self.assertEqual(response.status_code, 403) + + def test_cancel_without_auth(self) -> None: + """Require authentication to cancel jobs.""" + # Logout + self.client.get("/logout") + + response = self.client.post(f"/queue/{self.job_id}/cancel") + + self.assertEqual(response.status_code, 401) + + def test_cancel_button_visibility(self) -> None: + """Cancel button only shows for pending jobs.""" + # Create jobs with different statuses + processing_job = Core.Database.add_to_queue( + "https://example.com/processing", + "test@example.com", + self.user_id, + ) + Core.Database.update_job_status( + processing_job, + "processing", + ) + + # Get status view + response = self.client.get("/status") + + # Should have cancel button for pending job + self.assertIn(f'hx-post="/queue/{self.job_id}/cancel"', response.text) + self.assertIn("Cancel", response.text) + + # Should NOT have cancel button for processing job + self.assertNotIn( + f'hx-post="/queue/{processing_job}/cancel"', + response.text, + ) + + +class TestEpisodeDetailPage(BaseWebTest): + """Test episode detail page functionality.""" + + def setUp(self) -> None: + """Set up test client with user and episode.""" + super().setUp() + + # Create user and episode + self.user_id, self.token = Core.Database.create_user( + "creator@example.com", + status="active", + ) + self.episode_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=5000, + user_id=self.user_id, + author="Test Author", + original_url="https://example.com/article", + ) + Core.Database.add_episode_to_user(self.user_id, self.episode_id) + self.episode_sqid = encode_episode_id(self.episode_id) + + def test_episode_page_loads(self) -> None: + """Episode page should load successfully.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertEqual(response.status_code, 200) + self.assertIn("Test Episode", response.text) + self.assertIn("Test Author", response.text) + + def test_episode_not_found(self) -> None: + """Non-existent episode should return 404.""" + response = self.client.get("/episode/invalidcode") + + self.assertEqual(response.status_code, 404) + + def test_audio_player_present(self) -> None: + """Audio player should be present on episode page.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("<audio", response.text) + self.assertIn("controls", response.text) + self.assertIn("https://example.com/audio.mp3", response.text) + + def test_share_button_present(self) -> None: + """Share button should be present.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("Share Episode", response.text) + self.assertIn("navigator.clipboard.writeText", response.text) + + def test_original_article_link(self) -> None: + """Original article link should be present.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("View original article", response.text) + self.assertIn("https://example.com/article", response.text) + + def test_signup_banner_for_non_authenticated(self) -> None: + """Non-authenticated users should see signup banner.""" + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertIn("This episode was created by", response.text) + self.assertIn("creator@example.com", response.text) + self.assertIn("Sign Up", response.text) + + def test_no_signup_banner_for_authenticated(self) -> None: + """Authenticated users should not see signup banner.""" + # Login + self.client.post("/login", data={"email": "creator@example.com"}) + + response = self.client.get(f"/episode/{self.episode_sqid}") + + self.assertNotIn("This episode was created by", response.text) + + def test_episode_links_from_home_page(self) -> None: + """Episode titles on home page should link to detail page.""" + # Login to see episodes + self.client.post("/login", data={"email": "creator@example.com"}) + + response = self.client.get("/") + + self.assertIn(f'href="/episode/{self.episode_sqid}"', response.text) + self.assertIn("Test Episode", response.text) + + def test_legacy_integer_id_redirects(self) -> None: + """Legacy integer episode IDs should redirect to sqid URLs.""" + response = self.client.get( + f"/episode/{self.episode_id}", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 301) + self.assertEqual( + response.headers["location"], + f"/episode/{self.episode_sqid}", + ) + + +class TestPublicFeed(BaseWebTest): + """Test public feed functionality.""" + + def setUp(self) -> None: + """Set up test database, client, and create sample episodes.""" + super().setUp() + + # Create admin user + self.admin_id, _ = Core.Database.create_user( + "ben@bensima.com", + status="active", + ) + + # Create some episodes, some public, some private + self.public_episode_id = Core.Database.create_episode( + title="Public Episode", + audio_url="https://example.com/public.mp3", + duration=300, + content_length=1000, + user_id=self.admin_id, + author="Test Author", + original_url="https://example.com/public", + original_url_hash=Core.hash_url("https://example.com/public"), + ) + Core.Database.mark_episode_public(self.public_episode_id) + + self.private_episode_id = Core.Database.create_episode( + title="Private Episode", + audio_url="https://example.com/private.mp3", + duration=200, + content_length=800, + user_id=self.admin_id, + author="Test Author", + original_url="https://example.com/private", + original_url_hash=Core.hash_url("https://example.com/private"), + ) + + def test_public_feed_page(self) -> None: + """Public feed page should show only public episodes.""" + response = self.client.get("/public") + + self.assertEqual(response.status_code, 200) + self.assertIn("Public Episode", response.text) + self.assertNotIn("Private Episode", response.text) + + def test_home_page_shows_public_feed_when_logged_out(self) -> None: + """Home page should show public episodes when user is not logged in.""" + response = self.client.get("/") + + self.assertEqual(response.status_code, 200) + self.assertIn("Public Episode", response.text) + self.assertNotIn("Private Episode", response.text) + + def test_admin_can_toggle_episode_public(self) -> None: + """Admin should be able to toggle episode public/private status.""" + # Login as admin + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Toggle private episode to public + response = self.client.post( + f"/admin/episode/{self.private_episode_id}/toggle-public", + ) + + self.assertEqual(response.status_code, 200) + + # Verify it's now public + episode = Core.Database.get_episode_by_id(self.private_episode_id) + self.assertEqual(episode["is_public"], 1) # type: ignore[index] + + def test_non_admin_cannot_toggle_public(self) -> None: + """Non-admin users should not be able to toggle public status.""" + # Create and login as regular user + _user_id, _ = Core.Database.create_user("user@example.com") + self.client.post("/login", data={"email": "user@example.com"}) + + # Try to toggle + response = self.client.post( + f"/admin/episode/{self.private_episode_id}/toggle-public", + ) + + self.assertEqual(response.status_code, 403) + + def test_admin_can_add_user_episode_to_own_feed(self) -> None: + """Admin can add another user's episode to their own feed.""" + # Create regular user and their episode + user_id, _ = Core.Database.create_user( + "user@example.com", + status="active", + ) + user_episode_id = Core.Database.create_episode( + title="User Episode", + audio_url="https://example.com/user.mp3", + duration=400, + content_length=1200, + user_id=user_id, + author="User Author", + original_url="https://example.com/user-article", + original_url_hash=Core.hash_url("https://example.com/user-article"), + ) + Core.Database.add_episode_to_user(user_id, user_episode_id) + + # Login as admin + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Admin adds user's episode to their own feed + response = self.client.post(f"/episode/{user_episode_id}/add-to-feed") + + self.assertEqual(response.status_code, 200) + + # Verify episode is now in admin's feed + admin_episodes = Core.Database.get_user_episodes(self.admin_id) + episode_ids = [e["id"] for e in admin_episodes] + self.assertIn(user_episode_id, episode_ids) + + # Verify "added" event was tracked + metrics = Core.Database.get_episode_metric_events(user_episode_id) + added_events = [m for m in metrics if m["event_type"] == "added"] + self.assertEqual(len(added_events), 1) + self.assertEqual(added_events[0]["user_id"], self.admin_id) + + def test_admin_can_add_user_episode_to_public_feed(self) -> None: + """Admin should be able to add another user's episode to public feed.""" + # Create regular user and their episode + user_id, _ = Core.Database.create_user( + "user@example.com", + status="active", + ) + user_episode_id = Core.Database.create_episode( + title="User Episode for Public", + audio_url="https://example.com/user-public.mp3", + duration=500, + content_length=1500, + user_id=user_id, + author="User Author", + original_url="https://example.com/user-public-article", + original_url_hash=Core.hash_url( + "https://example.com/user-public-article", + ), + ) + Core.Database.add_episode_to_user(user_id, user_episode_id) + + # Verify episode is private initially + episode = Core.Database.get_episode_by_id(user_episode_id) + self.assertEqual(episode["is_public"], 0) # type: ignore[index] + + # Login as admin + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Admin toggles episode to public + response = self.client.post( + f"/admin/episode/{user_episode_id}/toggle-public", + ) + + self.assertEqual(response.status_code, 200) + + # Verify episode is now public + episode = Core.Database.get_episode_by_id(user_episode_id) + self.assertEqual(episode["is_public"], 1) # type: ignore[index] + + # Verify episode appears in public feed + public_episodes = Core.Database.get_public_episodes() + episode_ids = [e["id"] for e in public_episodes] + self.assertIn(user_episode_id, episode_ids) + + +class TestEpisodeDeduplication(BaseWebTest): + """Test episode deduplication functionality.""" + + def setUp(self) -> None: + """Set up test database, client, and create test user.""" + super().setUp() + + self.user_id, self.token = Core.Database.create_user( + "test@example.com", + status="active", + ) + + # Create an existing episode + self.existing_url = "https://example.com/article" + self.url_hash = Core.hash_url(self.existing_url) + + self.episode_id = Core.Database.create_episode( + title="Existing Article", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test Author", + original_url=self.existing_url, + original_url_hash=self.url_hash, + ) + + def test_url_normalization(self) -> None: + """URLs should be normalized for deduplication.""" + # Different URL variations that should be normalized to same hash + urls = [ + "http://example.com/article", + "https://example.com/article", + "https://www.example.com/article", + "https://EXAMPLE.COM/article", + "https://example.com/article/", + ] + + hashes = [Core.hash_url(url) for url in urls] + + # All should produce the same hash + self.assertEqual(len(set(hashes)), 1) + + def test_find_existing_episode_by_hash(self) -> None: + """Should find existing episode by normalized URL hash.""" + # Try different URL variations + similar_urls = [ + "http://example.com/article", + "https://www.example.com/article", + ] + + for url in similar_urls: + url_hash = Core.hash_url(url) + episode = Core.Database.get_episode_by_url_hash(url_hash) + + self.assertIsNotNone(episode) + if episode is not None: + self.assertEqual(episode["id"], self.episode_id) + + def test_add_existing_episode_to_user_feed(self) -> None: + """Should add existing episode to new user's feed.""" + # Create second user + user2_id, _ = Core.Database.create_user("user2@example.com") + + # Add existing episode to their feed + Core.Database.add_episode_to_user(user2_id, self.episode_id) + + # Verify it appears in their feed + episodes = Core.Database.get_user_episodes(user2_id) + episode_ids = [e["id"] for e in episodes] + + self.assertIn(self.episode_id, episode_ids) + + +class TestMetricsTracking(BaseWebTest): + """Test episode metrics tracking.""" + + def setUp(self) -> None: + """Set up test database, client, and create test episode.""" + super().setUp() + + self.user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + + self.episode_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test Author", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + + def test_track_episode_added(self) -> None: + """Should track when episode is added to feed.""" + Core.Database.track_episode_event( + self.episode_id, + "added", + self.user_id, + ) + + # Verify metric was recorded + metrics = Core.Database.get_episode_metric_events(self.episode_id) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["event_type"], "added") + self.assertEqual(metrics[0]["user_id"], self.user_id) + + def test_track_episode_played(self) -> None: + """Should track when episode is played.""" + Core.Database.track_episode_event( + self.episode_id, + "played", + self.user_id, + ) + + metrics = Core.Database.get_episode_metric_events(self.episode_id) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["event_type"], "played") + + def test_track_anonymous_play(self) -> None: + """Should track plays from anonymous users.""" + Core.Database.track_episode_event( + self.episode_id, + "played", + user_id=None, + ) + + metrics = Core.Database.get_episode_metric_events(self.episode_id) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["event_type"], "played") + self.assertIsNone(metrics[0]["user_id"]) + + def test_track_endpoint(self) -> None: + """POST /episode/{id}/track should record metrics.""" + # Login as user + self.client.post("/login", data={"email": "test@example.com"}) + + response = self.client.post( + f"/episode/{self.episode_id}/track", + data={"event_type": "played"}, + ) + + self.assertEqual(response.status_code, 200) + + # Verify metric was recorded + metrics = Core.Database.get_episode_metric_events(self.episode_id) + played_metrics = [m for m in metrics if m["event_type"] == "played"] + self.assertGreater(len(played_metrics), 0) + + +class TestUsageLimits(BaseWebTest): + """Test usage tracking and limit enforcement.""" + + def setUp(self) -> None: + """Set up test with free tier user.""" + super().setUp() + + # Create free tier user + self.user_id, self.token = Core.Database.create_user( + "free@example.com", + status="active", + ) + # Login + self.client.post("/login", data={"email": "free@example.com"}) + + def test_usage_counts_episodes_added_to_feed(self) -> None: + """Usage should count episodes added via user_episodes table.""" + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNotNone(user) + assert user is not None # type narrowing # noqa: S101 + period_start, period_end = Billing.get_period_boundaries(user) + + # Initially no usage + usage = Billing.get_usage(self.user_id, period_start, period_end) + self.assertEqual(usage["articles"], 0) + + # Add an episode to user's feed + ep_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/test.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Usage should now be 1 + usage = Billing.get_usage(self.user_id, period_start, period_end) + self.assertEqual(usage["articles"], 1) + + def test_usage_counts_existing_episodes_correctly(self) -> None: + """Adding existing episodes should count toward usage.""" + # Create another user who creates an episode + other_user_id, _ = Core.Database.create_user("other@example.com") + ep_id = Core.Database.create_episode( + title="Other User Episode", + audio_url="https://example.com/other.mp3", + duration=400, + content_length=1200, + user_id=other_user_id, + author="Other", + original_url="https://example.com/other-article", + original_url_hash=Core.hash_url( + "https://example.com/other-article", + ), + ) + Core.Database.add_episode_to_user(other_user_id, ep_id) + + # Free user adds it to their feed + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Check usage for free user + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNotNone(user) + assert user is not None # type narrowing # noqa: S101 + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(self.user_id, period_start, period_end) + + # Should count as 1 article for free user + self.assertEqual(usage["articles"], 1) + + def test_free_tier_limit_enforcement(self) -> None: + """Free tier users should be blocked at 10 articles.""" + # Add 10 episodes (the free tier limit) + for i in range(10): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Try to submit 11th article + response = self.client.post( + "/submit", + data={"url": "https://example.com/article11"}, + ) + + # Should be blocked + self.assertEqual(response.status_code, 200) + self.assertIn("Limit reached", response.text) + self.assertIn("10", response.text) + self.assertIn("Upgrade", response.text) + + def test_can_submit_blocks_at_limit(self) -> None: + """can_submit should return False at limit.""" + # Add 10 episodes + for i in range(10): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Check can_submit + allowed, msg, usage = Billing.can_submit(self.user_id) + + self.assertFalse(allowed) + self.assertIn("10", msg) + self.assertIn("limit", msg.lower()) + self.assertEqual(usage["articles"], 10) + + def test_paid_tier_unlimited(self) -> None: + """Paid tier should have no article limits.""" + # Create a paid tier user directly + paid_user_id, _ = Core.Database.create_user("paid@example.com") + + # Simulate paid subscription via update_user_subscription + now = datetime.now(timezone.utc) + period_start = now + december = 12 + january = 1 + period_end = now.replace( + month=now.month + 1 if now.month < december else january, + ) + + Core.Database.update_user_subscription( + paid_user_id, + subscription_id="sub_test123", + status="active", + period_start=period_start, + period_end=period_end, + tier="paid", + cancel_at_period_end=False, + ) + + # Add 20 episodes (more than free limit) + for i in range(20): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=paid_user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(paid_user_id, ep_id) + + # Should still be allowed to submit + allowed, msg, usage = Billing.can_submit(paid_user_id) + + self.assertTrue(allowed) + self.assertEqual(msg, "") + self.assertEqual(usage["articles"], 20) + + +class TestAccountPage(BaseWebTest): + """Test account page functionality.""" + + def setUp(self) -> None: + """Set up test with user.""" + super().setUp() + self.user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + def test_account_page_logged_in(self) -> None: + """Account page should render for logged-in users.""" + # Create some usage to verify stats are shown + ep_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/audio.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test Author", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + response = self.client.get("/account") + + self.assertEqual(response.status_code, 200) + self.assertIn("My Account", response.text) + self.assertIn("test@example.com", response.text) + self.assertIn("1 / 10", response.text) # Usage / Limit for free tier + + def test_account_page_login_required(self) -> None: + """Should redirect to login if not logged in.""" + self.client.post("/logout") + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + self.assertEqual(response.headers["location"], "/?error=login_required") + + def test_logout(self) -> None: + """Logout should clear session.""" + response = self.client.post("/logout", follow_redirects=False) + self.assertEqual(response.status_code, 303) + self.assertEqual(response.headers["location"], "/") + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + def test_billing_portal_redirect(self) -> None: + """Billing portal should redirect to Stripe.""" + # First set a customer ID + Core.Database.set_user_stripe_customer(self.user_id, "cus_test") + + # Mock the create_portal_session method + with patch( + "Biz.PodcastItLater.Billing.create_portal_session", + ) as mock_portal: + mock_portal.return_value = "https://billing.stripe.com/test" + + response = self.client.post( + "/billing/portal", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 303) + self.assertEqual( + response.headers["location"], + "https://billing.stripe.com/test", + ) + + def test_update_email_success(self) -> None: + """Should allow updating email.""" + # POST new email + response = self.client.post( + "/settings/email", + data={"email": "new@example.com"}, + ) + self.assertEqual(response.status_code, 200) + + # Verify update in DB + user = Core.Database.get_user_by_id(self.user_id) + self.assertEqual(user["email"], "new@example.com") # type: ignore[index] + + def test_update_email_duplicate(self) -> None: + """Should prevent updating to existing email.""" + # Create another user + Core.Database.create_user("other@example.com") + + # Try to update to their email + response = self.client.post( + "/settings/email", + data={"email": "other@example.com"}, + ) + + # Should show error (return 200 with error message in form) + self.assertEqual(response.status_code, 200) + self.assertIn("already taken", response.text.lower()) + + def test_delete_account(self) -> None: + """Should allow user to delete their account.""" + # Delete account + response = self.client.delete("/account") + self.assertEqual(response.status_code, 200) + self.assertIn("HX-Redirect", response.headers) + + # Verify user gone + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNone(user) + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + +class TestAdminUsers(BaseWebTest): + """Test admin user management functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in admin user.""" + super().setUp() + + # Create and login admin user + self.user_id, _ = Core.Database.create_user( + "ben@bensima.com", + ) + Core.Database.update_user_status( + self.user_id, + "active", + ) + self.client.post("/login", data={"email": "ben@bensima.com"}) + + # Create another regular user + self.other_user_id, _ = Core.Database.create_user("user@example.com") + Core.Database.update_user_status(self.other_user_id, "active") + + def test_admin_users_page_access(self) -> None: + """Admin can access users page.""" + response = self.client.get("/admin/users") + self.assertEqual(response.status_code, 200) + self.assertIn("User Management", response.text) + self.assertIn("user@example.com", response.text) + + def test_non_admin_users_page_access(self) -> None: + """Non-admin cannot access users page.""" + # Login as regular user + self.client.get("/logout") + self.client.post("/login", data={"email": "user@example.com"}) + + response = self.client.get("/admin/users") + self.assertEqual(response.status_code, 302) + self.assertIn("error=forbidden", response.headers["Location"]) + + def test_admin_can_update_user_status(self) -> None: + """Admin can update user status.""" + response = self.client.post( + f"/admin/users/{self.other_user_id}/status", + data={"status": "disabled"}, + ) + self.assertEqual(response.status_code, 200) + + user = Core.Database.get_user_by_id(self.other_user_id) + assert user is not None # noqa: S101 + self.assertEqual(user["status"], "disabled") + + def test_non_admin_cannot_update_user_status(self) -> None: + """Non-admin cannot update user status.""" + # Login as regular user + self.client.get("/logout") + self.client.post("/login", data={"email": "user@example.com"}) + + response = self.client.post( + f"/admin/users/{self.other_user_id}/status", + data={"status": "disabled"}, + ) + self.assertEqual(response.status_code, 403) + + user = Core.Database.get_user_by_id(self.other_user_id) + assert user is not None # noqa: S101 + self.assertEqual(user["status"], "active") + + def test_update_user_status_invalid_status(self) -> None: + """Invalid status validation.""" + response = self.client.post( + f"/admin/users/{self.other_user_id}/status", + data={"status": "invalid_status"}, + ) + self.assertEqual(response.status_code, 400) + + user = Core.Database.get_user_by_id(self.other_user_id) + assert user is not None # noqa: S101 + self.assertEqual(user["status"], "active") + + +def test() -> None: + """Run all tests for the web module.""" + Test.run( + App.Area.Test, + [ + TestDurationFormatting, + TestAuthentication, + TestArticleSubmission, + TestRSSFeed, + TestAdminInterface, + TestJobCancellation, + TestEpisodeDetailPage, + TestPublicFeed, + TestEpisodeDeduplication, + TestMetricsTracking, + TestUsageLimits, + TestAccountPage, + TestAdminUsers, + ], + ) + + +def main() -> None: + """Run the web server.""" + if "test" in sys.argv: + test() + else: + # Initialize database on startup + Core.Database.init_db() + uvicorn.run(app, host="0.0.0.0", port=PORT) # noqa: S104 diff --git a/Biz/PodcastItLater/Worker.nix b/Biz/PodcastItLater/Worker.nix new file mode 100644 index 0000000..974a3ba --- /dev/null +++ b/Biz/PodcastItLater/Worker.nix @@ -0,0 +1,63 @@ +{ + options, + lib, + config, + pkgs, + ... +}: let + cfg = config.services.podcastitlater-worker; +in { + options.services.podcastitlater-worker = { + enable = lib.mkEnableOption "Enable the PodcastItLater worker service"; + dataDir = lib.mkOption { + type = lib.types.path; + default = "/var/podcastitlater"; + description = "Data directory for PodcastItLater (shared with web)"; + }; + package = lib.mkOption { + type = lib.types.package; + description = "PodcastItLater worker package to use"; + }; + }; + config = lib.mkIf cfg.enable { + systemd.services.podcastitlater-worker = { + path = [cfg.package pkgs.ffmpeg]; # ffmpeg needed for pydub + wantedBy = ["multi-user.target"]; + after = ["network.target"]; + preStart = '' + # Create data directory if it doesn't exist + mkdir -p ${cfg.dataDir} + + # Manual step: create this file with secrets + # OPENAI_API_KEY=your-openai-api-key + # S3_ENDPOINT=https://your-s3-endpoint.digitaloceanspaces.com + # S3_BUCKET=your-bucket-name + # S3_ACCESS_KEY=your-s3-access-key + # S3_SECRET_KEY=your-s3-secret-key + test -f /run/podcastitlater/worker-env + ''; + script = '' + ${cfg.package}/bin/podcastitlater-worker + ''; + description = '' + PodcastItLater Worker Service - processes articles to podcasts + ''; + serviceConfig = { + Environment = [ + "AREA=Live" + "DATA_DIR=${cfg.dataDir}" + ]; + EnvironmentFile = "/run/podcastitlater/worker-env"; + KillSignal = "TERM"; + KillMode = "mixed"; + Type = "simple"; + Restart = "always"; + RestartSec = "10"; + # Give the worker time to finish current job + TimeoutStopSec = "300"; # 5 minutes + # Send SIGTERM first, then SIGKILL after timeout + SendSIGKILL = "yes"; + }; + }; + }; +} diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py new file mode 100644 index 0000000..bf6ef9e --- /dev/null +++ b/Biz/PodcastItLater/Worker.py @@ -0,0 +1,2199 @@ +"""Background worker for processing article-to-podcast conversions.""" + +# : dep boto3 +# : dep botocore +# : dep openai +# : dep psutil +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep trafilatura +# : out podcastitlater-worker +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import boto3 # type: ignore[import-untyped] +import concurrent.futures +import io +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import operator +import os +import psutil # type: ignore[import-untyped] +import pytest +import signal +import sys +import tempfile +import threading +import time +import trafilatura +import typing +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from pathlib import Path +from pydub import AudioSegment # type: ignore[import-untyped] +from typing import Any + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +S3_ENDPOINT = os.getenv("S3_ENDPOINT") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") +area = App.from_env() + +# Worker configuration +MAX_CONTENT_LENGTH = 5000 # characters for TTS +MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles +POLL_INTERVAL = 30 # seconds +MAX_RETRIES = 3 +TTS_MODEL = "tts-1" +TTS_VOICE = "alloy" +MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage +CROSSFADE_DURATION = 500 # ms for crossfading segments +PAUSE_DURATION = 1000 # ms for silence between segments + + +class ShutdownHandler: + """Handles graceful shutdown of the worker.""" + + def __init__(self) -> None: + """Initialize shutdown handler.""" + self.shutdown_requested = threading.Event() + self.current_job_id: int | None = None + self.lock = threading.Lock() + + # Register signal handlers + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + def _handle_signal(self, signum: int, _frame: Any) -> None: + """Handle shutdown signals.""" + logger.info( + "Received signal %d, initiating graceful shutdown...", + signum, + ) + self.shutdown_requested.set() + + def is_shutdown_requested(self) -> bool: + """Check if shutdown has been requested.""" + return self.shutdown_requested.is_set() + + def set_current_job(self, job_id: int | None) -> None: + """Set the currently processing job.""" + with self.lock: + self.current_job_id = job_id + + def get_current_job(self) -> int | None: + """Get the currently processing job.""" + with self.lock: + return self.current_job_id + + +class ArticleProcessor: + """Handles the complete article-to-podcast conversion pipeline.""" + + def __init__(self, shutdown_handler: ShutdownHandler) -> None: + """Initialize the processor with required services. + + Raises: + ValueError: If OPENAI_API_KEY environment variable is not set. + """ + if not OPENAI_API_KEY: + msg = "OPENAI_API_KEY environment variable is required" + raise ValueError(msg) + + self.openai_client: openai.OpenAI = openai.OpenAI( + api_key=OPENAI_API_KEY, + ) + self.shutdown_handler = shutdown_handler + + # Initialize S3 client for Digital Ocean Spaces + if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): + self.s3_client: Any = boto3.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + ) + else: + logger.warning("S3 configuration incomplete, uploads will fail") + self.s3_client = None + + @staticmethod + def extract_article_content( + url: str, + ) -> tuple[str, str, str | None, str | None]: + """Extract title, content, author, and date from article URL. + + Returns: + tuple: (title, content, author, publication_date) + + Raises: + ValueError: If content cannot be downloaded, extracted, or large. + """ + try: + downloaded = trafilatura.fetch_url(url) + if not downloaded: + msg = f"Failed to download content from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Check size before processing + if ( + len(downloaded) > MAX_ARTICLE_SIZE * 4 + ): # Rough HTML to text ratio + msg = f"Article too large: {len(downloaded)} bytes" + raise ValueError(msg) # noqa: TRY301 + + # Extract with metadata + result = trafilatura.extract( + downloaded, + include_comments=False, + include_tables=False, + with_metadata=True, + output_format="json", + ) + + if not result: + msg = f"Failed to extract content from {url}" + raise ValueError(msg) # noqa: TRY301 + + data = json.loads(result) + + title = data.get("title", "Untitled Article") + content = data.get("text", "") + author = data.get("author") + pub_date = data.get("date") + + if not content: + msg = f"No content extracted from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Enforce content size limit + if len(content) > MAX_ARTICLE_SIZE: + logger.warning( + "Article content truncated from %d to %d characters", + len(content), + MAX_ARTICLE_SIZE, + ) + content = content[:MAX_ARTICLE_SIZE] + + logger.info( + "Extracted article: %s (%d chars, author: %s, date: %s)", + title, + len(content), + author or "unknown", + pub_date or "unknown", + ) + except Exception: + logger.exception("Failed to extract content from %s", url) + raise + else: + return title, content, author, pub_date + + def text_to_speech( + self, + text: str, + title: str, + author: str | None = None, + pub_date: str | None = None, + ) -> bytes: + """Convert text to speech with intro/outro using OpenAI TTS API. + + Uses parallel processing for chunks while maintaining order. + Adds intro with metadata and outro with attribution. + + Args: + text: Article content to convert + title: Article title + author: Article author (optional) + pub_date: Publication date (optional) + + Raises: + ValueError: If no chunks are generated from text. + """ + try: + # Generate intro audio + intro_text = self._create_intro_text(title, author, pub_date) + intro_audio = self._generate_tts_segment(intro_text) + + # Generate outro audio + outro_text = self._create_outro_text(title, author) + outro_audio = self._generate_tts_segment(outro_text) + + # Use LLM to prepare and chunk the main content + chunks = prepare_text_for_tts(text, title) + + if not chunks: + msg = "No chunks generated from text" + raise ValueError(msg) # noqa: TRY301 + + logger.info("Processing %d chunks for TTS", len(chunks)) + + # Check memory before parallel processing + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer + logger.warning( + "High memory usage (%.1f%%), falling back to serial " + "processing", + mem_usage, + ) + content_audio_bytes = self._text_to_speech_serial(chunks) + else: + # Determine max workers + max_workers = min( + 4, # Reasonable limit to avoid rate limiting + len(chunks), # No more workers than chunks + max(1, psutil.cpu_count() // 2), # Use half of CPU cores + ) + + logger.info( + "Using %d workers for parallel TTS processing", + max_workers, + ) + + # Process chunks in parallel + chunk_results: list[tuple[int, bytes]] = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + ) as executor: + # Submit all chunks for processing + future_to_index = { + executor.submit(self._process_tts_chunk, chunk, i): i + for i, chunk in enumerate(chunks) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed( + future_to_index, + ): + index = future_to_index[future] + try: + audio_data = future.result() + chunk_results.append((index, audio_data)) + except Exception: + logger.exception( + "Failed to process chunk %d", + index, + ) + raise + + # Sort results by index to maintain order + chunk_results.sort(key=operator.itemgetter(0)) + + # Combine audio chunks + content_audio_bytes = self._combine_audio_chunks([ + data for _, data in chunk_results + ]) + + # Combine intro, content, and outro with pauses + return ArticleProcessor._combine_intro_content_outro( + intro_audio, + content_audio_bytes, + outro_audio, + ) + + except Exception: + logger.exception("TTS generation failed") + raise + + @staticmethod + def _create_intro_text( + title: str, + author: str | None, + pub_date: str | None, + ) -> str: + """Create intro text with available metadata.""" + parts = [f"Title: {title}"] + + if author: + parts.append(f"Author: {author}") + + if pub_date: + parts.append(f"Published: {pub_date}") + + return ". ".join(parts) + "." + + @staticmethod + def _create_outro_text(title: str, author: str | None) -> str: + """Create outro text with attribution.""" + if author: + return ( + f"This has been an audio version of {title} " + f"by {author}, created using Podcast It Later." + ) + return ( + f"This has been an audio version of {title}, " + "created using Podcast It Later." + ) + + def _generate_tts_segment(self, text: str) -> bytes: + """Generate TTS audio for a single segment (intro/outro). + + Args: + text: Text to convert to speech + + Returns: + MP3 audio bytes + """ + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=text, + ) + return response.content + + @staticmethod + def _combine_intro_content_outro( + intro_audio: bytes, + content_audio: bytes, + outro_audio: bytes, + ) -> bytes: + """Combine intro, content, and outro with crossfades. + + Args: + intro_audio: MP3 bytes for intro + content_audio: MP3 bytes for main content + outro_audio: MP3 bytes for outro + + Returns: + Combined MP3 audio bytes + """ + # Load audio segments + intro = AudioSegment.from_mp3(io.BytesIO(intro_audio)) + content = AudioSegment.from_mp3(io.BytesIO(content_audio)) + outro = AudioSegment.from_mp3(io.BytesIO(outro_audio)) + + # Create bridge silence (pause + 2 * crossfade to account for overlap) + bridge = AudioSegment.silent( + duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION + ) + + def safe_append( + seg1: AudioSegment, seg2: AudioSegment, crossfade: int + ) -> AudioSegment: + if len(seg1) < crossfade or len(seg2) < crossfade: + logger.warning( + "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation", + crossfade, + len(seg1), + len(seg2), + ) + return seg1 + seg2 + return seg1.append(seg2, crossfade=crossfade) + + # Combine segments with crossfades + # Intro -> Bridge -> Content -> Bridge -> Outro + # This effectively fades out the previous segment and fades in the next one + combined = safe_append(intro, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, content, CROSSFADE_DURATION) + combined = safe_append(combined, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, outro, CROSSFADE_DURATION) + + # Export to bytes + output = io.BytesIO() + combined.export(output, format="mp3") + return output.getvalue() + + def _process_tts_chunk(self, chunk: str, index: int) -> bytes: + """Process a single TTS chunk. + + Args: + chunk: Text to convert to speech + index: Chunk index for logging + + Returns: + Audio data as bytes + """ + logger.info( + "Generating TTS for chunk %d (%d chars)", + index + 1, + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + return response.content + + @staticmethod + def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes: + """Combine multiple audio chunks with silence gaps. + + Args: + audio_chunks: List of audio data in order + + Returns: + Combined audio data + """ + if not audio_chunks: + msg = "No audio chunks to combine" + raise ValueError(msg) + + # Create a temporary file for the combined audio + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Start with the first chunk + combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0])) + + # Add remaining chunks with silence gaps + for chunk_data in audio_chunks[1:]: + chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data)) + silence = AudioSegment.silent(duration=300) # 300ms gap + combined_audio = combined_audio + silence + chunk_audio + + # Export to file + combined_audio.export(temp_path, format="mp3", bitrate="128k") + + # Read back the combined audio + return Path(temp_path).read_bytes() + + finally: + # Clean up temp file + if Path(temp_path).exists(): + Path(temp_path).unlink() + + def _text_to_speech_serial(self, chunks: list[str]) -> bytes: + """Fallback serial processing for high memory situations. + + This is the original serial implementation. + """ + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunks[0], + response_format="mp3", + ) + + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment + + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data + + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() + + def upload_to_s3(self, audio_data: bytes, filename: str) -> str: + """Upload audio file to S3-compatible storage and return public URL. + + Raises: + ValueError: If S3 client is not configured. + ClientError: If S3 upload fails. + """ + if not self.s3_client: + msg = "S3 client not configured" + raise ValueError(msg) + + try: + # Upload file using streaming to minimize memory usage + audio_stream = io.BytesIO(audio_data) + self.s3_client.upload_fileobj( + audio_stream, + S3_BUCKET, + filename, + ExtraArgs={ + "ContentType": "audio/mpeg", + "ACL": "public-read", + }, + ) + + # Construct public URL + audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" + logger.info( + "Uploaded audio to: %s (%d bytes)", + audio_url, + len(audio_data), + ) + except ClientError: + logger.exception("S3 upload failed") + raise + else: + return audio_url + + @staticmethod + def estimate_duration(audio_data: bytes) -> int: + """Estimate audio duration in seconds based on file size and bitrate.""" + # Rough estimation: MP3 at 128kbps = ~16KB per second + estimated_seconds = len(audio_data) // 16000 + return max(1, estimated_seconds) # Minimum 1 second + + @staticmethod + def generate_filename(job_id: int, title: str) -> str: + """Generate unique filename for audio file.""" + timestamp = int(datetime.now(tz=timezone.utc).timestamp()) + # Create safe filename from title + safe_title = "".join( + c for c in title if c.isalnum() or c in {" ", "-", "_"} + ).rstrip() + safe_title = safe_title.replace(" ", "_")[:50] # Limit length + return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" + + def process_job( + self, + job: dict[str, Any], + ) -> None: + """Process a single job through the complete pipeline.""" + job_id = job["id"] + url = job["url"] + + # Check memory before starting + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD: + logger.warning( + "High memory usage (%.1f%%), deferring job %d", + mem_usage, + job_id, + ) + return + + # Track current job for graceful shutdown + self.shutdown_handler.set_current_job(job_id) + + try: + logger.info("Processing job %d: %s", job_id, url) + + # Update status to processing + Core.Database.update_job_status( + job_id, + "processing", + ) + + # Check for shutdown before each major step + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 1: Extract article content + Core.Database.update_job_status(job_id, "extracting") + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content(url) + ) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 2: Generate audio with metadata + Core.Database.update_job_status(job_id, "synthesizing") + audio_data = self.text_to_speech(content, title, author, pub_date) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 3: Upload to S3 + Core.Database.update_job_status(job_id, "uploading") + filename = ArticleProcessor.generate_filename(job_id, title) + audio_url = self.upload_to_s3(audio_data, filename) + + # Step 4: Calculate duration + duration = ArticleProcessor.estimate_duration(audio_data) + + # Step 5: Create episode record + url_hash = Core.hash_url(url) + episode_id = Core.Database.create_episode( + title=title, + audio_url=audio_url, + duration=duration, + content_length=len(content), + user_id=job.get("user_id"), + author=job.get("author"), # Pass author from job + original_url=url, # Pass the original article URL + original_url_hash=url_hash, + ) + + # Add episode to user's feed + user_id = job.get("user_id") + if user_id: + Core.Database.add_episode_to_user(user_id, episode_id) + Core.Database.track_episode_event( + episode_id, + "added", + user_id, + ) + + # Step 6: Mark job as complete + Core.Database.update_job_status( + job_id, + "completed", + ) + + logger.info( + "Successfully processed job %d -> episode %d", + job_id, + episode_id, + ) + + except Exception as e: + error_msg = str(e) + logger.exception("Job %d failed: %s", job_id, error_msg) + Core.Database.update_job_status( + job_id, + "error", + error_msg, + ) + raise + finally: + # Clear current job + self.shutdown_handler.set_current_job(None) + + +def prepare_text_for_tts(text: str, title: str) -> list[str]: + """Use LLM to prepare text for TTS, returning chunks ready for speech. + + First splits text mechanically, then has LLM edit each chunk. + """ + # First, split the text into manageable chunks + raw_chunks = split_text_into_chunks(text, max_chars=3000) + + logger.info("Split article into %d raw chunks", len(raw_chunks)) + + # Prepare the first chunk with intro + edited_chunks = [] + + for i, chunk in enumerate(raw_chunks): + is_first = i == 0 + is_last = i == len(raw_chunks) - 1 + + try: + edited_chunk = edit_chunk_for_speech( + chunk, + title=title if is_first else None, + is_first=is_first, + is_last=is_last, + ) + edited_chunks.append(edited_chunk) + except Exception: + logger.exception("Failed to edit chunk %d", i + 1) + # Fall back to raw chunk if LLM fails + if is_first: + edited_chunks.append( + f"This is an audio version of {title}. {chunk}", + ) + elif is_last: + edited_chunks.append(f"{chunk} This concludes the article.") + else: + edited_chunks.append(chunk) + + return edited_chunks + + +def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: + """Split text into chunks at sentence boundaries.""" + chunks = [] + current_chunk = "" + + # Split into paragraphs first + paragraphs = text.split("\n\n") + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + # If paragraph itself is too long, split by sentences + if len(para_stripped) > max_chars: + sentences = para_stripped.split(". ") + for sentence in sentences: + if len(current_chunk) + len(sentence) + 2 < max_chars: + current_chunk += sentence + ". " + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + ". " + # If adding this paragraph would exceed limit, start new chunk + elif len(current_chunk) + len(para_stripped) + 2 > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = para_stripped + " " + else: + current_chunk += para_stripped + " " + + # Don't forget the last chunk + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def edit_chunk_for_speech( + chunk: str, + title: str | None = None, + *, + is_first: bool = False, + is_last: bool = False, +) -> str: + """Use LLM to lightly edit a single chunk for speech. + + Raises: + ValueError: If no content is returned from LLM. + """ + system_prompt = ( + "You are a podcast script editor. Your job is to lightly edit text " + "to make it sound natural when spoken aloud.\n\n" + "Guidelines:\n" + ) + system_prompt += """ +- Remove URLs and email addresses, replacing with descriptive phrases +- Convert bullet points and lists into flowing sentences +- Fix any awkward phrasing for speech +- Remove references like "click here" or "see below" +- Keep edits minimal - preserve the original content and style +- Do NOT add commentary or explanations +- Return ONLY the edited text, no JSON or formatting +""" + + user_prompt = chunk + + # Add intro/outro if needed + if is_first and title: + user_prompt = ( + f"Add a brief intro mentioning this is an audio version of " + f"'{title}', then edit this text:\n\n{chunk}" + ) + elif is_last: + user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" + + try: + client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.3, # Lower temperature for more consistent edits + max_tokens=4000, + ) + + content = response.choices[0].message.content + if not content: + msg = "No content returned from LLM" + raise ValueError(msg) # noqa: TRY301 + + # Ensure the chunk isn't too long + max_chunk_length = 4000 + if len(content) > max_chunk_length: + # Truncate at sentence boundary + sentences = content.split(". ") + truncated = "" + for sentence in sentences: + if len(truncated) + len(sentence) + 2 < max_chunk_length: + truncated += sentence + ". " + else: + break + content = truncated.strip() + + except Exception: + logger.exception("LLM chunk editing failed") + raise + else: + return content + + +def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: + """Parse datetime string and ensure it has timezone info.""" + if isinstance(created_at, str): + # Handle timezone-aware datetime strings + if created_at.endswith("Z"): + created_at = created_at[:-1] + "+00:00" + last_attempt = datetime.fromisoformat(created_at) + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + else: + last_attempt = created_at + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + return last_attempt + + +def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: + """Check if a job should be retried based on retry count and backoff time. + + Uses exponential backoff to determine if enough time has passed. + """ + retry_count = job["retry_count"] + if retry_count >= max_retries: + return False + + # Exponential backoff: 30s, 60s, 120s + backoff_time = 30 * (2**retry_count) + last_attempt = parse_datetime_with_timezone(job["created_at"]) + time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt + + return time_since_attempt > timedelta(seconds=backoff_time) + + +def process_pending_jobs( + processor: ArticleProcessor, +) -> None: + """Process all pending jobs.""" + pending_jobs = Core.Database.get_pending_jobs( + limit=5, + ) + + for job in pending_jobs: + if processor.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, stopping job processing") + break + + current_job = job["id"] + try: + processor.process_job(job) + except Exception as e: + # Ensure job is marked as error even if process_job didn't handle it + logger.exception("Failed to process job: %d", current_job) + # Check if job is still in processing state + current_status = Core.Database.get_job_by_id( + current_job, + ) + if current_status and current_status.get("status") == "processing": + Core.Database.update_job_status( + current_job, + "error", + str(e), + ) + continue + + +def process_retryable_jobs() -> None: + """Check and retry failed jobs with exponential backoff.""" + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + for job in retryable_jobs: + if should_retry_job(job, MAX_RETRIES): + logger.info( + "Retrying job %d (attempt %d)", + job["id"], + job["retry_count"] + 1, + ) + Core.Database.update_job_status( + job["id"], + "pending", + ) + + +def check_memory_usage() -> int | Any: + """Check current memory usage percentage.""" + try: + process = psutil.Process() + # this returns an int but psutil is untyped + return process.memory_percent() + except (psutil.Error, OSError): + logger.warning("Failed to check memory usage") + return 0.0 + + +def cleanup_stale_jobs() -> None: + """Reset jobs stuck in processing state on startup.""" + with Core.Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE queue + SET status = 'pending' + WHERE status = 'processing' + """, + ) + affected = cursor.rowcount + conn.commit() + + if affected > 0: + logger.info( + "Reset %d stale jobs from processing to pending", + affected, + ) + + +def main_loop() -> None: + """Poll for jobs and process them in a continuous loop.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + # Clean up any stale jobs from previous runs + cleanup_stale_jobs() + + logger.info("Worker started, polling for jobs...") + + while not shutdown_handler.is_shutdown_requested(): + try: + # Process pending jobs + process_pending_jobs(processor) + process_retryable_jobs() + + # Check if there's any work + pending_jobs = Core.Database.get_pending_jobs( + limit=1, + ) + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + if not pending_jobs and not retryable_jobs: + logger.debug("No jobs to process, sleeping...") + + except Exception: + logger.exception("Error in main loop") + + # Use interruptible sleep + if not shutdown_handler.is_shutdown_requested(): + shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL) + + # Graceful shutdown + current_job = shutdown_handler.get_current_job() + if current_job: + logger.info( + "Waiting for job %d to complete before shutdown...", + current_job, + ) + # The job will complete or be reset to pending + + logger.info("Worker shutdown complete") + + +def move() -> None: + """Make the worker move.""" + try: + # Initialize database + Core.Database.init_db() + + # Start main processing loop + main_loop() + + except KeyboardInterrupt: + logger.info("Worker stopped by user") + except Exception: + logger.exception("Worker crashed") + raise + + +class TestArticleExtraction(Test.TestCase): + """Test article extraction functionality.""" + + def test_extract_valid_article(self) -> None: + """Extract from well-formed HTML.""" + # Mock trafilatura.fetch_url and extract + mock_html = ( + "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" + ) + mock_result = json.dumps({ + "title": "Test Article", + "text": "Content here", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Content here") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_missing_title(self) -> None: + """Handle articles without titles.""" + mock_html = "<html><body><p>Content without title</p></body></html>" + mock_result = json.dumps({"text": "Content without title"}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Untitled Article") + self.assertEqual(content, "Content without title") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_empty_content(self) -> None: + """Handle empty articles.""" + mock_html = "<html><body></body></html>" + mock_result = json.dumps({"title": "Empty Article", "text": ""}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + pytest.raises(ValueError, match="No content extracted") as cm, + ): + ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertIn("No content extracted", str(cm.value)) + + def test_extract_network_error(self) -> None: + """Handle connection failures.""" + with ( + unittest.mock.patch("trafilatura.fetch_url", return_value=None), + pytest.raises(ValueError, match="Failed to download") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Failed to download", str(cm.value)) + + @staticmethod + def test_extract_timeout() -> None: + """Handle slow responses.""" + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + side_effect=TimeoutError("Timeout"), + ), + pytest.raises(TimeoutError), + ): + ArticleProcessor.extract_article_content("https://example.com") + + def test_content_sanitization(self) -> None: + """Remove unwanted elements.""" + mock_html = """ + <html><body> + <h1>Article</h1> + <p>Good content</p> + <script>alert('bad')</script> + <table><tr><td>data</td></tr></table> + </body></html> + """ + mock_result = json.dumps({ + "title": "Article", + "text": "Good content", # Tables and scripts removed + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + _title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(content, "Good content") + self.assertNotIn("script", content) + self.assertNotIn("table", content) + + +class TestTextToSpeech(Test.TestCase): + """Test text-to-speech functionality.""" + + def setUp(self) -> None: + """Set up mocks.""" + # Mock OpenAI API key + self.env_patcher = unittest.mock.patch.dict( + os.environ, + {"OPENAI_API_KEY": "test-key"}, + ) + self.env_patcher.start() + + # Mock OpenAI response + self.mock_audio_response: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_response.content = b"fake-audio-data" + + # Mock AudioSegment to avoid ffmpeg issues in tests + self.mock_audio_segment: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_segment.export.return_value = None + self.audio_segment_patcher = unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=self.mock_audio_segment, + ) + self.audio_segment_patcher.start() + + # Mock the concatenation operations + self.mock_audio_segment.__add__.return_value = self.mock_audio_segment + + def tearDown(self) -> None: + """Clean up mocks.""" + self.env_patcher.stop() + self.audio_segment_patcher.stop() + + def test_tts_generation(self) -> None: + """Generate audio from text.""" + + # Mock the export to write test audio data + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["Test content"], + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + "Test content", + "Test Title", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_chunking(self) -> None: + """Handle long articles with chunking.""" + long_text = "Long content " * 1000 + chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock AudioSegment.silent + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + long_text, + "Long Article", + ) + + # Should have called TTS for each chunk + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_empty_text(self) -> None: + """Handle empty input.""" + with unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[], + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + with pytest.raises(ValueError, match="No chunks generated") as cm: + processor.text_to_speech("", "Empty") + + self.assertIn("No chunks generated", str(cm.value)) + + def test_tts_special_characters(self) -> None: + """Handle unicode and special chars.""" + special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[special_text], + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + special_text, + "Special", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_llm_text_preparation(self) -> None: + """Verify LLM editing.""" + # Test the actual text preparation functions + chunks = split_text_into_chunks("Short text", max_chars=100) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], "Short text") + + # Test long text splitting + long_text = "Sentence one. " * 100 + chunks = split_text_into_chunks(long_text, max_chars=100) + self.assertGreater(len(chunks), 1) + for chunk in chunks: + self.assertLessEqual(len(chunk), 100) + + @staticmethod + def test_llm_failure_fallback() -> None: + """Handle LLM API failures.""" + # Mock LLM failure + with unittest.mock.patch("openai.OpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_client.chat.completions.create.side_effect = Exception( + "API Error", + ) + + # Should fall back to raw text + with pytest.raises(Exception, match="API Error"): + edit_chunk_for_speech("Test chunk", "Title", is_first=True) + + def test_chunk_concatenation(self) -> None: + """Verify audio joining.""" + # Mock multiple audio segments + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test", "Title") + + # Should produce combined audio + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_parallel_tts_generation(self) -> None: + """Test parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"] + + # Mock responses for each chunk + mock_responses = [] + for i in range(len(chunks)): + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{i}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # Make create return different responses for each call + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Mock AudioSegment operations + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + Path(path).write_bytes(b"combined-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=mock_segment, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, # Normal memory usage + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Verify all chunks were processed + self.assertEqual(mock_speech.create.call_count, len(chunks)) + self.assertEqual(audio_data, b"combined-audio") + + def test_parallel_tts_high_memory_fallback(self) -> None: + """Test fallback to serial processing when memory is high.""" + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"serial-audio") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=65.0, # High memory usage + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Should use serial processing + self.assertEqual(audio_data, b"serial-audio") + + @staticmethod + def test_parallel_tts_error_handling() -> None: + """Test error handling in parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2"] + + # Mock OpenAI client with one failure + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # First call succeeds, second fails + mock_resp1 = unittest.mock.MagicMock() + mock_resp1.content = b"audio-1" + mock_speech.create.side_effect = [mock_resp1, Exception("API Error")] + + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Set up the test context + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + pytest.raises(Exception, match="API Error"), + ): + processor.text_to_speech("Test content", "Test Title") + + def test_parallel_tts_order_preservation(self) -> None: + """Test that chunks are combined in the correct order.""" + chunks = ["First", "Second", "Third", "Fourth", "Fifth"] + + # Create mock responses with identifiable content + mock_responses = [] + for chunk in chunks: + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{chunk}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Track the order of segments being combined + combined_order = [] + + def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock: + content = data.read() + combined_order.append(content.decode()) + segment = unittest.mock.MagicMock() + segment.__add__.return_value = segment + return segment + + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + # Verify order is preserved + expected_order = [f"audio-{chunk}" for chunk in chunks] + if combined_order != expected_order: + msg = f"Order mismatch: {combined_order} != {expected_order}" + raise AssertionError(msg) + Path(path).write_bytes(b"ordered-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + side_effect=mock_from_mp3, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + self.assertEqual(audio_data, b"ordered-audio") + + +class TestIntroOutro(Test.TestCase): + """Test intro and outro generation with metadata.""" + + def test_create_intro_text_full_metadata(self) -> None: + """Test intro text creation with all metadata.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author="John Doe", + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertIn("Author: John Doe", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_no_author(self) -> None: + """Test intro text without author.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertNotIn("Author:", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_minimal(self) -> None: + """Test intro text with only title.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date=None, + ) + self.assertEqual(intro, "Title: Test Article.") + + def test_create_outro_text_with_author(self) -> None: + """Test outro text with author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author="Jane Smith", + ) + self.assertIn("Test Article", outro) + self.assertIn("Jane Smith", outro) + self.assertIn("Podcast It Later", outro) + + def test_create_outro_text_no_author(self) -> None: + """Test outro text without author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author=None, + ) + self.assertIn("Test Article", outro) + self.assertNotIn("by", outro) + self.assertIn("Podcast It Later", outro) + + def test_extract_with_metadata(self) -> None: + """Test that extraction returns metadata.""" + mock_html = "<html><body><p>Content</p></body></html>" + mock_result = json.dumps({ + "title": "Test Article", + "text": "Article content", + "author": "Test Author", + "date": "2024-01-15", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Article content") + self.assertEqual(author, "Test Author") + self.assertEqual(pub_date, "2024-01-15") + + +class TestMemoryEfficiency(Test.TestCase): + """Test memory-efficient processing.""" + + def test_large_article_size_limit(self) -> None: + """Test that articles exceeding size limits are rejected.""" + huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=huge_text * 4, # Simulate large HTML + ), + pytest.raises(ValueError, match="Article too large") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Article too large", str(cm.value)) + + def test_content_truncation(self) -> None: + """Test that oversized content is truncated.""" + large_content = "Content " * 100_000 # Create large content + mock_result = json.dumps({ + "title": "Large Article", + "text": large_content, + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value="<html><body>content</body></html>", + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Large Article") + self.assertLessEqual(len(content), MAX_ARTICLE_SIZE) + + def test_memory_usage_check(self) -> None: + """Test memory usage monitoring.""" + usage = check_memory_usage() + self.assertIsInstance(usage, float) + self.assertGreaterEqual(usage, 0.0) + self.assertLessEqual(usage, 100.0) + + +class TestJobProcessing(Test.TestCase): + """Test job processing functionality.""" + + def setUp(self) -> None: + """Set up test environment.""" + Core.Database.init_db() + + # Create test user and job + self.user_id, _ = Core.Database.create_user( + "test@example.com", + ) + self.job_id = Core.Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + # Mock environment + self.env_patcher = unittest.mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test-key", + "S3_ENDPOINT": "https://s3.example.com", + "S3_BUCKET": "test-bucket", + "S3_ACCESS_KEY": "test-access", + "S3_SECRET_KEY": "test-secret", + }, + ) + self.env_patcher.start() + + def tearDown(self) -> None: + """Clean up.""" + self.env_patcher.stop() + Core.Database.teardown() + + def test_process_job_success(self) -> None: + """Complete pipeline execution.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock all external calls + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio-data", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://s3.example.com/audio.mp3", + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.create_episode", + ) as mock_create, + ): + mock_create.return_value = 1 + processor.process_job(job) + + # Verify job was marked complete + mock_update.assert_called_with(self.job_id, "completed") + mock_create.assert_called_once() + + def test_process_job_extraction_failure(self) -> None: + """Handle bad URLs.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + side_effect=ValueError("Bad URL"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(ValueError, match="Bad URL"), + ): + processor.process_job(job) + + # Job should be marked as error + mock_update.assert_called_with(self.job_id, "error", "Bad URL") + + def test_process_job_tts_failure(self) -> None: + """Handle TTS errors.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=Exception("TTS Error"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(Exception, match="TTS Error"), + ): + processor.process_job(job) + + mock_update.assert_called_with(self.job_id, "error", "TTS Error") + + def test_process_job_s3_failure(self) -> None: + """Handle upload errors.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + side_effect=ClientError({}, "PutObject"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ), + pytest.raises(ClientError), + ): + processor.process_job(job) + + def test_job_retry_logic(self) -> None: + """Verify exponential backoff.""" + # Set job to error with retry count + Core.Database.update_job_status( + self.job_id, + "error", + "First failure", + ) + Core.Database.update_job_status( + self.job_id, + "error", + "Second failure", + ) + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + self.assertEqual(job["retry_count"], 2) + + # Should be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + + def test_max_retries(self) -> None: + """Stop after max attempts.""" + # Exceed retry limit + for i in range(4): + Core.Database.update_job_status( + self.job_id, + "error", + f"Failure {i}", + ) + + # Should not be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_graceful_shutdown(self) -> None: + """Test graceful shutdown during job processing.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + def mock_tts(*_args: Any) -> bytes: + shutdown_handler.shutdown_requested.set() + return b"audio-data" + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=mock_tts, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should be reset to pending due to shutdown + mock_update.assert_any_call(self.job_id, "pending") + + def test_cleanup_stale_jobs(self) -> None: + """Test cleanup of stale processing jobs.""" + # Manually set job to processing + Core.Database.update_job_status(self.job_id, "processing") + + # Run cleanup + cleanup_stale_jobs() + + # Job should be back to pending + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + self.assertEqual(job["status"], "pending") + + def test_concurrent_processing(self) -> None: + """Handle multiple jobs.""" + # Create multiple jobs + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Core.Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Core.Database.get_pending_jobs(limit=5) + + self.assertEqual(len(jobs), 3) + self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + + def test_memory_threshold_deferral(self) -> None: + """Test that jobs are deferred when memory usage is high.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock high memory usage + with ( + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=90.0, # High memory usage + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should not be processed (no status updates) + mock_update.assert_not_called() + + +class TestWorkerErrorHandling(Test.TestCase): + """Test worker error handling and recovery.""" + + def setUp(self) -> None: + """Set up test environment.""" + Core.Database.init_db() + self.user_id, _ = Core.Database.create_user("test@example.com") + self.job_id = Core.Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + self.shutdown_handler = ShutdownHandler() + self.processor = ArticleProcessor(self.shutdown_handler) + + @staticmethod + def tearDown() -> None: + """Clean up.""" + Core.Database.teardown() + + def test_process_pending_jobs_exception_handling(self) -> None: + """Test that process_pending_jobs handles exceptions.""" + + def side_effect(job: dict[str, Any]) -> None: + # Simulate process_job starting and setting status to processing + Core.Database.update_job_status(job["id"], "processing") + msg = "Unexpected Error" + raise ValueError(msg) + + with ( + unittest.mock.patch.object( + self.processor, + "process_job", + side_effect=side_effect, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + side_effect=Core.Database.update_job_status, + ) as _mock_update, + ): + process_pending_jobs(self.processor) + + # Job should be marked as error + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "error") + self.assertIn("Unexpected Error", job["error_message"]) + + def test_process_retryable_jobs_success(self) -> None: + """Test processing of retryable jobs.""" + # Set up a retryable job + Core.Database.update_job_status(self.job_id, "error", "Fail 1") + + # Modify created_at to be in the past to satisfy backoff + with Core.Database.get_connection() as conn: + conn.execute( + "UPDATE queue SET created_at = ? WHERE id = ?", + ( + ( + datetime.now(tz=timezone.utc) - timedelta(minutes=5) + ).isoformat(), + self.job_id, + ), + ) + conn.commit() + + process_retryable_jobs() + + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "pending") + + def test_process_retryable_jobs_not_ready(self) -> None: + """Test that jobs are not retried before backoff period.""" + # Set up a retryable job that just failed + Core.Database.update_job_status(self.job_id, "error", "Fail 1") + + # created_at is now, so backoff should prevent retry + process_retryable_jobs() + + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "error") + + +class TestTextChunking(Test.TestCase): + """Test text chunking edge cases.""" + + def test_split_text_single_long_word(self) -> None: + """Handle text with a single word exceeding limit.""" + long_word = "a" * 4000 + chunks = split_text_into_chunks(long_word, max_chars=3000) + + # Should keep it as one chunk or split? + # The current implementation does not split words + self.assertEqual(len(chunks), 1) + self.assertEqual(len(chunks[0]), 4000) + + def test_split_text_no_sentence_boundaries(self) -> None: + """Handle long text with no sentence boundaries.""" + text = "word " * 1000 # 5000 chars + chunks = split_text_into_chunks(text, max_chars=3000) + + # Should keep it as one chunk as it can't split by ". " + self.assertEqual(len(chunks), 1) + self.assertGreater(len(chunks[0]), 3000) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestArticleExtraction, + TestTextToSpeech, + TestMemoryEfficiency, + TestJobProcessing, + TestWorkerErrorHandling, + TestTextChunking, + ], + ) + + +def main() -> None: + """Entry point for the worker.""" + if "test" in sys.argv: + test() + else: + move() diff --git a/Biz/Que/Host.hs b/Biz/Que/Host.hs index 834ce0e..8d826b4 100755 --- a/Biz/Que/Host.hs +++ b/Biz/Que/Host.hs @@ -33,6 +33,7 @@ import qualified Control.Exception as Exception import Data.HashMap.Lazy (HashMap) import qualified Data.HashMap.Lazy as HashMap import Network.HTTP.Media ((//), (/:)) +import Network.Socket (SockAddr (..)) import qualified Network.Wai.Handler.Warp as Warp import qualified Omni.Cli as Cli import qualified Omni.Log as Log @@ -75,7 +76,30 @@ Usage: |] test :: Test.Tree -test = Test.group "Biz.Que.Host" [Test.unit "id" <| 1 @=? (1 :: Integer)] +test = + Test.group + "Biz.Que.Host" + [ Test.unit "id" <| 1 @=? (1 :: Integer), + Test.unit "putQue requires auth for '_'" <| do + st <- atomically <| STM.newTVar mempty + let cfg = Envy.defConfig + let handlers = paths cfg + + -- Case 1: No auth, should fail + let nonLocalHost = SockAddrInet 0 0 + let handler1 = putQue handlers nonLocalHost Nothing "_" "testq" "body" + res1 <- Servant.runHandler (runReaderT handler1 st) + case res1 of + Left err -> if errHTTPCode err == 401 then pure () else Test.assertFailure ("Expected 401, got " <> show err) + Right _ -> Test.assertFailure "Expected failure, got success" + + -- Case 2: Correct auth, should succeed + let handler2 = putQue handlers nonLocalHost (Just "admin-key") "_" "testq" "body" + res2 <- Servant.runHandler (runReaderT handler2 st) + case res2 of + Left err -> Test.assertFailure (show err) + Right _ -> pure () + ] type App = ReaderT AppState Servant.Handler @@ -125,23 +149,31 @@ data Paths path = Paths :- Get '[JSON] NoContent, dash :: path - :- "_" + :- RemoteHost + :> Header "Authorization" Text + :> "_" :> "dash" :> Get '[JSON] Ques, getQue :: path - :- Capture "ns" Text + :- RemoteHost + :> Header "Authorization" Text + :> Capture "ns" Text :> Capture "quename" Text :> Get '[PlainText, HTML, OctetStream] Message, getStream :: path - :- Capture "ns" Text + :- RemoteHost + :> Header "Authorization" Text + :> Capture "ns" Text :> Capture "quename" Text :> "stream" :> StreamGet NoFraming OctetStream (SourceIO Message), putQue :: path - :- Capture "ns" Text + :- RemoteHost + :> Header "Authorization" Text + :> Capture "ns" Text :> Capture "quepath" Text :> ReqBody '[PlainText, HTML, OctetStream] Text :> Post '[PlainText, HTML, OctetStream] NoContent @@ -149,15 +181,15 @@ data Paths path = Paths deriving (Generic) paths :: Config -> Paths (AsServerT App) -paths _ = - -- TODO revive authkey stuff - -- - read Authorization header, compare with queSkey - -- - Only allow my IP or localhost to publish to '_' namespace +paths Config {..} = Paths { home = throwError <| err301 {errHeaders = [("Location", "/_/index")]}, - dash = gets, - getQue = \ns qn -> do + dash = \rh mAuth -> do + checkAuth queSkey rh mAuth "_" + gets, + getQue = \rh mAuth ns qn -> do + checkAuth queSkey rh mAuth ns guardNs ns ["pub", "_"] modify <| upsertNamespace ns q <- que ns qn @@ -165,7 +197,8 @@ paths _ = |> liftIO +> Go.tap |> liftIO, - getStream = \ns qn -> do + getStream = \rh mAuth ns qn -> do + checkAuth queSkey rh mAuth ns guardNs ns ["pub", "_"] modify <| upsertNamespace ns q <- que ns qn @@ -174,7 +207,8 @@ paths _ = +> Go.tap |> Source.fromAction (const False) -- peek chan instead of False? |> pure, - putQue = \ns qp body -> do + putQue = \rh mAuth ns qp body -> do + checkAuth queSkey rh mAuth ns guardNs ns ["pub", "_"] modify <| upsertNamespace ns q <- que ns qp @@ -188,6 +222,19 @@ paths _ = >> pure NoContent } +checkAuth :: Text -> SockAddr -> Maybe Text -> Text -> App () +checkAuth skey rh mAuth ns = do + let authorized = mAuth == Just skey + let isLocal = isLocalhost rh + when (ns == "_" && not (authorized || isLocal)) <| do + throwError err401 {errBody = "Authorized access only for '_' namespace"} + +isLocalhost :: SockAddr -> Bool +isLocalhost (SockAddrInet _ h) = h == 0x0100007f -- 127.0.0.1 +isLocalhost (SockAddrInet6 _ _ (0, 0, 0, 1) _) = True -- ::1 +isLocalhost (SockAddrUnix _) = True +isLocalhost _ = False + -- | Given `guardNs ns whitelist`, if `ns` is not in the `whitelist` -- list, return a 405 error. guardNs :: (Applicative a, MonadError ServerError a) => Text -> [Text] -> a () diff --git a/Biz/Storybook.py b/Biz/Storybook.py index dbaf82a..164e845 100755 --- a/Biz/Storybook.py +++ b/Biz/Storybook.py @@ -56,7 +56,8 @@ PORT = int(os.environ.get("PORT", "3000")) area = App.from_env() app = ludic.web.LudicApp(debug=area == App.Area.Test) -log = Log.setup(logging.DEBUG if area == App.Area.Test else logging.ERROR) +log = logging.getLogger(__name__) +Log.setup(log, logging.DEBUG if area == App.Area.Test else logging.ERROR) Sqids = sqids.Sqids() @@ -309,7 +310,12 @@ def generate_image( size="1024x1024", quality="standard", ) - url = image_response.data[0].url + data = image_response.data + if data is None: + msg = "error getting data from OpenAI" + log.error(msg) + raise ludic.web.exceptions.InternalServerError(msg) + url = data[0].url if url is None: msg = "error getting image from OpenAI" log.error(msg) |
