summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater/Worker.py')
-rw-r--r--Biz/PodcastItLater/Worker.py1194
1 files changed, 1194 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py
new file mode 100644
index 0000000..af51260
--- /dev/null
+++ b/Biz/PodcastItLater/Worker.py
@@ -0,0 +1,1194 @@
+"""Background worker for processing article-to-podcast conversions."""
+
+# : dep boto3
+# : dep botocore
+# : dep openai
+# : dep pydub
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+# : dep trafilatura
+# : out podcastitlater-worker
+# : run ffmpeg
+import Biz.PodcastItLater.Core as Core
+import boto3 # type: ignore[import-untyped]
+import io
+import json
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import openai
+import os
+import pathlib
+import pytest
+import sys
+import time
+import trafilatura
+import typing
+import unittest.mock
+from botocore.exceptions import ClientError # type: ignore[import-untyped]
+from datetime import datetime
+from datetime import timedelta
+from datetime import timezone
+from pydub import AudioSegment # type: ignore[import-untyped]
+from typing import Any
+
+logger = Log.setup()
+
+# Configuration from environment variables
+OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
+S3_ENDPOINT = os.getenv("S3_ENDPOINT")
+S3_BUCKET = os.getenv("S3_BUCKET")
+S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY")
+S3_SECRET_KEY = os.getenv("S3_SECRET_KEY")
+DATABASE_PATH = os.getenv("DATABASE_PATH", "podcast.db")
+
+# Worker configuration
+MAX_CONTENT_LENGTH = 5000 # characters for TTS
+POLL_INTERVAL = 30 # seconds
+MAX_RETRIES = 3
+TTS_MODEL = "tts-1"
+TTS_VOICE = "alloy"
+
+
+class ArticleProcessor:
+ """Handles the complete article-to-podcast conversion pipeline."""
+
+ def __init__(self) -> None:
+ """Initialize the processor with required services.
+
+ Raises:
+ ValueError: If OPENAI_API_KEY environment variable is not set.
+ """
+ if not OPENAI_API_KEY:
+ msg = "OPENAI_API_KEY environment variable is required"
+ raise ValueError(msg)
+
+ self.openai_client: openai.OpenAI = openai.OpenAI(
+ api_key=OPENAI_API_KEY,
+ )
+
+ # Initialize S3 client for Digital Ocean Spaces
+ if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]):
+ self.s3_client: Any = boto3.client(
+ "s3",
+ endpoint_url=S3_ENDPOINT,
+ aws_access_key_id=S3_ACCESS_KEY,
+ aws_secret_access_key=S3_SECRET_KEY,
+ )
+ else:
+ logger.warning("S3 configuration incomplete, uploads will fail")
+ self.s3_client = None
+
+ @staticmethod
+ def extract_article_content(url: str) -> tuple[str, str]:
+ """Extract title and content from article URL using trafilatura.
+
+ Raises:
+ ValueError: If content cannot be downloaded or extracted.
+ """
+ try:
+ downloaded = trafilatura.fetch_url(url)
+ if not downloaded:
+ msg = f"Failed to download content from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Extract with metadata
+ result = trafilatura.extract(
+ downloaded,
+ include_comments=False,
+ include_tables=False,
+ with_metadata=True,
+ output_format="json",
+ )
+
+ if not result:
+ msg = f"Failed to extract content from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ data = json.loads(result)
+
+ title = data.get("title", "Untitled Article")
+ content = data.get("text", "")
+
+ if not content:
+ msg = f"No content extracted from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Don't truncate - we'll handle length in text_to_speech
+ logger.info("Extracted article: %s (%d chars)", title, len(content))
+ except Exception:
+ logger.exception("Failed to extract content from %s", url)
+ raise
+ else:
+ return title, content
+
+ def text_to_speech(self, text: str, title: str) -> bytes:
+ """Convert text to speech using OpenAI TTS API.
+
+ Uses LLM to prepare text, then handles chunking and concatenation.
+
+ Raises:
+ ValueError: If no chunks are generated from text.
+ """
+ try:
+ # Use LLM to prepare and chunk the text
+ chunks = prepare_text_for_tts(text, title)
+
+ if not chunks:
+ msg = "No chunks generated from text"
+ raise ValueError(msg) # noqa: TRY301
+
+ logger.info("Processing %d chunks for TTS", len(chunks))
+
+ # Generate audio for each chunk
+ audio_segments = []
+ for i, chunk in enumerate(chunks):
+ logger.info(
+ "Generating TTS for chunk %d/%d (%d chars)",
+ i + 1,
+ len(chunks),
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ # Convert bytes to AudioSegment
+ audio_segment = AudioSegment.from_mp3(
+ io.BytesIO(response.content),
+ )
+ audio_segments.append(audio_segment)
+
+ # Small delay between API calls to be respectful
+ if i < len(chunks) - 1:
+ time.sleep(0.5)
+
+ # Concatenate all audio segments
+ combined_audio = audio_segments[0]
+ for segment in audio_segments[1:]:
+ # Add a small silence between chunks for natural pacing
+ silence = AudioSegment.silent(duration=300)
+ combined_audio = combined_audio + silence + segment
+
+ # Export combined audio to bytes
+ output_buffer = io.BytesIO()
+ combined_audio.export(output_buffer, format="mp3", bitrate="128k")
+ audio_data = output_buffer.getvalue()
+
+ logger.info(
+ "Generated combined TTS audio: %d bytes",
+ len(audio_data),
+ )
+ except Exception:
+ logger.exception("TTS generation failed")
+ raise
+ else:
+ return audio_data
+
+ def upload_to_s3(self, audio_data: bytes, filename: str) -> str:
+ """Upload audio file to S3-compatible storage and return public URL.
+
+ Raises:
+ ValueError: If S3 client is not configured.
+ ClientError: If S3 upload fails.
+ """
+ if not self.s3_client:
+ msg = "S3 client not configured"
+ raise ValueError(msg)
+
+ try:
+ # Upload file
+ self.s3_client.put_object(
+ Bucket=S3_BUCKET,
+ Key=filename,
+ Body=audio_data,
+ ContentType="audio/mpeg",
+ ACL="public-read",
+ )
+
+ # Construct public URL
+ audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}"
+ logger.info("Uploaded audio to: %s", audio_url)
+ except ClientError:
+ logger.exception("S3 upload failed")
+ raise
+ else:
+ return audio_url
+
+ @staticmethod
+ def estimate_duration(audio_data: bytes) -> int:
+ """Estimate audio duration in seconds based on file size and bitrate."""
+ # Rough estimation: MP3 at 128kbps = ~16KB per second
+ estimated_seconds = len(audio_data) // 16000
+ return max(1, estimated_seconds) # Minimum 1 second
+
+ @staticmethod
+ def generate_filename(job_id: int, title: str) -> str:
+ """Generate unique filename for audio file."""
+ timestamp = int(datetime.now(tz=timezone.utc).timestamp())
+ # Create safe filename from title
+ safe_title = "".join(
+ c for c in title if c.isalnum() or c in {" ", "-", "_"}
+ ).rstrip()
+ safe_title = safe_title.replace(" ", "_")[:50] # Limit length
+ return f"episode_{timestamp}_{job_id}_{safe_title}.mp3"
+
+ def process_job(self, job: dict[str, Any]) -> None:
+ """Process a single job through the complete pipeline."""
+ job_id = job["id"]
+ url = job["url"]
+
+ try:
+ logger.info("Processing job %d: %s", job_id, url)
+
+ # Update status to processing
+ Core.Database.update_job_status(
+ job_id,
+ "processing",
+ db_path=DATABASE_PATH,
+ )
+
+ # Step 1: Extract article content
+ title, content = ArticleProcessor.extract_article_content(url)
+
+ # Step 2: Generate audio
+ audio_data = self.text_to_speech(content, title)
+
+ # Step 3: Upload to S3
+ filename = ArticleProcessor.generate_filename(job_id, title)
+ audio_url = self.upload_to_s3(audio_data, filename)
+
+ # Step 4: Calculate duration
+ duration = ArticleProcessor.estimate_duration(audio_data)
+
+ # Step 5: Create episode record
+ episode_id = Core.Database.create_episode(
+ title=title,
+ audio_url=audio_url,
+ duration=duration,
+ content_length=len(content),
+ user_id=job.get("user_id"),
+ db_path=DATABASE_PATH,
+ )
+
+ # Step 6: Mark job as complete
+ Core.Database.update_job_status(
+ job_id,
+ "completed",
+ db_path=DATABASE_PATH,
+ )
+
+ logger.info(
+ "Successfully processed job %d -> episode %d",
+ job_id,
+ episode_id,
+ )
+
+ except Exception as e:
+ error_msg = str(e)
+ logger.exception("Job %d failed: %s", job_id, error_msg)
+ Core.Database.update_job_status(
+ job_id,
+ "error",
+ error_msg,
+ DATABASE_PATH,
+ )
+ raise
+
+
+def prepare_text_for_tts(text: str, title: str) -> list[str]:
+ """Use LLM to prepare text for TTS, returning chunks ready for speech.
+
+ First splits text mechanically, then has LLM edit each chunk.
+ """
+ # First, split the text into manageable chunks
+ raw_chunks = split_text_into_chunks(text, max_chars=3000)
+
+ logger.info("Split article into %d raw chunks", len(raw_chunks))
+
+ # Prepare the first chunk with intro
+ edited_chunks = []
+
+ for i, chunk in enumerate(raw_chunks):
+ is_first = i == 0
+ is_last = i == len(raw_chunks) - 1
+
+ try:
+ edited_chunk = edit_chunk_for_speech(
+ chunk,
+ title=title if is_first else None,
+ is_first=is_first,
+ is_last=is_last,
+ )
+ edited_chunks.append(edited_chunk)
+ except Exception:
+ logger.exception("Failed to edit chunk %d", i + 1)
+ # Fall back to raw chunk if LLM fails
+ if is_first:
+ edited_chunks.append(
+ f"This is an audio version of {title}. {chunk}",
+ )
+ elif is_last:
+ edited_chunks.append(f"{chunk} This concludes the article.")
+ else:
+ edited_chunks.append(chunk)
+
+ return edited_chunks
+
+
+def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]:
+ """Split text into chunks at sentence boundaries."""
+ chunks = []
+ current_chunk = ""
+
+ # Split into paragraphs first
+ paragraphs = text.split("\n\n")
+
+ for para in paragraphs:
+ para_stripped = para.strip()
+ if not para_stripped:
+ continue
+
+ # If paragraph itself is too long, split by sentences
+ if len(para_stripped) > max_chars:
+ sentences = para_stripped.split(". ")
+ for sentence in sentences:
+ if len(current_chunk) + len(sentence) + 2 < max_chars:
+ current_chunk += sentence + ". "
+ else:
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+ current_chunk = sentence + ". "
+ # If adding this paragraph would exceed limit, start new chunk
+ elif len(current_chunk) + len(para_stripped) + 2 > max_chars:
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+ current_chunk = para_stripped + " "
+ else:
+ current_chunk += para_stripped + " "
+
+ # Don't forget the last chunk
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+
+ return chunks
+
+
+def edit_chunk_for_speech(
+ chunk: str,
+ title: str | None = None,
+ *,
+ is_first: bool = False,
+ is_last: bool = False,
+) -> str:
+ """Use LLM to lightly edit a single chunk for speech.
+
+ Raises:
+ ValueError: If no content is returned from LLM.
+ """
+ system_prompt = (
+ "You are a podcast script editor. Your job is to lightly edit text "
+ "to make it sound natural when spoken aloud.\n\n"
+ "Guidelines:\n"
+ )
+ system_prompt += """
+- Remove URLs and email addresses, replacing with descriptive phrases
+- Convert bullet points and lists into flowing sentences
+- Fix any awkward phrasing for speech
+- Remove references like "click here" or "see below"
+- Keep edits minimal - preserve the original content and style
+- Do NOT add commentary or explanations
+- Return ONLY the edited text, no JSON or formatting
+"""
+
+ user_prompt = chunk
+
+ # Add intro/outro if needed
+ if is_first and title:
+ user_prompt = (
+ f"Add a brief intro mentioning this is an audio version of "
+ f"'{title}', then edit this text:\n\n{chunk}"
+ )
+ elif is_last:
+ user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}"
+
+ try:
+ client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY)
+ response = client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ],
+ temperature=0.3, # Lower temperature for more consistent edits
+ max_tokens=4000,
+ )
+
+ content = response.choices[0].message.content
+ if not content:
+ msg = "No content returned from LLM"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Ensure the chunk isn't too long
+ max_chunk_length = 4000
+ if len(content) > max_chunk_length:
+ # Truncate at sentence boundary
+ sentences = content.split(". ")
+ truncated = ""
+ for sentence in sentences:
+ if len(truncated) + len(sentence) + 2 < max_chunk_length:
+ truncated += sentence + ". "
+ else:
+ break
+ content = truncated.strip()
+
+ except Exception:
+ logger.exception("LLM chunk editing failed")
+ raise
+ else:
+ return content
+
+
+def parse_datetime_with_timezone(created_at: str | datetime) -> datetime:
+ """Parse datetime string and ensure it has timezone info."""
+ if isinstance(created_at, str):
+ # Handle timezone-aware datetime strings
+ if created_at.endswith("Z"):
+ created_at = created_at[:-1] + "+00:00"
+ last_attempt = datetime.fromisoformat(created_at)
+ if last_attempt.tzinfo is None:
+ last_attempt = last_attempt.replace(tzinfo=timezone.utc)
+ else:
+ last_attempt = created_at
+ if last_attempt.tzinfo is None:
+ last_attempt = last_attempt.replace(tzinfo=timezone.utc)
+ return last_attempt
+
+
+def should_retry_job(job: dict[str, Any], max_retries: int) -> bool:
+ """Check if a job should be retried based on retry count and backoff time.
+
+ Uses exponential backoff to determine if enough time has passed.
+ """
+ retry_count = job["retry_count"]
+ if retry_count >= max_retries:
+ return False
+
+ # Exponential backoff: 30s, 60s, 120s
+ backoff_time = 30 * (2**retry_count)
+ last_attempt = parse_datetime_with_timezone(job["created_at"])
+ time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt
+
+ return time_since_attempt > timedelta(seconds=backoff_time)
+
+
+def process_pending_jobs(processor: ArticleProcessor) -> None:
+ """Process all pending jobs."""
+ pending_jobs = Core.Database.get_pending_jobs(
+ limit=5,
+ db_path=DATABASE_PATH,
+ )
+
+ for job in pending_jobs:
+ current_job = job["id"]
+ try:
+ processor.process_job(job)
+ except Exception as e:
+ # Ensure job is marked as error even if process_job didn't handle it
+ logger.exception("Failed to process job: %d", current_job)
+ # Check if job is still in processing state
+ current_status = Core.Database.get_job_by_id(
+ current_job,
+ DATABASE_PATH,
+ )
+ if current_status and current_status.get("status") == "processing":
+ Core.Database.update_job_status(
+ current_job,
+ "error",
+ str(e),
+ DATABASE_PATH,
+ )
+ continue
+
+
+def process_retryable_jobs() -> None:
+ """Check and retry failed jobs with exponential backoff."""
+ retryable_jobs = Core.Database.get_retryable_jobs(
+ MAX_RETRIES,
+ DATABASE_PATH,
+ )
+
+ for job in retryable_jobs:
+ if should_retry_job(job, MAX_RETRIES):
+ logger.info(
+ "Retrying job %d (attempt %d)",
+ job["id"],
+ job["retry_count"] + 1,
+ )
+ Core.Database.update_job_status(
+ job["id"],
+ "pending",
+ db_path=DATABASE_PATH,
+ )
+
+
+def main_loop() -> None:
+ """Poll for jobs and process them in a continuous loop."""
+ processor = ArticleProcessor()
+ logger.info("Worker started, polling for jobs...")
+
+ while True:
+ try:
+ process_pending_jobs(processor)
+ process_retryable_jobs()
+
+ # Check if there's any work
+ pending_jobs = Core.Database.get_pending_jobs(
+ limit=1,
+ db_path=DATABASE_PATH,
+ )
+ retryable_jobs = Core.Database.get_retryable_jobs(
+ MAX_RETRIES,
+ DATABASE_PATH,
+ )
+
+ if not pending_jobs and not retryable_jobs:
+ logger.debug("No jobs to process, sleeping...")
+
+ except Exception:
+ logger.exception("Error in main loop")
+
+ time.sleep(POLL_INTERVAL)
+
+
+def move() -> None:
+ """Make the worker move."""
+ try:
+ # Initialize database
+ Core.Database.init_db(DATABASE_PATH)
+
+ # Start main processing loop
+ main_loop()
+
+ except KeyboardInterrupt:
+ logger.info("Worker stopped by user")
+ except Exception:
+ logger.exception("Worker crashed")
+ raise
+
+
+class TestArticleExtraction(Test.TestCase):
+ """Test article extraction functionality."""
+
+ def test_extract_valid_article(self) -> None:
+ """Extract from well-formed HTML."""
+ # Mock trafilatura.fetch_url and extract
+ mock_html = (
+ "<html><body><h1>Test Article</h1><p>Content here</p></body></html>"
+ )
+ mock_result = json.dumps({
+ "title": "Test Article",
+ "text": "Content here",
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content = ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertEqual(title, "Test Article")
+ self.assertEqual(content, "Content here")
+
+ def test_extract_missing_title(self) -> None:
+ """Handle articles without titles."""
+ mock_html = "<html><body><p>Content without title</p></body></html>"
+ mock_result = json.dumps({"text": "Content without title"})
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content = ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertEqual(title, "Untitled Article")
+ self.assertEqual(content, "Content without title")
+
+ def test_extract_empty_content(self) -> None:
+ """Handle empty articles."""
+ mock_html = "<html><body></body></html>"
+ mock_result = json.dumps({"title": "Empty Article", "text": ""})
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ pytest.raises(ValueError, match="No content extracted") as cm,
+ ):
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertIn("No content extracted", str(cm.value))
+
+ def test_extract_network_error(self) -> None:
+ """Handle connection failures."""
+ with (
+ unittest.mock.patch("trafilatura.fetch_url", return_value=None),
+ pytest.raises(ValueError, match="Failed to download") as cm,
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ self.assertIn("Failed to download", str(cm.value))
+
+ @staticmethod
+ def test_extract_timeout() -> None:
+ """Handle slow responses."""
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ side_effect=TimeoutError("Timeout"),
+ ),
+ pytest.raises(TimeoutError),
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ def test_content_sanitization(self) -> None:
+ """Remove unwanted elements."""
+ mock_html = """
+ <html><body>
+ <h1>Article</h1>
+ <p>Good content</p>
+ <script>alert('bad')</script>
+ <table><tr><td>data</td></tr></table>
+ </body></html>
+ """
+ mock_result = json.dumps({
+ "title": "Article",
+ "text": "Good content", # Tables and scripts removed
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ _title, content = ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertEqual(content, "Good content")
+ self.assertNotIn("script", content)
+ self.assertNotIn("table", content)
+
+
+class TestTextToSpeech(Test.TestCase):
+ """Test text-to-speech functionality."""
+
+ def setUp(self) -> None:
+ """Set up mocks."""
+ # Mock OpenAI API key
+ self.env_patcher = unittest.mock.patch.dict(
+ os.environ,
+ {"OPENAI_API_KEY": "test-key"},
+ )
+ self.env_patcher.start()
+
+ # Mock OpenAI response
+ self.mock_audio_response: unittest.mock.MagicMock = (
+ unittest.mock.MagicMock()
+ )
+ self.mock_audio_response.content = b"fake-audio-data"
+
+ # Mock AudioSegment to avoid ffmpeg issues in tests
+ self.mock_audio_segment: unittest.mock.MagicMock = (
+ unittest.mock.MagicMock()
+ )
+ self.mock_audio_segment.export.return_value = None
+ self.audio_segment_patcher = unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ return_value=self.mock_audio_segment,
+ )
+ self.audio_segment_patcher.start()
+
+ # Mock the concatenation operations
+ self.mock_audio_segment.__add__.return_value = self.mock_audio_segment
+
+ def tearDown(self) -> None:
+ """Clean up mocks."""
+ self.env_patcher.stop()
+ self.audio_segment_patcher.stop()
+
+ def test_tts_generation(self) -> None:
+ """Generate audio from text."""
+
+ # Mock the export to write test audio data
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=["Test content"],
+ ),
+ ):
+ processor = ArticleProcessor()
+ audio_data = processor.text_to_speech(
+ "Test content",
+ "Test Title",
+ )
+
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_tts_chunking(self) -> None:
+ """Handle long articles with chunking."""
+ long_text = "Long content " * 1000
+ chunks = ["Chunk 1", "Chunk 2", "Chunk 3"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock AudioSegment.silent
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ processor = ArticleProcessor()
+ audio_data = processor.text_to_speech(
+ long_text,
+ "Long Article",
+ )
+
+ # Should have called TTS for each chunk
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_tts_empty_text(self) -> None:
+ """Handle empty input."""
+ with unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=[],
+ ):
+ processor = ArticleProcessor()
+ with pytest.raises(ValueError, match="No chunks generated") as cm:
+ processor.text_to_speech("", "Empty")
+
+ self.assertIn("No chunks generated", str(cm.value))
+
+ def test_tts_special_characters(self) -> None:
+ """Handle unicode and special chars."""
+ special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"'
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=[special_text],
+ ),
+ ):
+ processor = ArticleProcessor()
+ audio_data = processor.text_to_speech(
+ special_text,
+ "Special",
+ )
+
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_llm_text_preparation(self) -> None:
+ """Verify LLM editing."""
+ # Test the actual text preparation functions
+ chunks = split_text_into_chunks("Short text", max_chars=100)
+ self.assertEqual(len(chunks), 1)
+ self.assertEqual(chunks[0], "Short text")
+
+ # Test long text splitting
+ long_text = "Sentence one. " * 100
+ chunks = split_text_into_chunks(long_text, max_chars=100)
+ self.assertGreater(len(chunks), 1)
+ for chunk in chunks:
+ self.assertLessEqual(len(chunk), 100)
+
+ @staticmethod
+ def test_llm_failure_fallback() -> None:
+ """Handle LLM API failures."""
+ # Mock LLM failure
+ with unittest.mock.patch("openai.OpenAI") as mock_openai:
+ mock_client = mock_openai.return_value
+ mock_client.chat.completions.create.side_effect = Exception(
+ "API Error",
+ )
+
+ # Should fall back to raw text
+ with pytest.raises(Exception, match="API Error"):
+ edit_chunk_for_speech("Test chunk", "Title", is_first=True)
+
+ def test_chunk_concatenation(self) -> None:
+ """Verify audio joining."""
+ # Mock multiple audio segments
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ processor = ArticleProcessor()
+ audio_data = processor.text_to_speech("Test", "Title")
+
+ # Should produce combined audio
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+
+class TestJobProcessing(Test.TestCase):
+ """Test job processing functionality."""
+
+ def setUp(self) -> None:
+ """Set up test environment."""
+ self.test_db = "test_podcast_worker.db"
+ # Clean up any existing test database
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+ Core.Database.init_db(self.test_db)
+
+ # Create test user and job
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ self.test_db,
+ )
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com/article",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ # Mock environment
+ self.env_patcher = unittest.mock.patch.dict(
+ os.environ,
+ {
+ "OPENAI_API_KEY": "test-key",
+ "S3_ENDPOINT": "https://s3.example.com",
+ "S3_BUCKET": "test-bucket",
+ "S3_ACCESS_KEY": "test-access",
+ "S3_SECRET_KEY": "test-secret",
+ },
+ )
+ self.env_patcher.start()
+
+ def tearDown(self) -> None:
+ """Clean up."""
+ self.env_patcher.stop()
+ test_db_path = pathlib.Path(self.test_db)
+ if test_db_path.exists():
+ test_db_path.unlink()
+
+ def test_process_job_success(self) -> None:
+ """Complete pipeline execution."""
+ processor = ArticleProcessor()
+ job = Core.Database.get_job_by_id(self.job_id, self.test_db)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ # Mock all external calls
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Test Title", "Test content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ return_value=b"audio-data",
+ ),
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ return_value="https://s3.example.com/audio.mp3",
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.create_episode",
+ ) as mock_create,
+ ):
+ mock_create.return_value = 1
+ processor.process_job(job)
+
+ # Verify job was marked complete
+ mock_update.assert_called_with(self.job_id, "completed")
+ mock_create.assert_called_once()
+
+ def test_process_job_extraction_failure(self) -> None:
+ """Handle bad URLs."""
+ processor = ArticleProcessor()
+ job = Core.Database.get_job_by_id(self.job_id, self.test_db)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ side_effect=ValueError("Bad URL"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ pytest.raises(ValueError, match="Bad URL"),
+ ):
+ processor.process_job(job)
+
+ # Job should be marked as error
+ mock_update.assert_called_with(self.job_id, "error", "Bad URL")
+
+ def test_process_job_tts_failure(self) -> None:
+ """Handle TTS errors."""
+ processor = ArticleProcessor()
+ job = Core.Database.get_job_by_id(self.job_id, self.test_db)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Title", "Content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ side_effect=Exception("TTS Error"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ pytest.raises(Exception, match="TTS Error"),
+ ):
+ processor.process_job(job)
+
+ mock_update.assert_called_with(self.job_id, "error", "TTS Error")
+
+ def test_process_job_s3_failure(self) -> None:
+ """Handle upload errors."""
+ processor = ArticleProcessor()
+ job = Core.Database.get_job_by_id(self.job_id, self.test_db)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Title", "Content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ return_value=b"audio",
+ ),
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ side_effect=ClientError({}, "PutObject"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ),
+ pytest.raises(ClientError),
+ ):
+ processor.process_job(job)
+
+ def test_job_retry_logic(self) -> None:
+ """Verify exponential backoff."""
+ # Set job to error with retry count
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "First failure",
+ self.test_db,
+ )
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "Second failure",
+ self.test_db,
+ )
+
+ job = Core.Database.get_job_by_id(self.job_id, self.test_db)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ self.assertEqual(job["retry_count"], 2)
+
+ # Should be retryable
+ retryable = Core.Database.get_retryable_jobs(
+ max_retries=3,
+ db_path=self.test_db,
+ )
+ self.assertEqual(len(retryable), 1)
+
+ def test_max_retries(self) -> None:
+ """Stop after max attempts."""
+ # Exceed retry limit
+ for i in range(4):
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ f"Failure {i}",
+ self.test_db,
+ )
+
+ # Should not be retryable
+ retryable = Core.Database.get_retryable_jobs(
+ max_retries=3,
+ db_path=self.test_db,
+ )
+ self.assertEqual(len(retryable), 0)
+
+ def test_concurrent_processing(self) -> None:
+ """Handle multiple jobs."""
+ # Create multiple jobs
+ job2 = Core.Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+ job3 = Core.Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ self.test_db,
+ )
+
+ # Get pending jobs
+ jobs = Core.Database.get_pending_jobs(limit=5, db_path=self.test_db)
+
+ self.assertEqual(len(jobs), 3)
+ self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3})
+
+
+def test() -> None:
+ """Run the tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestArticleExtraction,
+ TestTextToSpeech,
+ TestJobProcessing,
+ ],
+ )
+
+
+def main() -> None:
+ """Entry point for the worker."""
+ if "test" in sys.argv:
+ test()
+ else:
+ move()