diff options
| author | Ben Sima <ben@bsima.me> | 2025-08-13 13:36:30 -0400 |
|---|---|---|
| committer | Ben Sima <ben@bsima.me> | 2025-08-28 12:14:09 -0400 |
| commit | 0b005c192b2c141c7f6c9bff4a0702361814c21d (patch) | |
| tree | 3527a76137f6ee4dd970bba17a93617a311149cb /Biz/PodcastItLater/Worker.py | |
| parent | 7de0a3e0abbf1e152423e148d507e17b752a4982 (diff) | |
Prototype PodcastItLater
This implements a working prototype of PodcastItLater. It basically just works
for a single user currently, but the articles are nice to listen to and this is
something that we can start to build with.
Diffstat (limited to 'Biz/PodcastItLater/Worker.py')
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 1194 |
1 files changed, 1194 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py new file mode 100644 index 0000000..af51260 --- /dev/null +++ b/Biz/PodcastItLater/Worker.py @@ -0,0 +1,1194 @@ +"""Background worker for processing article-to-podcast conversions.""" + +# : dep boto3 +# : dep botocore +# : dep openai +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep trafilatura +# : out podcastitlater-worker +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import boto3 # type: ignore[import-untyped] +import io +import json +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import os +import pathlib +import pytest +import sys +import time +import trafilatura +import typing +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from pydub import AudioSegment # type: ignore[import-untyped] +from typing import Any + +logger = Log.setup() + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +S3_ENDPOINT = os.getenv("S3_ENDPOINT") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") +DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db") + +# Worker configuration +MAX_CONTENT_LENGTH = 5000 # characters for TTS +POLL_INTERVAL = 30 # seconds +MAX_RETRIES = 3 +TTS_MODEL = "tts-1" +TTS_VOICE = "alloy" + + +class ArticleProcessor: + """Handles the complete article-to-podcast conversion pipeline.""" + + def __init__(self) -> None: + """Initialize the processor with required services. + + Raises: + ValueError: If OPENAI_API_KEY environment variable is not set. + """ + if not OPENAI_API_KEY: + msg = "OPENAI_API_KEY environment variable is required" + raise ValueError(msg) + + self.openai_client: openai.OpenAI = openai.OpenAI( + api_key=OPENAI_API_KEY, + ) + + # Initialize S3 client for Digital Ocean Spaces + if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): + self.s3_client: Any = boto3.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + ) + else: + logger.warning("S3 configuration incomplete, uploads will fail") + self.s3_client = None + + @staticmethod + def extract_article_content(url: str) -> tuple[str, str]: + """Extract title and content from article URL using trafilatura. + + Raises: + ValueError: If content cannot be downloaded or extracted. + """ + try: + downloaded = trafilatura.fetch_url(url) + if not downloaded: + msg = f"Failed to download content from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Extract with metadata + result = trafilatura.extract( + downloaded, + include_comments=False, + include_tables=False, + with_metadata=True, + output_format="json", + ) + + if not result: + msg = f"Failed to extract content from {url}" + raise ValueError(msg) # noqa: TRY301 + + data = json.loads(result) + + title = data.get("title", "Untitled Article") + content = data.get("text", "") + + if not content: + msg = f"No content extracted from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Don't truncate - we'll handle length in text_to_speech + logger.info("Extracted article: %s (%d chars)", title, len(content)) + except Exception: + logger.exception("Failed to extract content from %s", url) + raise + else: + return title, content + + def text_to_speech(self, text: str, title: str) -> bytes: + """Convert text to speech using OpenAI TTS API. + + Uses LLM to prepare text, then handles chunking and concatenation. + + Raises: + ValueError: If no chunks are generated from text. + """ + try: + # Use LLM to prepare and chunk the text + chunks = prepare_text_for_tts(text, title) + + if not chunks: + msg = "No chunks generated from text" + raise ValueError(msg) # noqa: TRY301 + + logger.info("Processing %d chunks for TTS", len(chunks)) + + # Generate audio for each chunk + audio_segments = [] + for i, chunk in enumerate(chunks): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Convert bytes to AudioSegment + audio_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + audio_segments.append(audio_segment) + + # Small delay between API calls to be respectful + if i < len(chunks) - 1: + time.sleep(0.5) + + # Concatenate all audio segments + combined_audio = audio_segments[0] + for segment in audio_segments[1:]: + # Add a small silence between chunks for natural pacing + silence = AudioSegment.silent(duration=300) + combined_audio = combined_audio + silence + segment + + # Export combined audio to bytes + output_buffer = io.BytesIO() + combined_audio.export(output_buffer, format="mp3", bitrate="128k") + audio_data = output_buffer.getvalue() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + except Exception: + logger.exception("TTS generation failed") + raise + else: + return audio_data + + def upload_to_s3(self, audio_data: bytes, filename: str) -> str: + """Upload audio file to S3-compatible storage and return public URL. + + Raises: + ValueError: If S3 client is not configured. + ClientError: If S3 upload fails. + """ + if not self.s3_client: + msg = "S3 client not configured" + raise ValueError(msg) + + try: + # Upload file + self.s3_client.put_object( + Bucket=S3_BUCKET, + Key=filename, + Body=audio_data, + ContentType="audio/mpeg", + ACL="public-read", + ) + + # Construct public URL + audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" + logger.info("Uploaded audio to: %s", audio_url) + except ClientError: + logger.exception("S3 upload failed") + raise + else: + return audio_url + + @staticmethod + def estimate_duration(audio_data: bytes) -> int: + """Estimate audio duration in seconds based on file size and bitrate.""" + # Rough estimation: MP3 at 128kbps = ~16KB per second + estimated_seconds = len(audio_data) // 16000 + return max(1, estimated_seconds) # Minimum 1 second + + @staticmethod + def generate_filename(job_id: int, title: str) -> str: + """Generate unique filename for audio file.""" + timestamp = int(datetime.now(tz=timezone.utc).timestamp()) + # Create safe filename from title + safe_title = "".join( + c for c in title if c.isalnum() or c in {" ", "-", "_"} + ).rstrip() + safe_title = safe_title.replace(" ", "_")[:50] # Limit length + return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" + + def process_job(self, job: dict[str, Any]) -> None: + """Process a single job through the complete pipeline.""" + job_id = job["id"] + url = job["url"] + + try: + logger.info("Processing job %d: %s", job_id, url) + + # Update status to processing + Core.Database.update_job_status( + job_id, + "processing", + db_path=DATABASE_PATH, + ) + + # Step 1: Extract article content + title, content = ArticleProcessor.extract_article_content(url) + + # Step 2: Generate audio + audio_data = self.text_to_speech(content, title) + + # Step 3: Upload to S3 + filename = ArticleProcessor.generate_filename(job_id, title) + audio_url = self.upload_to_s3(audio_data, filename) + + # Step 4: Calculate duration + duration = ArticleProcessor.estimate_duration(audio_data) + + # Step 5: Create episode record + episode_id = Core.Database.create_episode( + title=title, + audio_url=audio_url, + duration=duration, + content_length=len(content), + user_id=job.get("user_id"), + db_path=DATABASE_PATH, + ) + + # Step 6: Mark job as complete + Core.Database.update_job_status( + job_id, + "completed", + db_path=DATABASE_PATH, + ) + + logger.info( + "Successfully processed job %d -> episode %d", + job_id, + episode_id, + ) + + except Exception as e: + error_msg = str(e) + logger.exception("Job %d failed: %s", job_id, error_msg) + Core.Database.update_job_status( + job_id, + "error", + error_msg, + DATABASE_PATH, + ) + raise + + +def prepare_text_for_tts(text: str, title: str) -> list[str]: + """Use LLM to prepare text for TTS, returning chunks ready for speech. + + First splits text mechanically, then has LLM edit each chunk. + """ + # First, split the text into manageable chunks + raw_chunks = split_text_into_chunks(text, max_chars=3000) + + logger.info("Split article into %d raw chunks", len(raw_chunks)) + + # Prepare the first chunk with intro + edited_chunks = [] + + for i, chunk in enumerate(raw_chunks): + is_first = i == 0 + is_last = i == len(raw_chunks) - 1 + + try: + edited_chunk = edit_chunk_for_speech( + chunk, + title=title if is_first else None, + is_first=is_first, + is_last=is_last, + ) + edited_chunks.append(edited_chunk) + except Exception: + logger.exception("Failed to edit chunk %d", i + 1) + # Fall back to raw chunk if LLM fails + if is_first: + edited_chunks.append( + f"This is an audio version of {title}. {chunk}", + ) + elif is_last: + edited_chunks.append(f"{chunk} This concludes the article.") + else: + edited_chunks.append(chunk) + + return edited_chunks + + +def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: + """Split text into chunks at sentence boundaries.""" + chunks = [] + current_chunk = "" + + # Split into paragraphs first + paragraphs = text.split("\n\n") + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + # If paragraph itself is too long, split by sentences + if len(para_stripped) > max_chars: + sentences = para_stripped.split(". ") + for sentence in sentences: + if len(current_chunk) + len(sentence) + 2 < max_chars: + current_chunk += sentence + ". " + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + ". " + # If adding this paragraph would exceed limit, start new chunk + elif len(current_chunk) + len(para_stripped) + 2 > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = para_stripped + " " + else: + current_chunk += para_stripped + " " + + # Don't forget the last chunk + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def edit_chunk_for_speech( + chunk: str, + title: str | None = None, + *, + is_first: bool = False, + is_last: bool = False, +) -> str: + """Use LLM to lightly edit a single chunk for speech. + + Raises: + ValueError: If no content is returned from LLM. + """ + system_prompt = ( + "You are a podcast script editor. Your job is to lightly edit text " + "to make it sound natural when spoken aloud.\n\n" + "Guidelines:\n" + ) + system_prompt += """ +- Remove URLs and email addresses, replacing with descriptive phrases +- Convert bullet points and lists into flowing sentences +- Fix any awkward phrasing for speech +- Remove references like "click here" or "see below" +- Keep edits minimal - preserve the original content and style +- Do NOT add commentary or explanations +- Return ONLY the edited text, no JSON or formatting +""" + + user_prompt = chunk + + # Add intro/outro if needed + if is_first and title: + user_prompt = ( + f"Add a brief intro mentioning this is an audio version of " + f"'{title}', then edit this text:\n\n{chunk}" + ) + elif is_last: + user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" + + try: + client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.3, # Lower temperature for more consistent edits + max_tokens=4000, + ) + + content = response.choices[0].message.content + if not content: + msg = "No content returned from LLM" + raise ValueError(msg) # noqa: TRY301 + + # Ensure the chunk isn't too long + max_chunk_length = 4000 + if len(content) > max_chunk_length: + # Truncate at sentence boundary + sentences = content.split(". ") + truncated = "" + for sentence in sentences: + if len(truncated) + len(sentence) + 2 < max_chunk_length: + truncated += sentence + ". " + else: + break + content = truncated.strip() + + except Exception: + logger.exception("LLM chunk editing failed") + raise + else: + return content + + +def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: + """Parse datetime string and ensure it has timezone info.""" + if isinstance(created_at, str): + # Handle timezone-aware datetime strings + if created_at.endswith("Z"): + created_at = created_at[:-1] + "+00:00" + last_attempt = datetime.fromisoformat(created_at) + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + else: + last_attempt = created_at + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + return last_attempt + + +def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: + """Check if a job should be retried based on retry count and backoff time. + + Uses exponential backoff to determine if enough time has passed. + """ + retry_count = job["retry_count"] + if retry_count >= max_retries: + return False + + # Exponential backoff: 30s, 60s, 120s + backoff_time = 30 * (2**retry_count) + last_attempt = parse_datetime_with_timezone(job["created_at"]) + time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt + + return time_since_attempt > timedelta(seconds=backoff_time) + + +def process_pending_jobs(processor: ArticleProcessor) -> None: + """Process all pending jobs.""" + pending_jobs = Core.Database.get_pending_jobs( + limit=5, + db_path=DATABASE_PATH, + ) + + for job in pending_jobs: + current_job = job["id"] + try: + processor.process_job(job) + except Exception as e: + # Ensure job is marked as error even if process_job didn't handle it + logger.exception("Failed to process job: %d", current_job) + # Check if job is still in processing state + current_status = Core.Database.get_job_by_id( + current_job, + DATABASE_PATH, + ) + if current_status and current_status.get("status") == "processing": + Core.Database.update_job_status( + current_job, + "error", + str(e), + DATABASE_PATH, + ) + continue + + +def process_retryable_jobs() -> None: + """Check and retry failed jobs with exponential backoff.""" + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + DATABASE_PATH, + ) + + for job in retryable_jobs: + if should_retry_job(job, MAX_RETRIES): + logger.info( + "Retrying job %d (attempt %d)", + job["id"], + job["retry_count"] + 1, + ) + Core.Database.update_job_status( + job["id"], + "pending", + db_path=DATABASE_PATH, + ) + + +def main_loop() -> None: + """Poll for jobs and process them in a continuous loop.""" + processor = ArticleProcessor() + logger.info("Worker started, polling for jobs...") + + while True: + try: + process_pending_jobs(processor) + process_retryable_jobs() + + # Check if there's any work + pending_jobs = Core.Database.get_pending_jobs( + limit=1, + db_path=DATABASE_PATH, + ) + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + DATABASE_PATH, + ) + + if not pending_jobs and not retryable_jobs: + logger.debug("No jobs to process, sleeping...") + + except Exception: + logger.exception("Error in main loop") + + time.sleep(POLL_INTERVAL) + + +def move() -> None: + """Make the worker move.""" + try: + # Initialize database + Core.Database.init_db(DATABASE_PATH) + + # Start main processing loop + main_loop() + + except KeyboardInterrupt: + logger.info("Worker stopped by user") + except Exception: + logger.exception("Worker crashed") + raise + + +class TestArticleExtraction(Test.TestCase): + """Test article extraction functionality.""" + + def test_extract_valid_article(self) -> None: + """Extract from well-formed HTML.""" + # Mock trafilatura.fetch_url and extract + mock_html = ( + "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" + ) + mock_result = json.dumps({ + "title": "Test Article", + "text": "Content here", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Content here") + + def test_extract_missing_title(self) -> None: + """Handle articles without titles.""" + mock_html = "<html><body><p>Content without title</p></body></html>" + mock_result = json.dumps({"text": "Content without title"}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(title, "Untitled Article") + self.assertEqual(content, "Content without title") + + def test_extract_empty_content(self) -> None: + """Handle empty articles.""" + mock_html = "<html><body></body></html>" + mock_result = json.dumps({"title": "Empty Article", "text": ""}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + pytest.raises(ValueError, match="No content extracted") as cm, + ): + ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertIn("No content extracted", str(cm.value)) + + def test_extract_network_error(self) -> None: + """Handle connection failures.""" + with ( + unittest.mock.patch("trafilatura.fetch_url", return_value=None), + pytest.raises(ValueError, match="Failed to download") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Failed to download", str(cm.value)) + + @staticmethod + def test_extract_timeout() -> None: + """Handle slow responses.""" + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + side_effect=TimeoutError("Timeout"), + ), + pytest.raises(TimeoutError), + ): + ArticleProcessor.extract_article_content("https://example.com") + + def test_content_sanitization(self) -> None: + """Remove unwanted elements.""" + mock_html = """ + <html><body> + <h1>Article</h1> + <p>Good content</p> + <script>alert('bad')</script> + <table><tr><td>data</td></tr></table> + </body></html> + """ + mock_result = json.dumps({ + "title": "Article", + "text": "Good content", # Tables and scripts removed + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + _title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(content, "Good content") + self.assertNotIn("script", content) + self.assertNotIn("table", content) + + +class TestTextToSpeech(Test.TestCase): + """Test text-to-speech functionality.""" + + def setUp(self) -> None: + """Set up mocks.""" + # Mock OpenAI API key + self.env_patcher = unittest.mock.patch.dict( + os.environ, + {"OPENAI_API_KEY": "test-key"}, + ) + self.env_patcher.start() + + # Mock OpenAI response + self.mock_audio_response: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_response.content = b"fake-audio-data" + + # Mock AudioSegment to avoid ffmpeg issues in tests + self.mock_audio_segment: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_segment.export.return_value = None + self.audio_segment_patcher = unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=self.mock_audio_segment, + ) + self.audio_segment_patcher.start() + + # Mock the concatenation operations + self.mock_audio_segment.__add__.return_value = self.mock_audio_segment + + def tearDown(self) -> None: + """Clean up mocks.""" + self.env_patcher.stop() + self.audio_segment_patcher.stop() + + def test_tts_generation(self) -> None: + """Generate audio from text.""" + + # Mock the export to write test audio data + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["Test content"], + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech( + "Test content", + "Test Title", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_chunking(self) -> None: + """Handle long articles with chunking.""" + long_text = "Long content " * 1000 + chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock AudioSegment.silent + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech( + long_text, + "Long Article", + ) + + # Should have called TTS for each chunk + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_empty_text(self) -> None: + """Handle empty input.""" + with unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[], + ): + processor = ArticleProcessor() + with pytest.raises(ValueError, match="No chunks generated") as cm: + processor.text_to_speech("", "Empty") + + self.assertIn("No chunks generated", str(cm.value)) + + def test_tts_special_characters(self) -> None: + """Handle unicode and special chars.""" + special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[special_text], + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech( + special_text, + "Special", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_llm_text_preparation(self) -> None: + """Verify LLM editing.""" + # Test the actual text preparation functions + chunks = split_text_into_chunks("Short text", max_chars=100) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], "Short text") + + # Test long text splitting + long_text = "Sentence one. " * 100 + chunks = split_text_into_chunks(long_text, max_chars=100) + self.assertGreater(len(chunks), 1) + for chunk in chunks: + self.assertLessEqual(len(chunk), 100) + + @staticmethod + def test_llm_failure_fallback() -> None: + """Handle LLM API failures.""" + # Mock LLM failure + with unittest.mock.patch("openai.OpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_client.chat.completions.create.side_effect = Exception( + "API Error", + ) + + # Should fall back to raw text + with pytest.raises(Exception, match="API Error"): + edit_chunk_for_speech("Test chunk", "Title", is_first=True) + + def test_chunk_concatenation(self) -> None: + """Verify audio joining.""" + # Mock multiple audio segments + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + processor = ArticleProcessor() + audio_data = processor.text_to_speech("Test", "Title") + + # Should produce combined audio + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + +class TestJobProcessing(Test.TestCase): + """Test job processing functionality.""" + + def setUp(self) -> None: + """Set up test environment.""" + self.test_db = "test_podcast_worker.db" + # Clean up any existing test database + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + Core.Database.init_db(self.test_db) + + # Create test user and job + self.user_id, _ = Core.Database.create_user( + "test@example.com", + self.test_db, + ) + self.job_id = Core.Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Mock environment + self.env_patcher = unittest.mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test-key", + "S3_ENDPOINT": "https://s3.example.com", + "S3_BUCKET": "test-bucket", + "S3_ACCESS_KEY": "test-access", + "S3_SECRET_KEY": "test-secret", + }, + ) + self.env_patcher.start() + + def tearDown(self) -> None: + """Clean up.""" + self.env_patcher.stop() + test_db_path = pathlib.Path(self.test_db) + if test_db_path.exists(): + test_db_path.unlink() + + def test_process_job_success(self) -> None: + """Complete pipeline execution.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock all external calls + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Test Title", "Test content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio-data", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://s3.example.com/audio.mp3", + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.create_episode", + ) as mock_create, + ): + mock_create.return_value = 1 + processor.process_job(job) + + # Verify job was marked complete + mock_update.assert_called_with(self.job_id, "completed") + mock_create.assert_called_once() + + def test_process_job_extraction_failure(self) -> None: + """Handle bad URLs.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + side_effect=ValueError("Bad URL"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(ValueError, match="Bad URL"), + ): + processor.process_job(job) + + # Job should be marked as error + mock_update.assert_called_with(self.job_id, "error", "Bad URL") + + def test_process_job_tts_failure(self) -> None: + """Handle TTS errors.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=Exception("TTS Error"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(Exception, match="TTS Error"), + ): + processor.process_job(job) + + mock_update.assert_called_with(self.job_id, "error", "TTS Error") + + def test_process_job_s3_failure(self) -> None: + """Handle upload errors.""" + processor = ArticleProcessor() + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + side_effect=ClientError({}, "PutObject"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ), + pytest.raises(ClientError), + ): + processor.process_job(job) + + def test_job_retry_logic(self) -> None: + """Verify exponential backoff.""" + # Set job to error with retry count + Core.Database.update_job_status( + self.job_id, + "error", + "First failure", + self.test_db, + ) + Core.Database.update_job_status( + self.job_id, + "error", + "Second failure", + self.test_db, + ) + + job = Core.Database.get_job_by_id(self.job_id, self.test_db) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + self.assertEqual(job["retry_count"], 2) + + # Should be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 1) + + def test_max_retries(self) -> None: + """Stop after max attempts.""" + # Exceed retry limit + for i in range(4): + Core.Database.update_job_status( + self.job_id, + "error", + f"Failure {i}", + self.test_db, + ) + + # Should not be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + db_path=self.test_db, + ) + self.assertEqual(len(retryable), 0) + + def test_concurrent_processing(self) -> None: + """Handle multiple jobs.""" + # Create multiple jobs + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + self.test_db, + ) + job3 = Core.Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + self.test_db, + ) + + # Get pending jobs + jobs = Core.Database.get_pending_jobs(limit=5, db_path=self.test_db) + + self.assertEqual(len(jobs), 3) + self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestArticleExtraction, + TestTextToSpeech, + TestJobProcessing, + ], + ) + + +def main() -> None: + """Entry point for the worker.""" + if "test" in sys.argv: + test() + else: + move() |
