diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-02 14:55:10 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-02 14:55:10 -0500 |
| commit | 8329b760082e07364a6f6c3e8e0b240802838316 (patch) | |
| tree | 96374d56651900a3c78dddbdc9234569a042b738 /Biz | |
| parent | 32f1f3e863a4844ad29285425749405d91f34662 (diff) | |
Ignore PLC0415 in ruff (late imports for circular deps)
Diffstat (limited to 'Biz')
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 1982 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker/Jobs.py | 506 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker/Processor.py | 1382 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Worker/TextProcessing.py | 211 |
4 files changed, 2126 insertions, 1955 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py index bf6ef9e..ecef2c0 100644 --- a/Biz/PodcastItLater/Worker.py +++ b/Biz/PodcastItLater/Worker.py @@ -12,33 +12,23 @@ # : out podcastitlater-worker # : run ffmpeg import Biz.PodcastItLater.Core as Core -import boto3 # type: ignore[import-untyped] -import concurrent.futures -import io +import Biz.PodcastItLater.Worker.Jobs as Jobs +import Biz.PodcastItLater.Worker.Processor as Processor +import Biz.PodcastItLater.Worker.TextProcessing as TextProcessing import json import logging import Omni.App as App import Omni.Log as Log import Omni.Test as Test -import openai -import operator import os -import psutil # type: ignore[import-untyped] import pytest import signal import sys -import tempfile import threading -import time -import trafilatura -import typing import unittest.mock -from botocore.exceptions import ClientError # type: ignore[import-untyped] from datetime import datetime from datetime import timedelta from datetime import timezone -from pathlib import Path -from pydub import AudioSegment # type: ignore[import-untyped] from typing import Any logger = logging.getLogger(__name__) @@ -46,22 +36,11 @@ Log.setup(logger) # Configuration from environment variables OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -S3_ENDPOINT = os.getenv("S3_ENDPOINT") -S3_BUCKET = os.getenv("S3_BUCKET") -S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") -S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") area = App.from_env() # Worker configuration -MAX_CONTENT_LENGTH = 5000 # characters for TTS MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles -POLL_INTERVAL = 30 # seconds -MAX_RETRIES = 3 -TTS_MODEL = "tts-1" -TTS_VOICE = "alloy" MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage -CROSSFADE_DURATION = 500 # ms for crossfading segments -PAUSE_DURATION = 1000 # ms for silence between segments class ShutdownHandler: @@ -100,939 +79,6 @@ class ShutdownHandler: return self.current_job_id -class ArticleProcessor: - """Handles the complete article-to-podcast conversion pipeline.""" - - def __init__(self, shutdown_handler: ShutdownHandler) -> None: - """Initialize the processor with required services. - - Raises: - ValueError: If OPENAI_API_KEY environment variable is not set. - """ - if not OPENAI_API_KEY: - msg = "OPENAI_API_KEY environment variable is required" - raise ValueError(msg) - - self.openai_client: openai.OpenAI = openai.OpenAI( - api_key=OPENAI_API_KEY, - ) - self.shutdown_handler = shutdown_handler - - # Initialize S3 client for Digital Ocean Spaces - if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): - self.s3_client: Any = boto3.client( - "s3", - endpoint_url=S3_ENDPOINT, - aws_access_key_id=S3_ACCESS_KEY, - aws_secret_access_key=S3_SECRET_KEY, - ) - else: - logger.warning("S3 configuration incomplete, uploads will fail") - self.s3_client = None - - @staticmethod - def extract_article_content( - url: str, - ) -> tuple[str, str, str | None, str | None]: - """Extract title, content, author, and date from article URL. - - Returns: - tuple: (title, content, author, publication_date) - - Raises: - ValueError: If content cannot be downloaded, extracted, or large. - """ - try: - downloaded = trafilatura.fetch_url(url) - if not downloaded: - msg = f"Failed to download content from {url}" - raise ValueError(msg) # noqa: TRY301 - - # Check size before processing - if ( - len(downloaded) > MAX_ARTICLE_SIZE * 4 - ): # Rough HTML to text ratio - msg = f"Article too large: {len(downloaded)} bytes" - raise ValueError(msg) # noqa: TRY301 - - # Extract with metadata - result = trafilatura.extract( - downloaded, - include_comments=False, - include_tables=False, - with_metadata=True, - output_format="json", - ) - - if not result: - msg = f"Failed to extract content from {url}" - raise ValueError(msg) # noqa: TRY301 - - data = json.loads(result) - - title = data.get("title", "Untitled Article") - content = data.get("text", "") - author = data.get("author") - pub_date = data.get("date") - - if not content: - msg = f"No content extracted from {url}" - raise ValueError(msg) # noqa: TRY301 - - # Enforce content size limit - if len(content) > MAX_ARTICLE_SIZE: - logger.warning( - "Article content truncated from %d to %d characters", - len(content), - MAX_ARTICLE_SIZE, - ) - content = content[:MAX_ARTICLE_SIZE] - - logger.info( - "Extracted article: %s (%d chars, author: %s, date: %s)", - title, - len(content), - author or "unknown", - pub_date or "unknown", - ) - except Exception: - logger.exception("Failed to extract content from %s", url) - raise - else: - return title, content, author, pub_date - - def text_to_speech( - self, - text: str, - title: str, - author: str | None = None, - pub_date: str | None = None, - ) -> bytes: - """Convert text to speech with intro/outro using OpenAI TTS API. - - Uses parallel processing for chunks while maintaining order. - Adds intro with metadata and outro with attribution. - - Args: - text: Article content to convert - title: Article title - author: Article author (optional) - pub_date: Publication date (optional) - - Raises: - ValueError: If no chunks are generated from text. - """ - try: - # Generate intro audio - intro_text = self._create_intro_text(title, author, pub_date) - intro_audio = self._generate_tts_segment(intro_text) - - # Generate outro audio - outro_text = self._create_outro_text(title, author) - outro_audio = self._generate_tts_segment(outro_text) - - # Use LLM to prepare and chunk the main content - chunks = prepare_text_for_tts(text, title) - - if not chunks: - msg = "No chunks generated from text" - raise ValueError(msg) # noqa: TRY301 - - logger.info("Processing %d chunks for TTS", len(chunks)) - - # Check memory before parallel processing - mem_usage = check_memory_usage() - if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer - logger.warning( - "High memory usage (%.1f%%), falling back to serial " - "processing", - mem_usage, - ) - content_audio_bytes = self._text_to_speech_serial(chunks) - else: - # Determine max workers - max_workers = min( - 4, # Reasonable limit to avoid rate limiting - len(chunks), # No more workers than chunks - max(1, psutil.cpu_count() // 2), # Use half of CPU cores - ) - - logger.info( - "Using %d workers for parallel TTS processing", - max_workers, - ) - - # Process chunks in parallel - chunk_results: list[tuple[int, bytes]] = [] - - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers, - ) as executor: - # Submit all chunks for processing - future_to_index = { - executor.submit(self._process_tts_chunk, chunk, i): i - for i, chunk in enumerate(chunks) - } - - # Collect results as they complete - for future in concurrent.futures.as_completed( - future_to_index, - ): - index = future_to_index[future] - try: - audio_data = future.result() - chunk_results.append((index, audio_data)) - except Exception: - logger.exception( - "Failed to process chunk %d", - index, - ) - raise - - # Sort results by index to maintain order - chunk_results.sort(key=operator.itemgetter(0)) - - # Combine audio chunks - content_audio_bytes = self._combine_audio_chunks([ - data for _, data in chunk_results - ]) - - # Combine intro, content, and outro with pauses - return ArticleProcessor._combine_intro_content_outro( - intro_audio, - content_audio_bytes, - outro_audio, - ) - - except Exception: - logger.exception("TTS generation failed") - raise - - @staticmethod - def _create_intro_text( - title: str, - author: str | None, - pub_date: str | None, - ) -> str: - """Create intro text with available metadata.""" - parts = [f"Title: {title}"] - - if author: - parts.append(f"Author: {author}") - - if pub_date: - parts.append(f"Published: {pub_date}") - - return ". ".join(parts) + "." - - @staticmethod - def _create_outro_text(title: str, author: str | None) -> str: - """Create outro text with attribution.""" - if author: - return ( - f"This has been an audio version of {title} " - f"by {author}, created using Podcast It Later." - ) - return ( - f"This has been an audio version of {title}, " - "created using Podcast It Later." - ) - - def _generate_tts_segment(self, text: str) -> bytes: - """Generate TTS audio for a single segment (intro/outro). - - Args: - text: Text to convert to speech - - Returns: - MP3 audio bytes - """ - response = self.openai_client.audio.speech.create( - model=TTS_MODEL, - voice=TTS_VOICE, - input=text, - ) - return response.content - - @staticmethod - def _combine_intro_content_outro( - intro_audio: bytes, - content_audio: bytes, - outro_audio: bytes, - ) -> bytes: - """Combine intro, content, and outro with crossfades. - - Args: - intro_audio: MP3 bytes for intro - content_audio: MP3 bytes for main content - outro_audio: MP3 bytes for outro - - Returns: - Combined MP3 audio bytes - """ - # Load audio segments - intro = AudioSegment.from_mp3(io.BytesIO(intro_audio)) - content = AudioSegment.from_mp3(io.BytesIO(content_audio)) - outro = AudioSegment.from_mp3(io.BytesIO(outro_audio)) - - # Create bridge silence (pause + 2 * crossfade to account for overlap) - bridge = AudioSegment.silent( - duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION - ) - - def safe_append( - seg1: AudioSegment, seg2: AudioSegment, crossfade: int - ) -> AudioSegment: - if len(seg1) < crossfade or len(seg2) < crossfade: - logger.warning( - "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation", - crossfade, - len(seg1), - len(seg2), - ) - return seg1 + seg2 - return seg1.append(seg2, crossfade=crossfade) - - # Combine segments with crossfades - # Intro -> Bridge -> Content -> Bridge -> Outro - # This effectively fades out the previous segment and fades in the next one - combined = safe_append(intro, bridge, CROSSFADE_DURATION) - combined = safe_append(combined, content, CROSSFADE_DURATION) - combined = safe_append(combined, bridge, CROSSFADE_DURATION) - combined = safe_append(combined, outro, CROSSFADE_DURATION) - - # Export to bytes - output = io.BytesIO() - combined.export(output, format="mp3") - return output.getvalue() - - def _process_tts_chunk(self, chunk: str, index: int) -> bytes: - """Process a single TTS chunk. - - Args: - chunk: Text to convert to speech - index: Chunk index for logging - - Returns: - Audio data as bytes - """ - logger.info( - "Generating TTS for chunk %d (%d chars)", - index + 1, - len(chunk), - ) - - response = self.openai_client.audio.speech.create( - model=TTS_MODEL, - voice=TTS_VOICE, - input=chunk, - response_format="mp3", - ) - - return response.content - - @staticmethod - def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes: - """Combine multiple audio chunks with silence gaps. - - Args: - audio_chunks: List of audio data in order - - Returns: - Combined audio data - """ - if not audio_chunks: - msg = "No audio chunks to combine" - raise ValueError(msg) - - # Create a temporary file for the combined audio - with tempfile.NamedTemporaryFile( - suffix=".mp3", - delete=False, - ) as temp_file: - temp_path = temp_file.name - - try: - # Start with the first chunk - combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0])) - - # Add remaining chunks with silence gaps - for chunk_data in audio_chunks[1:]: - chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data)) - silence = AudioSegment.silent(duration=300) # 300ms gap - combined_audio = combined_audio + silence + chunk_audio - - # Export to file - combined_audio.export(temp_path, format="mp3", bitrate="128k") - - # Read back the combined audio - return Path(temp_path).read_bytes() - - finally: - # Clean up temp file - if Path(temp_path).exists(): - Path(temp_path).unlink() - - def _text_to_speech_serial(self, chunks: list[str]) -> bytes: - """Fallback serial processing for high memory situations. - - This is the original serial implementation. - """ - # Create a temporary file for streaming audio concatenation - with tempfile.NamedTemporaryFile( - suffix=".mp3", - delete=False, - ) as temp_file: - temp_path = temp_file.name - - try: - # Process first chunk - logger.info("Generating TTS for chunk 1/%d", len(chunks)) - response = self.openai_client.audio.speech.create( - model=TTS_MODEL, - voice=TTS_VOICE, - input=chunks[0], - response_format="mp3", - ) - - # Write first chunk directly to file - Path(temp_path).write_bytes(response.content) - - # Process remaining chunks - for i, chunk in enumerate(chunks[1:], 1): - logger.info( - "Generating TTS for chunk %d/%d (%d chars)", - i + 1, - len(chunks), - len(chunk), - ) - - response = self.openai_client.audio.speech.create( - model=TTS_MODEL, - voice=TTS_VOICE, - input=chunk, - response_format="mp3", - ) - - # Append to existing file with silence gap - # Load only the current segment - current_segment = AudioSegment.from_mp3( - io.BytesIO(response.content), - ) - - # Load existing audio, append, and save back - existing_audio = AudioSegment.from_mp3(temp_path) - silence = AudioSegment.silent(duration=300) - combined = existing_audio + silence + current_segment - - # Export back to the same file - combined.export(temp_path, format="mp3", bitrate="128k") - - # Force garbage collection to free memory - del existing_audio, current_segment, combined - - # Small delay between API calls - if i < len(chunks) - 1: - time.sleep(0.5) - - # Read final result - audio_data = Path(temp_path).read_bytes() - - logger.info( - "Generated combined TTS audio: %d bytes", - len(audio_data), - ) - return audio_data - - finally: - # Clean up temp file - temp_file_path = Path(temp_path) - if temp_file_path.exists(): - temp_file_path.unlink() - - def upload_to_s3(self, audio_data: bytes, filename: str) -> str: - """Upload audio file to S3-compatible storage and return public URL. - - Raises: - ValueError: If S3 client is not configured. - ClientError: If S3 upload fails. - """ - if not self.s3_client: - msg = "S3 client not configured" - raise ValueError(msg) - - try: - # Upload file using streaming to minimize memory usage - audio_stream = io.BytesIO(audio_data) - self.s3_client.upload_fileobj( - audio_stream, - S3_BUCKET, - filename, - ExtraArgs={ - "ContentType": "audio/mpeg", - "ACL": "public-read", - }, - ) - - # Construct public URL - audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" - logger.info( - "Uploaded audio to: %s (%d bytes)", - audio_url, - len(audio_data), - ) - except ClientError: - logger.exception("S3 upload failed") - raise - else: - return audio_url - - @staticmethod - def estimate_duration(audio_data: bytes) -> int: - """Estimate audio duration in seconds based on file size and bitrate.""" - # Rough estimation: MP3 at 128kbps = ~16KB per second - estimated_seconds = len(audio_data) // 16000 - return max(1, estimated_seconds) # Minimum 1 second - - @staticmethod - def generate_filename(job_id: int, title: str) -> str: - """Generate unique filename for audio file.""" - timestamp = int(datetime.now(tz=timezone.utc).timestamp()) - # Create safe filename from title - safe_title = "".join( - c for c in title if c.isalnum() or c in {" ", "-", "_"} - ).rstrip() - safe_title = safe_title.replace(" ", "_")[:50] # Limit length - return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" - - def process_job( - self, - job: dict[str, Any], - ) -> None: - """Process a single job through the complete pipeline.""" - job_id = job["id"] - url = job["url"] - - # Check memory before starting - mem_usage = check_memory_usage() - if mem_usage > MEMORY_THRESHOLD: - logger.warning( - "High memory usage (%.1f%%), deferring job %d", - mem_usage, - job_id, - ) - return - - # Track current job for graceful shutdown - self.shutdown_handler.set_current_job(job_id) - - try: - logger.info("Processing job %d: %s", job_id, url) - - # Update status to processing - Core.Database.update_job_status( - job_id, - "processing", - ) - - # Check for shutdown before each major step - if self.shutdown_handler.is_shutdown_requested(): - logger.info("Shutdown requested, aborting job %d", job_id) - Core.Database.update_job_status(job_id, "pending") - return - - # Step 1: Extract article content - Core.Database.update_job_status(job_id, "extracting") - title, content, author, pub_date = ( - ArticleProcessor.extract_article_content(url) - ) - - if self.shutdown_handler.is_shutdown_requested(): - logger.info("Shutdown requested, aborting job %d", job_id) - Core.Database.update_job_status(job_id, "pending") - return - - # Step 2: Generate audio with metadata - Core.Database.update_job_status(job_id, "synthesizing") - audio_data = self.text_to_speech(content, title, author, pub_date) - - if self.shutdown_handler.is_shutdown_requested(): - logger.info("Shutdown requested, aborting job %d", job_id) - Core.Database.update_job_status(job_id, "pending") - return - - # Step 3: Upload to S3 - Core.Database.update_job_status(job_id, "uploading") - filename = ArticleProcessor.generate_filename(job_id, title) - audio_url = self.upload_to_s3(audio_data, filename) - - # Step 4: Calculate duration - duration = ArticleProcessor.estimate_duration(audio_data) - - # Step 5: Create episode record - url_hash = Core.hash_url(url) - episode_id = Core.Database.create_episode( - title=title, - audio_url=audio_url, - duration=duration, - content_length=len(content), - user_id=job.get("user_id"), - author=job.get("author"), # Pass author from job - original_url=url, # Pass the original article URL - original_url_hash=url_hash, - ) - - # Add episode to user's feed - user_id = job.get("user_id") - if user_id: - Core.Database.add_episode_to_user(user_id, episode_id) - Core.Database.track_episode_event( - episode_id, - "added", - user_id, - ) - - # Step 6: Mark job as complete - Core.Database.update_job_status( - job_id, - "completed", - ) - - logger.info( - "Successfully processed job %d -> episode %d", - job_id, - episode_id, - ) - - except Exception as e: - error_msg = str(e) - logger.exception("Job %d failed: %s", job_id, error_msg) - Core.Database.update_job_status( - job_id, - "error", - error_msg, - ) - raise - finally: - # Clear current job - self.shutdown_handler.set_current_job(None) - - -def prepare_text_for_tts(text: str, title: str) -> list[str]: - """Use LLM to prepare text for TTS, returning chunks ready for speech. - - First splits text mechanically, then has LLM edit each chunk. - """ - # First, split the text into manageable chunks - raw_chunks = split_text_into_chunks(text, max_chars=3000) - - logger.info("Split article into %d raw chunks", len(raw_chunks)) - - # Prepare the first chunk with intro - edited_chunks = [] - - for i, chunk in enumerate(raw_chunks): - is_first = i == 0 - is_last = i == len(raw_chunks) - 1 - - try: - edited_chunk = edit_chunk_for_speech( - chunk, - title=title if is_first else None, - is_first=is_first, - is_last=is_last, - ) - edited_chunks.append(edited_chunk) - except Exception: - logger.exception("Failed to edit chunk %d", i + 1) - # Fall back to raw chunk if LLM fails - if is_first: - edited_chunks.append( - f"This is an audio version of {title}. {chunk}", - ) - elif is_last: - edited_chunks.append(f"{chunk} This concludes the article.") - else: - edited_chunks.append(chunk) - - return edited_chunks - - -def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: - """Split text into chunks at sentence boundaries.""" - chunks = [] - current_chunk = "" - - # Split into paragraphs first - paragraphs = text.split("\n\n") - - for para in paragraphs: - para_stripped = para.strip() - if not para_stripped: - continue - - # If paragraph itself is too long, split by sentences - if len(para_stripped) > max_chars: - sentences = para_stripped.split(". ") - for sentence in sentences: - if len(current_chunk) + len(sentence) + 2 < max_chars: - current_chunk += sentence + ". " - else: - if current_chunk: - chunks.append(current_chunk.strip()) - current_chunk = sentence + ". " - # If adding this paragraph would exceed limit, start new chunk - elif len(current_chunk) + len(para_stripped) + 2 > max_chars: - if current_chunk: - chunks.append(current_chunk.strip()) - current_chunk = para_stripped + " " - else: - current_chunk += para_stripped + " " - - # Don't forget the last chunk - if current_chunk: - chunks.append(current_chunk.strip()) - - return chunks - - -def edit_chunk_for_speech( - chunk: str, - title: str | None = None, - *, - is_first: bool = False, - is_last: bool = False, -) -> str: - """Use LLM to lightly edit a single chunk for speech. - - Raises: - ValueError: If no content is returned from LLM. - """ - system_prompt = ( - "You are a podcast script editor. Your job is to lightly edit text " - "to make it sound natural when spoken aloud.\n\n" - "Guidelines:\n" - ) - system_prompt += """ -- Remove URLs and email addresses, replacing with descriptive phrases -- Convert bullet points and lists into flowing sentences -- Fix any awkward phrasing for speech -- Remove references like "click here" or "see below" -- Keep edits minimal - preserve the original content and style -- Do NOT add commentary or explanations -- Return ONLY the edited text, no JSON or formatting -""" - - user_prompt = chunk - - # Add intro/outro if needed - if is_first and title: - user_prompt = ( - f"Add a brief intro mentioning this is an audio version of " - f"'{title}', then edit this text:\n\n{chunk}" - ) - elif is_last: - user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" - - try: - client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - temperature=0.3, # Lower temperature for more consistent edits - max_tokens=4000, - ) - - content = response.choices[0].message.content - if not content: - msg = "No content returned from LLM" - raise ValueError(msg) # noqa: TRY301 - - # Ensure the chunk isn't too long - max_chunk_length = 4000 - if len(content) > max_chunk_length: - # Truncate at sentence boundary - sentences = content.split(". ") - truncated = "" - for sentence in sentences: - if len(truncated) + len(sentence) + 2 < max_chunk_length: - truncated += sentence + ". " - else: - break - content = truncated.strip() - - except Exception: - logger.exception("LLM chunk editing failed") - raise - else: - return content - - -def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: - """Parse datetime string and ensure it has timezone info.""" - if isinstance(created_at, str): - # Handle timezone-aware datetime strings - if created_at.endswith("Z"): - created_at = created_at[:-1] + "+00:00" - last_attempt = datetime.fromisoformat(created_at) - if last_attempt.tzinfo is None: - last_attempt = last_attempt.replace(tzinfo=timezone.utc) - else: - last_attempt = created_at - if last_attempt.tzinfo is None: - last_attempt = last_attempt.replace(tzinfo=timezone.utc) - return last_attempt - - -def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: - """Check if a job should be retried based on retry count and backoff time. - - Uses exponential backoff to determine if enough time has passed. - """ - retry_count = job["retry_count"] - if retry_count >= max_retries: - return False - - # Exponential backoff: 30s, 60s, 120s - backoff_time = 30 * (2**retry_count) - last_attempt = parse_datetime_with_timezone(job["created_at"]) - time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt - - return time_since_attempt > timedelta(seconds=backoff_time) - - -def process_pending_jobs( - processor: ArticleProcessor, -) -> None: - """Process all pending jobs.""" - pending_jobs = Core.Database.get_pending_jobs( - limit=5, - ) - - for job in pending_jobs: - if processor.shutdown_handler.is_shutdown_requested(): - logger.info("Shutdown requested, stopping job processing") - break - - current_job = job["id"] - try: - processor.process_job(job) - except Exception as e: - # Ensure job is marked as error even if process_job didn't handle it - logger.exception("Failed to process job: %d", current_job) - # Check if job is still in processing state - current_status = Core.Database.get_job_by_id( - current_job, - ) - if current_status and current_status.get("status") == "processing": - Core.Database.update_job_status( - current_job, - "error", - str(e), - ) - continue - - -def process_retryable_jobs() -> None: - """Check and retry failed jobs with exponential backoff.""" - retryable_jobs = Core.Database.get_retryable_jobs( - MAX_RETRIES, - ) - - for job in retryable_jobs: - if should_retry_job(job, MAX_RETRIES): - logger.info( - "Retrying job %d (attempt %d)", - job["id"], - job["retry_count"] + 1, - ) - Core.Database.update_job_status( - job["id"], - "pending", - ) - - -def check_memory_usage() -> int | Any: - """Check current memory usage percentage.""" - try: - process = psutil.Process() - # this returns an int but psutil is untyped - return process.memory_percent() - except (psutil.Error, OSError): - logger.warning("Failed to check memory usage") - return 0.0 - - -def cleanup_stale_jobs() -> None: - """Reset jobs stuck in processing state on startup.""" - with Core.Database.get_connection() as conn: - cursor = conn.cursor() - cursor.execute( - """ - UPDATE queue - SET status = 'pending' - WHERE status = 'processing' - """, - ) - affected = cursor.rowcount - conn.commit() - - if affected > 0: - logger.info( - "Reset %d stale jobs from processing to pending", - affected, - ) - - -def main_loop() -> None: - """Poll for jobs and process them in a continuous loop.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - - # Clean up any stale jobs from previous runs - cleanup_stale_jobs() - - logger.info("Worker started, polling for jobs...") - - while not shutdown_handler.is_shutdown_requested(): - try: - # Process pending jobs - process_pending_jobs(processor) - process_retryable_jobs() - - # Check if there's any work - pending_jobs = Core.Database.get_pending_jobs( - limit=1, - ) - retryable_jobs = Core.Database.get_retryable_jobs( - MAX_RETRIES, - ) - - if not pending_jobs and not retryable_jobs: - logger.debug("No jobs to process, sleeping...") - - except Exception: - logger.exception("Error in main loop") - - # Use interruptible sleep - if not shutdown_handler.is_shutdown_requested(): - shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL) - - # Graceful shutdown - current_job = shutdown_handler.get_current_job() - if current_job: - logger.info( - "Waiting for job %d to complete before shutdown...", - current_job, - ) - # The job will complete or be reset to pending - - logger.info("Worker shutdown complete") - - def move() -> None: """Make the worker move.""" try: @@ -1040,7 +86,9 @@ def move() -> None: Core.Database.init_db() # Start main processing loop - main_loop() + shutdown_handler = ShutdownHandler() + processor = Processor.ArticleProcessor(shutdown_handler) + Jobs.main_loop(shutdown_handler, processor) except KeyboardInterrupt: logger.info("Worker stopped by user") @@ -1049,662 +97,6 @@ def move() -> None: raise -class TestArticleExtraction(Test.TestCase): - """Test article extraction functionality.""" - - def test_extract_valid_article(self) -> None: - """Extract from well-formed HTML.""" - # Mock trafilatura.fetch_url and extract - mock_html = ( - "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" - ) - mock_result = json.dumps({ - "title": "Test Article", - "text": "Content here", - }) - - with ( - unittest.mock.patch( - "trafilatura.fetch_url", - return_value=mock_html, - ), - unittest.mock.patch( - "trafilatura.extract", - return_value=mock_result, - ), - ): - title, content, author, pub_date = ( - ArticleProcessor.extract_article_content( - "https://example.com", - ) - ) - - self.assertEqual(title, "Test Article") - self.assertEqual(content, "Content here") - self.assertIsNone(author) - self.assertIsNone(pub_date) - - def test_extract_missing_title(self) -> None: - """Handle articles without titles.""" - mock_html = "<html><body><p>Content without title</p></body></html>" - mock_result = json.dumps({"text": "Content without title"}) - - with ( - unittest.mock.patch( - "trafilatura.fetch_url", - return_value=mock_html, - ), - unittest.mock.patch( - "trafilatura.extract", - return_value=mock_result, - ), - ): - title, content, author, pub_date = ( - ArticleProcessor.extract_article_content( - "https://example.com", - ) - ) - - self.assertEqual(title, "Untitled Article") - self.assertEqual(content, "Content without title") - self.assertIsNone(author) - self.assertIsNone(pub_date) - - def test_extract_empty_content(self) -> None: - """Handle empty articles.""" - mock_html = "<html><body></body></html>" - mock_result = json.dumps({"title": "Empty Article", "text": ""}) - - with ( - unittest.mock.patch( - "trafilatura.fetch_url", - return_value=mock_html, - ), - unittest.mock.patch( - "trafilatura.extract", - return_value=mock_result, - ), - pytest.raises(ValueError, match="No content extracted") as cm, - ): - ArticleProcessor.extract_article_content( - "https://example.com", - ) - - self.assertIn("No content extracted", str(cm.value)) - - def test_extract_network_error(self) -> None: - """Handle connection failures.""" - with ( - unittest.mock.patch("trafilatura.fetch_url", return_value=None), - pytest.raises(ValueError, match="Failed to download") as cm, - ): - ArticleProcessor.extract_article_content("https://example.com") - - self.assertIn("Failed to download", str(cm.value)) - - @staticmethod - def test_extract_timeout() -> None: - """Handle slow responses.""" - with ( - unittest.mock.patch( - "trafilatura.fetch_url", - side_effect=TimeoutError("Timeout"), - ), - pytest.raises(TimeoutError), - ): - ArticleProcessor.extract_article_content("https://example.com") - - def test_content_sanitization(self) -> None: - """Remove unwanted elements.""" - mock_html = """ - <html><body> - <h1>Article</h1> - <p>Good content</p> - <script>alert('bad')</script> - <table><tr><td>data</td></tr></table> - </body></html> - """ - mock_result = json.dumps({ - "title": "Article", - "text": "Good content", # Tables and scripts removed - }) - - with ( - unittest.mock.patch( - "trafilatura.fetch_url", - return_value=mock_html, - ), - unittest.mock.patch( - "trafilatura.extract", - return_value=mock_result, - ), - ): - _title, content, _author, _pub_date = ( - ArticleProcessor.extract_article_content( - "https://example.com", - ) - ) - - self.assertEqual(content, "Good content") - self.assertNotIn("script", content) - self.assertNotIn("table", content) - - -class TestTextToSpeech(Test.TestCase): - """Test text-to-speech functionality.""" - - def setUp(self) -> None: - """Set up mocks.""" - # Mock OpenAI API key - self.env_patcher = unittest.mock.patch.dict( - os.environ, - {"OPENAI_API_KEY": "test-key"}, - ) - self.env_patcher.start() - - # Mock OpenAI response - self.mock_audio_response: unittest.mock.MagicMock = ( - unittest.mock.MagicMock() - ) - self.mock_audio_response.content = b"fake-audio-data" - - # Mock AudioSegment to avoid ffmpeg issues in tests - self.mock_audio_segment: unittest.mock.MagicMock = ( - unittest.mock.MagicMock() - ) - self.mock_audio_segment.export.return_value = None - self.audio_segment_patcher = unittest.mock.patch( - "pydub.AudioSegment.from_mp3", - return_value=self.mock_audio_segment, - ) - self.audio_segment_patcher.start() - - # Mock the concatenation operations - self.mock_audio_segment.__add__.return_value = self.mock_audio_segment - - def tearDown(self) -> None: - """Clean up mocks.""" - self.env_patcher.stop() - self.audio_segment_patcher.stop() - - def test_tts_generation(self) -> None: - """Generate audio from text.""" - - # Mock the export to write test audio data - def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: - buffer.write(b"test-audio-output") - buffer.seek(0) - - self.mock_audio_segment.export.side_effect = mock_export - - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - mock_speech.create.return_value = self.mock_audio_response - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=["Test content"], - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech( - "Test content", - "Test Title", - ) - - self.assertIsInstance(audio_data, bytes) - self.assertEqual(audio_data, b"test-audio-output") - - def test_tts_chunking(self) -> None: - """Handle long articles with chunking.""" - long_text = "Long content " * 1000 - chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] - - def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: - buffer.write(b"test-audio-output") - buffer.seek(0) - - self.mock_audio_segment.export.side_effect = mock_export - - # Mock AudioSegment.silent - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - mock_speech.create.return_value = self.mock_audio_response - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=chunks, - ), - unittest.mock.patch( - "pydub.AudioSegment.silent", - return_value=self.mock_audio_segment, - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech( - long_text, - "Long Article", - ) - - # Should have called TTS for each chunk - self.assertIsInstance(audio_data, bytes) - self.assertEqual(audio_data, b"test-audio-output") - - def test_tts_empty_text(self) -> None: - """Handle empty input.""" - with unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=[], - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - with pytest.raises(ValueError, match="No chunks generated") as cm: - processor.text_to_speech("", "Empty") - - self.assertIn("No chunks generated", str(cm.value)) - - def test_tts_special_characters(self) -> None: - """Handle unicode and special chars.""" - special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' - - def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: - buffer.write(b"test-audio-output") - buffer.seek(0) - - self.mock_audio_segment.export.side_effect = mock_export - - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - mock_speech.create.return_value = self.mock_audio_response - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=[special_text], - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech( - special_text, - "Special", - ) - - self.assertIsInstance(audio_data, bytes) - self.assertEqual(audio_data, b"test-audio-output") - - def test_llm_text_preparation(self) -> None: - """Verify LLM editing.""" - # Test the actual text preparation functions - chunks = split_text_into_chunks("Short text", max_chars=100) - self.assertEqual(len(chunks), 1) - self.assertEqual(chunks[0], "Short text") - - # Test long text splitting - long_text = "Sentence one. " * 100 - chunks = split_text_into_chunks(long_text, max_chars=100) - self.assertGreater(len(chunks), 1) - for chunk in chunks: - self.assertLessEqual(len(chunk), 100) - - @staticmethod - def test_llm_failure_fallback() -> None: - """Handle LLM API failures.""" - # Mock LLM failure - with unittest.mock.patch("openai.OpenAI") as mock_openai: - mock_client = mock_openai.return_value - mock_client.chat.completions.create.side_effect = Exception( - "API Error", - ) - - # Should fall back to raw text - with pytest.raises(Exception, match="API Error"): - edit_chunk_for_speech("Test chunk", "Title", is_first=True) - - def test_chunk_concatenation(self) -> None: - """Verify audio joining.""" - # Mock multiple audio segments - chunks = ["Chunk 1", "Chunk 2"] - - def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: - buffer.write(b"test-audio-output") - buffer.seek(0) - - self.mock_audio_segment.export.side_effect = mock_export - - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - mock_speech.create.return_value = self.mock_audio_response - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=chunks, - ), - unittest.mock.patch( - "pydub.AudioSegment.silent", - return_value=self.mock_audio_segment, - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech("Test", "Title") - - # Should produce combined audio - self.assertIsInstance(audio_data, bytes) - self.assertEqual(audio_data, b"test-audio-output") - - def test_parallel_tts_generation(self) -> None: - """Test parallel TTS processing.""" - chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"] - - # Mock responses for each chunk - mock_responses = [] - for i in range(len(chunks)): - mock_resp = unittest.mock.MagicMock() - mock_resp.content = f"audio-{i}".encode() - mock_responses.append(mock_resp) - - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - - # Make create return different responses for each call - mock_speech.create.side_effect = mock_responses - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - # Mock AudioSegment operations - mock_segment = unittest.mock.MagicMock() - mock_segment.__add__.return_value = mock_segment - - def mock_export(path: str, **_kwargs: typing.Any) -> None: - Path(path).write_bytes(b"combined-audio") - - mock_segment.export = mock_export - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=chunks, - ), - unittest.mock.patch( - "pydub.AudioSegment.from_mp3", - return_value=mock_segment, - ), - unittest.mock.patch( - "pydub.AudioSegment.silent", - return_value=mock_segment, - ), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.check_memory_usage", - return_value=50.0, # Normal memory usage - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech("Test content", "Test Title") - - # Verify all chunks were processed - self.assertEqual(mock_speech.create.call_count, len(chunks)) - self.assertEqual(audio_data, b"combined-audio") - - def test_parallel_tts_high_memory_fallback(self) -> None: - """Test fallback to serial processing when memory is high.""" - chunks = ["Chunk 1", "Chunk 2"] - - def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: - buffer.write(b"serial-audio") - buffer.seek(0) - - self.mock_audio_segment.export.side_effect = mock_export - - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - mock_speech.create.return_value = self.mock_audio_response - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=chunks, - ), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.check_memory_usage", - return_value=65.0, # High memory usage - ), - unittest.mock.patch( - "pydub.AudioSegment.silent", - return_value=self.mock_audio_segment, - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech("Test content", "Test Title") - - # Should use serial processing - self.assertEqual(audio_data, b"serial-audio") - - @staticmethod - def test_parallel_tts_error_handling() -> None: - """Test error handling in parallel TTS processing.""" - chunks = ["Chunk 1", "Chunk 2"] - - # Mock OpenAI client with one failure - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - - # First call succeeds, second fails - mock_resp1 = unittest.mock.MagicMock() - mock_resp1.content = b"audio-1" - mock_speech.create.side_effect = [mock_resp1, Exception("API Error")] - - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - # Set up the test context - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=chunks, - ), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.check_memory_usage", - return_value=50.0, - ), - pytest.raises(Exception, match="API Error"), - ): - processor.text_to_speech("Test content", "Test Title") - - def test_parallel_tts_order_preservation(self) -> None: - """Test that chunks are combined in the correct order.""" - chunks = ["First", "Second", "Third", "Fourth", "Fifth"] - - # Create mock responses with identifiable content - mock_responses = [] - for chunk in chunks: - mock_resp = unittest.mock.MagicMock() - mock_resp.content = f"audio-{chunk}".encode() - mock_responses.append(mock_resp) - - # Mock OpenAI client - mock_client = unittest.mock.MagicMock() - mock_audio = unittest.mock.MagicMock() - mock_speech = unittest.mock.MagicMock() - mock_speech.create.side_effect = mock_responses - mock_audio.speech = mock_speech - mock_client.audio = mock_audio - - # Track the order of segments being combined - combined_order = [] - - def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock: - content = data.read() - combined_order.append(content.decode()) - segment = unittest.mock.MagicMock() - segment.__add__.return_value = segment - return segment - - mock_segment = unittest.mock.MagicMock() - mock_segment.__add__.return_value = mock_segment - - def mock_export(path: str, **_kwargs: typing.Any) -> None: - # Verify order is preserved - expected_order = [f"audio-{chunk}" for chunk in chunks] - if combined_order != expected_order: - msg = f"Order mismatch: {combined_order} != {expected_order}" - raise AssertionError(msg) - Path(path).write_bytes(b"ordered-audio") - - mock_segment.export = mock_export - - with ( - unittest.mock.patch("openai.OpenAI", return_value=mock_client), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.prepare_text_for_tts", - return_value=chunks, - ), - unittest.mock.patch( - "pydub.AudioSegment.from_mp3", - side_effect=mock_from_mp3, - ), - unittest.mock.patch( - "pydub.AudioSegment.silent", - return_value=mock_segment, - ), - unittest.mock.patch( - "Biz.PodcastItLater.Worker.check_memory_usage", - return_value=50.0, - ), - ): - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - audio_data = processor.text_to_speech("Test content", "Test Title") - - self.assertEqual(audio_data, b"ordered-audio") - - -class TestIntroOutro(Test.TestCase): - """Test intro and outro generation with metadata.""" - - def test_create_intro_text_full_metadata(self) -> None: - """Test intro text creation with all metadata.""" - intro = ArticleProcessor._create_intro_text( # noqa: SLF001 - title="Test Article", - author="John Doe", - pub_date="2024-01-15", - ) - self.assertIn("Title: Test Article", intro) - self.assertIn("Author: John Doe", intro) - self.assertIn("Published: 2024-01-15", intro) - - def test_create_intro_text_no_author(self) -> None: - """Test intro text without author.""" - intro = ArticleProcessor._create_intro_text( # noqa: SLF001 - title="Test Article", - author=None, - pub_date="2024-01-15", - ) - self.assertIn("Title: Test Article", intro) - self.assertNotIn("Author:", intro) - self.assertIn("Published: 2024-01-15", intro) - - def test_create_intro_text_minimal(self) -> None: - """Test intro text with only title.""" - intro = ArticleProcessor._create_intro_text( # noqa: SLF001 - title="Test Article", - author=None, - pub_date=None, - ) - self.assertEqual(intro, "Title: Test Article.") - - def test_create_outro_text_with_author(self) -> None: - """Test outro text with author.""" - outro = ArticleProcessor._create_outro_text( # noqa: SLF001 - title="Test Article", - author="Jane Smith", - ) - self.assertIn("Test Article", outro) - self.assertIn("Jane Smith", outro) - self.assertIn("Podcast It Later", outro) - - def test_create_outro_text_no_author(self) -> None: - """Test outro text without author.""" - outro = ArticleProcessor._create_outro_text( # noqa: SLF001 - title="Test Article", - author=None, - ) - self.assertIn("Test Article", outro) - self.assertNotIn("by", outro) - self.assertIn("Podcast It Later", outro) - - def test_extract_with_metadata(self) -> None: - """Test that extraction returns metadata.""" - mock_html = "<html><body><p>Content</p></body></html>" - mock_result = json.dumps({ - "title": "Test Article", - "text": "Article content", - "author": "Test Author", - "date": "2024-01-15", - }) - - with ( - unittest.mock.patch( - "trafilatura.fetch_url", - return_value=mock_html, - ), - unittest.mock.patch( - "trafilatura.extract", - return_value=mock_result, - ), - ): - title, content, author, pub_date = ( - ArticleProcessor.extract_article_content( - "https://example.com", - ) - ) - - self.assertEqual(title, "Test Article") - self.assertEqual(content, "Article content") - self.assertEqual(author, "Test Author") - self.assertEqual(pub_date, "2024-01-15") - - class TestMemoryEfficiency(Test.TestCase): """Test memory-efficient processing.""" @@ -1719,7 +111,9 @@ class TestMemoryEfficiency(Test.TestCase): ), pytest.raises(ValueError, match="Article too large") as cm, ): - ArticleProcessor.extract_article_content("https://example.com") + Processor.ArticleProcessor.extract_article_content( + "https://example.com" + ) self.assertIn("Article too large", str(cm.value)) @@ -1742,7 +136,7 @@ class TestMemoryEfficiency(Test.TestCase): ), ): title, content, _author, _pub_date = ( - ArticleProcessor.extract_article_content( + Processor.ArticleProcessor.extract_article_content( "https://example.com", ) ) @@ -1752,339 +146,39 @@ class TestMemoryEfficiency(Test.TestCase): def test_memory_usage_check(self) -> None: """Test memory usage monitoring.""" - usage = check_memory_usage() + usage = Processor.check_memory_usage() self.assertIsInstance(usage, float) self.assertGreaterEqual(usage, 0.0) self.assertLessEqual(usage, 100.0) -class TestJobProcessing(Test.TestCase): - """Test job processing functionality.""" +class TestWorkerErrorHandling(Test.TestCase): + """Test worker error handling and recovery.""" def setUp(self) -> None: """Set up test environment.""" Core.Database.init_db() - - # Create test user and job - self.user_id, _ = Core.Database.create_user( - "test@example.com", - ) + self.user_id, _ = Core.Database.create_user("test@example.com") self.job_id = Core.Database.add_to_queue( - "https://example.com/article", + "https://example.com", "test@example.com", self.user_id, ) + self.shutdown_handler = ShutdownHandler() - # Mock environment + # Mock environment for processor self.env_patcher = unittest.mock.patch.dict( os.environ, - { - "OPENAI_API_KEY": "test-key", - "S3_ENDPOINT": "https://s3.example.com", - "S3_BUCKET": "test-bucket", - "S3_ACCESS_KEY": "test-access", - "S3_SECRET_KEY": "test-secret", - }, + {"OPENAI_API_KEY": "test-key"}, ) self.env_patcher.start() + self.processor = Processor.ArticleProcessor(self.shutdown_handler) def tearDown(self) -> None: """Clean up.""" self.env_patcher.stop() Core.Database.teardown() - def test_process_job_success(self) -> None: - """Complete pipeline execution.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - # Mock all external calls - with ( - unittest.mock.patch.object( - ArticleProcessor, - "extract_article_content", - return_value=( - "Test Title", - "Test content", - "Test Author", - "2024-01-15", - ), - ), - unittest.mock.patch.object( - ArticleProcessor, - "text_to_speech", - return_value=b"audio-data", - ), - unittest.mock.patch.object( - processor, - "upload_to_s3", - return_value="https://s3.example.com/audio.mp3", - ), - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.update_job_status", - ) as mock_update, - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.create_episode", - ) as mock_create, - ): - mock_create.return_value = 1 - processor.process_job(job) - - # Verify job was marked complete - mock_update.assert_called_with(self.job_id, "completed") - mock_create.assert_called_once() - - def test_process_job_extraction_failure(self) -> None: - """Handle bad URLs.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - with ( - unittest.mock.patch.object( - ArticleProcessor, - "extract_article_content", - side_effect=ValueError("Bad URL"), - ), - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.update_job_status", - ) as mock_update, - pytest.raises(ValueError, match="Bad URL"), - ): - processor.process_job(job) - - # Job should be marked as error - mock_update.assert_called_with(self.job_id, "error", "Bad URL") - - def test_process_job_tts_failure(self) -> None: - """Handle TTS errors.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - with ( - unittest.mock.patch.object( - ArticleProcessor, - "extract_article_content", - return_value=("Title", "Content"), - ), - unittest.mock.patch.object( - ArticleProcessor, - "text_to_speech", - side_effect=Exception("TTS Error"), - ), - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.update_job_status", - ) as mock_update, - pytest.raises(Exception, match="TTS Error"), - ): - processor.process_job(job) - - mock_update.assert_called_with(self.job_id, "error", "TTS Error") - - def test_process_job_s3_failure(self) -> None: - """Handle upload errors.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - with ( - unittest.mock.patch.object( - ArticleProcessor, - "extract_article_content", - return_value=("Title", "Content"), - ), - unittest.mock.patch.object( - ArticleProcessor, - "text_to_speech", - return_value=b"audio", - ), - unittest.mock.patch.object( - processor, - "upload_to_s3", - side_effect=ClientError({}, "PutObject"), - ), - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.update_job_status", - ), - pytest.raises(ClientError), - ): - processor.process_job(job) - - def test_job_retry_logic(self) -> None: - """Verify exponential backoff.""" - # Set job to error with retry count - Core.Database.update_job_status( - self.job_id, - "error", - "First failure", - ) - Core.Database.update_job_status( - self.job_id, - "error", - "Second failure", - ) - - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - self.assertEqual(job["retry_count"], 2) - - # Should be retryable - retryable = Core.Database.get_retryable_jobs( - max_retries=3, - ) - self.assertEqual(len(retryable), 1) - - def test_max_retries(self) -> None: - """Stop after max attempts.""" - # Exceed retry limit - for i in range(4): - Core.Database.update_job_status( - self.job_id, - "error", - f"Failure {i}", - ) - - # Should not be retryable - retryable = Core.Database.get_retryable_jobs( - max_retries=3, - ) - self.assertEqual(len(retryable), 0) - - def test_graceful_shutdown(self) -> None: - """Test graceful shutdown during job processing.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - def mock_tts(*_args: Any) -> bytes: - shutdown_handler.shutdown_requested.set() - return b"audio-data" - - with ( - unittest.mock.patch.object( - ArticleProcessor, - "extract_article_content", - return_value=( - "Test Title", - "Test content", - "Test Author", - "2024-01-15", - ), - ), - unittest.mock.patch.object( - ArticleProcessor, - "text_to_speech", - side_effect=mock_tts, - ), - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.update_job_status", - ) as mock_update, - ): - processor.process_job(job) - - # Job should be reset to pending due to shutdown - mock_update.assert_any_call(self.job_id, "pending") - - def test_cleanup_stale_jobs(self) -> None: - """Test cleanup of stale processing jobs.""" - # Manually set job to processing - Core.Database.update_job_status(self.job_id, "processing") - - # Run cleanup - cleanup_stale_jobs() - - # Job should be back to pending - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - self.assertEqual(job["status"], "pending") - - def test_concurrent_processing(self) -> None: - """Handle multiple jobs.""" - # Create multiple jobs - job2 = Core.Database.add_to_queue( - "https://example.com/2", - "test@example.com", - self.user_id, - ) - job3 = Core.Database.add_to_queue( - "https://example.com/3", - "test@example.com", - self.user_id, - ) - - # Get pending jobs - jobs = Core.Database.get_pending_jobs(limit=5) - - self.assertEqual(len(jobs), 3) - self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) - - def test_memory_threshold_deferral(self) -> None: - """Test that jobs are deferred when memory usage is high.""" - shutdown_handler = ShutdownHandler() - processor = ArticleProcessor(shutdown_handler) - job = Core.Database.get_job_by_id(self.job_id) - if job is None: - msg = "no job found for %s" - raise Test.TestError(msg, self.job_id) - - # Mock high memory usage - with ( - unittest.mock.patch( - "Biz.PodcastItLater.Worker.check_memory_usage", - return_value=90.0, # High memory usage - ), - unittest.mock.patch( - "Biz.PodcastItLater.Core.Database.update_job_status", - ) as mock_update, - ): - processor.process_job(job) - - # Job should not be processed (no status updates) - mock_update.assert_not_called() - - -class TestWorkerErrorHandling(Test.TestCase): - """Test worker error handling and recovery.""" - - def setUp(self) -> None: - """Set up test environment.""" - Core.Database.init_db() - self.user_id, _ = Core.Database.create_user("test@example.com") - self.job_id = Core.Database.add_to_queue( - "https://example.com", - "test@example.com", - self.user_id, - ) - self.shutdown_handler = ShutdownHandler() - self.processor = ArticleProcessor(self.shutdown_handler) - - @staticmethod - def tearDown() -> None: - """Clean up.""" - Core.Database.teardown() - def test_process_pending_jobs_exception_handling(self) -> None: """Test that process_pending_jobs handles exceptions.""" @@ -2105,7 +199,7 @@ class TestWorkerErrorHandling(Test.TestCase): side_effect=Core.Database.update_job_status, ) as _mock_update, ): - process_pending_jobs(self.processor) + Jobs.process_pending_jobs(self.processor) # Job should be marked as error job = Core.Database.get_job_by_id(self.job_id) @@ -2132,7 +226,7 @@ class TestWorkerErrorHandling(Test.TestCase): ) conn.commit() - process_retryable_jobs() + Jobs.process_retryable_jobs() job = Core.Database.get_job_by_id(self.job_id) self.assertIsNotNone(job) @@ -2145,7 +239,7 @@ class TestWorkerErrorHandling(Test.TestCase): Core.Database.update_job_status(self.job_id, "error", "Fail 1") # created_at is now, so backoff should prevent retry - process_retryable_jobs() + Jobs.process_retryable_jobs() job = Core.Database.get_job_by_id(self.job_id) self.assertIsNotNone(job) @@ -2153,40 +247,18 @@ class TestWorkerErrorHandling(Test.TestCase): self.assertEqual(job["status"], "error") -class TestTextChunking(Test.TestCase): - """Test text chunking edge cases.""" - - def test_split_text_single_long_word(self) -> None: - """Handle text with a single word exceeding limit.""" - long_word = "a" * 4000 - chunks = split_text_into_chunks(long_word, max_chars=3000) - - # Should keep it as one chunk or split? - # The current implementation does not split words - self.assertEqual(len(chunks), 1) - self.assertEqual(len(chunks[0]), 4000) - - def test_split_text_no_sentence_boundaries(self) -> None: - """Handle long text with no sentence boundaries.""" - text = "word " * 1000 # 5000 chars - chunks = split_text_into_chunks(text, max_chars=3000) - - # Should keep it as one chunk as it can't split by ". " - self.assertEqual(len(chunks), 1) - self.assertGreater(len(chunks[0]), 3000) - - def test() -> None: """Run the tests.""" Test.run( App.Area.Test, [ - TestArticleExtraction, - TestTextToSpeech, + Processor.TestArticleExtraction, + Processor.TestTextToSpeech, + Processor.TestIntroOutro, TestMemoryEfficiency, - TestJobProcessing, + Jobs.TestJobProcessing, TestWorkerErrorHandling, - TestTextChunking, + TextProcessing.TestTextChunking, ], ) diff --git a/Biz/PodcastItLater/Worker/Jobs.py b/Biz/PodcastItLater/Worker/Jobs.py new file mode 100644 index 0000000..630aaf0 --- /dev/null +++ b/Biz/PodcastItLater/Worker/Jobs.py @@ -0,0 +1,506 @@ +"""Job management and processing functions.""" + +# : dep pytest +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import pytest +import sys +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from typing import Any + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Worker configuration +POLL_INTERVAL = 30 # seconds +MAX_RETRIES = 3 + + +def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: + """Parse datetime string and ensure it has timezone info.""" + if isinstance(created_at, str): + # Handle timezone-aware datetime strings + if created_at.endswith("Z"): + created_at = created_at[:-1] + "+00:00" + last_attempt = datetime.fromisoformat(created_at) + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + else: + last_attempt = created_at + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + return last_attempt + + +def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: + """Check if a job should be retried based on retry count and backoff time. + + Uses exponential backoff to determine if enough time has passed. + """ + retry_count = job["retry_count"] + if retry_count >= max_retries: + return False + + # Exponential backoff: 30s, 60s, 120s + backoff_time = 30 * (2**retry_count) + last_attempt = parse_datetime_with_timezone(job["created_at"]) + time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt + + return time_since_attempt > timedelta(seconds=backoff_time) + + +def process_pending_jobs( + processor: Any, +) -> None: + """Process all pending jobs.""" + pending_jobs = Core.Database.get_pending_jobs( + limit=5, + ) + + for job in pending_jobs: + if processor.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, stopping job processing") + break + + current_job = job["id"] + try: + processor.process_job(job) + except Exception as e: + # Ensure job is marked as error even if process_job didn't handle it + logger.exception("Failed to process job: %d", current_job) + # Check if job is still in processing state + current_status = Core.Database.get_job_by_id( + current_job, + ) + if current_status and current_status.get("status") == "processing": + Core.Database.update_job_status( + current_job, + "error", + str(e), + ) + continue + + +def process_retryable_jobs() -> None: + """Check and retry failed jobs with exponential backoff.""" + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + for job in retryable_jobs: + if should_retry_job(job, MAX_RETRIES): + logger.info( + "Retrying job %d (attempt %d)", + job["id"], + job["retry_count"] + 1, + ) + Core.Database.update_job_status( + job["id"], + "pending", + ) + + +def cleanup_stale_jobs() -> None: + """Reset jobs stuck in processing state on startup.""" + with Core.Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE queue + SET status = 'pending' + WHERE status = 'processing' + """, + ) + affected = cursor.rowcount + conn.commit() + + if affected > 0: + logger.info( + "Reset %d stale jobs from processing to pending", + affected, + ) + + +def main_loop(shutdown_handler: Any, processor: Any) -> None: + """Poll for jobs and process them in a continuous loop.""" + # Clean up any stale jobs from previous runs + cleanup_stale_jobs() + + logger.info("Worker started, polling for jobs...") + + while not shutdown_handler.is_shutdown_requested(): + try: + # Process pending jobs + process_pending_jobs(processor) + process_retryable_jobs() + + # Check if there's any work + pending_jobs = Core.Database.get_pending_jobs( + limit=1, + ) + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + if not pending_jobs and not retryable_jobs: + logger.debug("No jobs to process, sleeping...") + + except Exception: + logger.exception("Error in main loop") + + # Use interruptible sleep + if not shutdown_handler.is_shutdown_requested(): + shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL) + + # Graceful shutdown + current_job = shutdown_handler.get_current_job() + if current_job: + logger.info( + "Waiting for job %d to complete before shutdown...", + current_job, + ) + # The job will complete or be reset to pending + + logger.info("Worker shutdown complete") + + +class TestJobProcessing(Test.TestCase): + """Test job processing functionality.""" + + def setUp(self) -> None: + """Set up test environment.""" + # Import here to avoid circular dependencies + import Biz.PodcastItLater.Worker as Worker + + Core.Database.init_db() + + # Create test user and job + self.user_id, _ = Core.Database.create_user( + "test@example.com", + ) + self.job_id = Core.Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + # Mock environment + self.env_patcher = unittest.mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test-key", + "S3_ENDPOINT": "https://s3.example.com", + "S3_BUCKET": "test-bucket", + "S3_ACCESS_KEY": "test-access", + "S3_SECRET_KEY": "test-secret", + }, + ) + self.env_patcher.start() + + # Create processor + self.shutdown_handler = Worker.ShutdownHandler() + # Import ArticleProcessor from Processor module + from Biz.PodcastItLater.Worker.Processor import ArticleProcessor + + self.processor = ArticleProcessor(self.shutdown_handler) + + def tearDown(self) -> None: + """Clean up.""" + self.env_patcher.stop() + Core.Database.teardown() + + def test_process_job_success(self) -> None: + """Complete pipeline execution.""" + from Biz.PodcastItLater.Worker.Processor import ArticleProcessor + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock all external calls + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio-data", + ), + unittest.mock.patch.object( + self.processor, + "upload_to_s3", + return_value="https://s3.example.com/audio.mp3", + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.create_episode", + ) as mock_create, + ): + mock_create.return_value = 1 + self.processor.process_job(job) + + # Verify job was marked complete + mock_update.assert_called_with(self.job_id, "completed") + mock_create.assert_called_once() + + def test_process_job_extraction_failure(self) -> None: + """Handle bad URLs.""" + from Biz.PodcastItLater.Worker.Processor import ArticleProcessor + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + side_effect=ValueError("Bad URL"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(ValueError, match="Bad URL"), + ): + self.processor.process_job(job) + + # Job should be marked as error + mock_update.assert_called_with(self.job_id, "error", "Bad URL") + + def test_process_job_tts_failure(self) -> None: + """Handle TTS errors.""" + from Biz.PodcastItLater.Worker.Processor import ArticleProcessor + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content", None, None), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=Exception("TTS Error"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(Exception, match="TTS Error"), + ): + self.processor.process_job(job) + + mock_update.assert_called_with(self.job_id, "error", "TTS Error") + + def test_process_job_s3_failure(self) -> None: + """Handle upload errors.""" + from Biz.PodcastItLater.Worker.Processor import ArticleProcessor + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content", None, None), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio", + ), + unittest.mock.patch.object( + self.processor, + "upload_to_s3", + side_effect=ClientError({}, "PutObject"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ), + pytest.raises(ClientError), + ): + self.processor.process_job(job) + + def test_job_retry_logic(self) -> None: + """Verify exponential backoff.""" + # Set job to error with retry count + Core.Database.update_job_status( + self.job_id, + "error", + "First failure", + ) + Core.Database.update_job_status( + self.job_id, + "error", + "Second failure", + ) + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + self.assertEqual(job["retry_count"], 2) + + # Should be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + + def test_max_retries(self) -> None: + """Stop after max attempts.""" + # Exceed retry limit + for i in range(4): + Core.Database.update_job_status( + self.job_id, + "error", + f"Failure {i}", + ) + + # Should not be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_graceful_shutdown(self) -> None: + """Test graceful shutdown during job processing.""" + from Biz.PodcastItLater.Worker.Processor import ArticleProcessor + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + def mock_tts(*_args: Any, **_kwargs: Any) -> bytes: + self.shutdown_handler.shutdown_requested.set() + return b"audio-data" + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=mock_tts, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + self.processor.process_job(job) + + # Job should be reset to pending due to shutdown + mock_update.assert_any_call(self.job_id, "pending") + + def test_cleanup_stale_jobs(self) -> None: + """Test cleanup of stale processing jobs.""" + # Manually set job to processing + Core.Database.update_job_status(self.job_id, "processing") + + # Run cleanup + cleanup_stale_jobs() + + # Job should be back to pending + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + self.assertEqual(job["status"], "pending") + + def test_concurrent_processing(self) -> None: + """Handle multiple jobs.""" + # Create multiple jobs + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Core.Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Core.Database.get_pending_jobs(limit=5) + + self.assertEqual(len(jobs), 3) + self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + + def test_memory_threshold_deferral(self) -> None: + """Test that jobs are deferred when memory usage is high.""" + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock high memory usage + with ( + unittest.mock.patch( + "Biz.PodcastItLater.Worker.Processor.check_memory_usage", + return_value=90.0, # High memory usage + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + self.processor.process_job(job) + + # Job should not be processed (no status updates) + mock_update.assert_not_called() + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestJobProcessing, + ], + ) + + +def main() -> None: + """Entry point for the module.""" + if "test" in sys.argv: + test() + else: + logger.info("Jobs module loaded") diff --git a/Biz/PodcastItLater/Worker/Processor.py b/Biz/PodcastItLater/Worker/Processor.py new file mode 100644 index 0000000..bdda3e5 --- /dev/null +++ b/Biz/PodcastItLater/Worker/Processor.py @@ -0,0 +1,1382 @@ +"""Article processing for podcast conversion.""" + +# : dep boto3 +# : dep botocore +# : dep openai +# : dep psutil +# : dep pydub +# : dep pytest +# : dep pytest-mock +# : dep trafilatura +import Biz.PodcastItLater.Core as Core +import Biz.PodcastItLater.Worker.TextProcessing as TextProcessing +import boto3 # type: ignore[import-untyped] +import concurrent.futures +import io +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import operator +import os +import psutil # type: ignore[import-untyped] +import pytest +import sys +import tempfile +import time +import trafilatura +import typing +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timezone +from pathlib import Path +from pydub import AudioSegment # type: ignore[import-untyped] +from typing import Any + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +S3_ENDPOINT = os.getenv("S3_ENDPOINT") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") + +# Worker configuration +MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles +TTS_MODEL = "tts-1" +TTS_VOICE = "alloy" +MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage +CROSSFADE_DURATION = 500 # ms for crossfading segments +PAUSE_DURATION = 1000 # ms for silence between segments + + +def check_memory_usage() -> int | Any: + """Check current memory usage percentage.""" + try: + process = psutil.Process() + # this returns an int but psutil is untyped + return process.memory_percent() + except (psutil.Error, OSError): + logger.warning("Failed to check memory usage") + return 0.0 + + +class ArticleProcessor: + """Handles the complete article-to-podcast conversion pipeline.""" + + def __init__(self, shutdown_handler: Any) -> None: + """Initialize the processor with required services. + + Raises: + ValueError: If OPENAI_API_KEY environment variable is not set. + """ + if not OPENAI_API_KEY: + msg = "OPENAI_API_KEY environment variable is required" + raise ValueError(msg) + + self.openai_client: openai.OpenAI = openai.OpenAI( + api_key=OPENAI_API_KEY, + ) + self.shutdown_handler = shutdown_handler + + # Initialize S3 client for Digital Ocean Spaces + if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): + self.s3_client: Any = boto3.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + ) + else: + logger.warning("S3 configuration incomplete, uploads will fail") + self.s3_client = None + + @staticmethod + def extract_article_content( + url: str, + ) -> tuple[str, str, str | None, str | None]: + """Extract title, content, author, and date from article URL. + + Returns: + tuple: (title, content, author, publication_date) + + Raises: + ValueError: If content cannot be downloaded, extracted, or large. + """ + try: + downloaded = trafilatura.fetch_url(url) + if not downloaded: + msg = f"Failed to download content from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Check size before processing + if ( + len(downloaded) > MAX_ARTICLE_SIZE * 4 + ): # Rough HTML to text ratio + msg = f"Article too large: {len(downloaded)} bytes" + raise ValueError(msg) # noqa: TRY301 + + # Extract with metadata + result = trafilatura.extract( + downloaded, + include_comments=False, + include_tables=False, + with_metadata=True, + output_format="json", + ) + + if not result: + msg = f"Failed to extract content from {url}" + raise ValueError(msg) # noqa: TRY301 + + data = json.loads(result) + + title = data.get("title", "Untitled Article") + content = data.get("text", "") + author = data.get("author") + pub_date = data.get("date") + + if not content: + msg = f"No content extracted from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Enforce content size limit + if len(content) > MAX_ARTICLE_SIZE: + logger.warning( + "Article content truncated from %d to %d characters", + len(content), + MAX_ARTICLE_SIZE, + ) + content = content[:MAX_ARTICLE_SIZE] + + logger.info( + "Extracted article: %s (%d chars, author: %s, date: %s)", + title, + len(content), + author or "unknown", + pub_date or "unknown", + ) + except Exception: + logger.exception("Failed to extract content from %s", url) + raise + else: + return title, content, author, pub_date + + def text_to_speech( + self, + text: str, + title: str, + author: str | None = None, + pub_date: str | None = None, + ) -> bytes: + """Convert text to speech with intro/outro using OpenAI TTS API. + + Uses parallel processing for chunks while maintaining order. + Adds intro with metadata and outro with attribution. + + Args: + text: Article content to convert + title: Article title + author: Article author (optional) + pub_date: Publication date (optional) + + Raises: + ValueError: If no chunks are generated from text. + """ + try: + # Generate intro audio + intro_text = self._create_intro_text(title, author, pub_date) + intro_audio = self._generate_tts_segment(intro_text) + + # Generate outro audio + outro_text = self._create_outro_text(title, author) + outro_audio = self._generate_tts_segment(outro_text) + + # Use LLM to prepare and chunk the main content + chunks = TextProcessing.prepare_text_for_tts(text, title) + + if not chunks: + msg = "No chunks generated from text" + raise ValueError(msg) # noqa: TRY301 + + logger.info("Processing %d chunks for TTS", len(chunks)) + + # Check memory before parallel processing + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer + logger.warning( + "High memory usage (%.1f%%), falling back to serial " + "processing", + mem_usage, + ) + content_audio_bytes = self._text_to_speech_serial(chunks) + else: + # Determine max workers + max_workers = min( + 4, # Reasonable limit to avoid rate limiting + len(chunks), # No more workers than chunks + max(1, psutil.cpu_count() // 2), # Use half of CPU cores + ) + + logger.info( + "Using %d workers for parallel TTS processing", + max_workers, + ) + + # Process chunks in parallel + chunk_results: list[tuple[int, bytes]] = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + ) as executor: + # Submit all chunks for processing + future_to_index = { + executor.submit(self._process_tts_chunk, chunk, i): i + for i, chunk in enumerate(chunks) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed( + future_to_index, + ): + index = future_to_index[future] + try: + audio_data = future.result() + chunk_results.append((index, audio_data)) + except Exception: + logger.exception( + "Failed to process chunk %d", + index, + ) + raise + + # Sort results by index to maintain order + chunk_results.sort(key=operator.itemgetter(0)) + + # Combine audio chunks + content_audio_bytes = self._combine_audio_chunks([ + data for _, data in chunk_results + ]) + + # Combine intro, content, and outro with pauses + return ArticleProcessor._combine_intro_content_outro( + intro_audio, + content_audio_bytes, + outro_audio, + ) + + except Exception: + logger.exception("TTS generation failed") + raise + + @staticmethod + def _create_intro_text( + title: str, + author: str | None, + pub_date: str | None, + ) -> str: + """Create intro text with available metadata.""" + parts = [f"Title: {title}"] + + if author: + parts.append(f"Author: {author}") + + if pub_date: + parts.append(f"Published: {pub_date}") + + return ". ".join(parts) + "." + + @staticmethod + def _create_outro_text(title: str, author: str | None) -> str: + """Create outro text with attribution.""" + if author: + return ( + f"This has been an audio version of {title} " + f"by {author}, created using Podcast It Later." + ) + return ( + f"This has been an audio version of {title}, " + "created using Podcast It Later." + ) + + def _generate_tts_segment(self, text: str) -> bytes: + """Generate TTS audio for a single segment (intro/outro). + + Args: + text: Text to convert to speech + + Returns: + MP3 audio bytes + """ + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=text, + ) + return response.content + + @staticmethod + def _combine_intro_content_outro( + intro_audio: bytes, + content_audio: bytes, + outro_audio: bytes, + ) -> bytes: + """Combine intro, content, and outro with crossfades. + + Args: + intro_audio: MP3 bytes for intro + content_audio: MP3 bytes for main content + outro_audio: MP3 bytes for outro + + Returns: + Combined MP3 audio bytes + """ + # Load audio segments + intro = AudioSegment.from_mp3(io.BytesIO(intro_audio)) + content = AudioSegment.from_mp3(io.BytesIO(content_audio)) + outro = AudioSegment.from_mp3(io.BytesIO(outro_audio)) + + # Create bridge silence (pause + 2 * crossfade to account for overlap) + bridge = AudioSegment.silent( + duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION + ) + + def safe_append( + seg1: AudioSegment, seg2: AudioSegment, crossfade: int + ) -> AudioSegment: + if len(seg1) < crossfade or len(seg2) < crossfade: + logger.warning( + "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation", + crossfade, + len(seg1), + len(seg2), + ) + return seg1 + seg2 + return seg1.append(seg2, crossfade=crossfade) + + # Combine segments with crossfades + # Intro -> Bridge -> Content -> Bridge -> Outro + # This effectively fades out the previous segment and fades in the next one + combined = safe_append(intro, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, content, CROSSFADE_DURATION) + combined = safe_append(combined, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, outro, CROSSFADE_DURATION) + + # Export to bytes + output = io.BytesIO() + combined.export(output, format="mp3") + return output.getvalue() + + def _process_tts_chunk(self, chunk: str, index: int) -> bytes: + """Process a single TTS chunk. + + Args: + chunk: Text to convert to speech + index: Chunk index for logging + + Returns: + Audio data as bytes + """ + logger.info( + "Generating TTS for chunk %d (%d chars)", + index + 1, + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + return response.content + + @staticmethod + def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes: + """Combine multiple audio chunks with silence gaps. + + Args: + audio_chunks: List of audio data in order + + Returns: + Combined audio data + """ + if not audio_chunks: + msg = "No audio chunks to combine" + raise ValueError(msg) + + # Create a temporary file for the combined audio + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Start with the first chunk + combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0])) + + # Add remaining chunks with silence gaps + for chunk_data in audio_chunks[1:]: + chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data)) + silence = AudioSegment.silent(duration=300) # 300ms gap + combined_audio = combined_audio + silence + chunk_audio + + # Export to file + combined_audio.export(temp_path, format="mp3", bitrate="128k") + + # Read back the combined audio + return Path(temp_path).read_bytes() + + finally: + # Clean up temp file + if Path(temp_path).exists(): + Path(temp_path).unlink() + + def _text_to_speech_serial(self, chunks: list[str]) -> bytes: + """Fallback serial processing for high memory situations. + + This is the original serial implementation. + """ + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunks[0], + response_format="mp3", + ) + + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment + + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data + + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() + + def upload_to_s3(self, audio_data: bytes, filename: str) -> str: + """Upload audio file to S3-compatible storage and return public URL. + + Raises: + ValueError: If S3 client is not configured. + ClientError: If S3 upload fails. + """ + if not self.s3_client: + msg = "S3 client not configured" + raise ValueError(msg) + + try: + # Upload file using streaming to minimize memory usage + audio_stream = io.BytesIO(audio_data) + self.s3_client.upload_fileobj( + audio_stream, + S3_BUCKET, + filename, + ExtraArgs={ + "ContentType": "audio/mpeg", + "ACL": "public-read", + }, + ) + + # Construct public URL + audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" + logger.info( + "Uploaded audio to: %s (%d bytes)", + audio_url, + len(audio_data), + ) + except ClientError: + logger.exception("S3 upload failed") + raise + else: + return audio_url + + @staticmethod + def estimate_duration(audio_data: bytes) -> int: + """Estimate audio duration in seconds based on file size and bitrate.""" + # Rough estimation: MP3 at 128kbps = ~16KB per second + estimated_seconds = len(audio_data) // 16000 + return max(1, estimated_seconds) # Minimum 1 second + + @staticmethod + def generate_filename(job_id: int, title: str) -> str: + """Generate unique filename for audio file.""" + timestamp = int(datetime.now(tz=timezone.utc).timestamp()) + # Create safe filename from title + safe_title = "".join( + c for c in title if c.isalnum() or c in {" ", "-", "_"} + ).rstrip() + safe_title = safe_title.replace(" ", "_")[:50] # Limit length + return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" + + def process_job( + self, + job: dict[str, Any], + ) -> None: + """Process a single job through the complete pipeline.""" + job_id = job["id"] + url = job["url"] + + # Check memory before starting + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD: + logger.warning( + "High memory usage (%.1f%%), deferring job %d", + mem_usage, + job_id, + ) + return + + # Track current job for graceful shutdown + self.shutdown_handler.set_current_job(job_id) + + try: + logger.info("Processing job %d: %s", job_id, url) + + # Update status to processing + Core.Database.update_job_status( + job_id, + "processing", + ) + + # Check for shutdown before each major step + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 1: Extract article content + Core.Database.update_job_status(job_id, "extracting") + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content(url) + ) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 2: Generate audio with metadata + Core.Database.update_job_status(job_id, "synthesizing") + audio_data = self.text_to_speech(content, title, author, pub_date) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 3: Upload to S3 + Core.Database.update_job_status(job_id, "uploading") + filename = ArticleProcessor.generate_filename(job_id, title) + audio_url = self.upload_to_s3(audio_data, filename) + + # Step 4: Calculate duration + duration = ArticleProcessor.estimate_duration(audio_data) + + # Step 5: Create episode record + url_hash = Core.hash_url(url) + episode_id = Core.Database.create_episode( + title=title, + audio_url=audio_url, + duration=duration, + content_length=len(content), + user_id=job.get("user_id"), + author=job.get("author"), # Pass author from job + original_url=url, # Pass the original article URL + original_url_hash=url_hash, + ) + + # Add episode to user's feed + user_id = job.get("user_id") + if user_id: + Core.Database.add_episode_to_user(user_id, episode_id) + Core.Database.track_episode_event( + episode_id, + "added", + user_id, + ) + + # Step 6: Mark job as complete + Core.Database.update_job_status( + job_id, + "completed", + ) + + logger.info( + "Successfully processed job %d -> episode %d", + job_id, + episode_id, + ) + + except Exception as e: + error_msg = str(e) + logger.exception("Job %d failed: %s", job_id, error_msg) + Core.Database.update_job_status( + job_id, + "error", + error_msg, + ) + raise + finally: + # Clear current job + self.shutdown_handler.set_current_job(None) + + +class TestArticleExtraction(Test.TestCase): + """Test article extraction functionality.""" + + def test_extract_valid_article(self) -> None: + """Extract from well-formed HTML.""" + # Mock trafilatura.fetch_url and extract + mock_html = ( + "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" + ) + mock_result = json.dumps({ + "title": "Test Article", + "text": "Content here", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Content here") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_missing_title(self) -> None: + """Handle articles without titles.""" + mock_html = "<html><body><p>Content without title</p></body></html>" + mock_result = json.dumps({"text": "Content without title"}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Untitled Article") + self.assertEqual(content, "Content without title") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_empty_content(self) -> None: + """Handle empty articles.""" + mock_html = "<html><body></body></html>" + mock_result = json.dumps({"title": "Empty Article", "text": ""}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + pytest.raises(ValueError, match="No content extracted") as cm, + ): + ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertIn("No content extracted", str(cm.value)) + + def test_extract_network_error(self) -> None: + """Handle connection failures.""" + with ( + unittest.mock.patch("trafilatura.fetch_url", return_value=None), + pytest.raises(ValueError, match="Failed to download") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Failed to download", str(cm.value)) + + @staticmethod + def test_extract_timeout() -> None: + """Handle slow responses.""" + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + side_effect=TimeoutError("Timeout"), + ), + pytest.raises(TimeoutError), + ): + ArticleProcessor.extract_article_content("https://example.com") + + def test_content_sanitization(self) -> None: + """Remove unwanted elements.""" + mock_html = """ + <html><body> + <h1>Article</h1> + <p>Good content</p> + <script>alert('bad')</script> + <table><tr><td>data</td></tr></table> + </body></html> + """ + mock_result = json.dumps({ + "title": "Article", + "text": "Good content", # Tables and scripts removed + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + _title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(content, "Good content") + self.assertNotIn("script", content) + self.assertNotIn("table", content) + + +class TestTextToSpeech(Test.TestCase): + """Test text-to-speech functionality.""" + + def setUp(self) -> None: + """Set up mocks.""" + # Mock OpenAI API key + self.env_patcher = unittest.mock.patch.dict( + os.environ, + {"OPENAI_API_KEY": "test-key"}, + ) + self.env_patcher.start() + + # Mock OpenAI response + self.mock_audio_response: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_response.content = b"fake-audio-data" + + # Mock AudioSegment to avoid ffmpeg issues in tests + self.mock_audio_segment: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_segment.export.return_value = None + self.audio_segment_patcher = unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=self.mock_audio_segment, + ) + self.audio_segment_patcher.start() + + # Mock the concatenation operations + self.mock_audio_segment.__add__.return_value = self.mock_audio_segment + + def tearDown(self) -> None: + """Clean up mocks.""" + self.env_patcher.stop() + self.audio_segment_patcher.stop() + + def test_tts_generation(self) -> None: + """Generate audio from text.""" + # Import ShutdownHandler dynamically to avoid circular import + import Biz.PodcastItLater.Worker as Worker + + # Mock the export to write test audio data + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=["Test content"], + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + "Test content", + "Test Title", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_chunking(self) -> None: + """Handle long articles with chunking.""" + import Biz.PodcastItLater.Worker as Worker + + long_text = "Long content " * 1000 + chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock AudioSegment.silent + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + long_text, + "Long Article", + ) + + # Should have called TTS for each chunk + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_empty_text(self) -> None: + """Handle empty input.""" + import Biz.PodcastItLater.Worker as Worker + + with unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=[], + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + with pytest.raises(ValueError, match="No chunks generated") as cm: + processor.text_to_speech("", "Empty") + + self.assertIn("No chunks generated", str(cm.value)) + + def test_tts_special_characters(self) -> None: + """Handle unicode and special chars.""" + import Biz.PodcastItLater.Worker as Worker + + special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=[special_text], + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + special_text, + "Special", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_llm_text_preparation(self) -> None: + """Verify LLM editing.""" + # Test the actual text preparation functions + chunks = TextProcessing.split_text_into_chunks( + "Short text", max_chars=100 + ) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], "Short text") + + # Test long text splitting + long_text = "Sentence one. " * 100 + chunks = TextProcessing.split_text_into_chunks(long_text, max_chars=100) + self.assertGreater(len(chunks), 1) + for chunk in chunks: + self.assertLessEqual(len(chunk), 100) + + @staticmethod + def test_llm_failure_fallback() -> None: + """Handle LLM API failures.""" + # Mock LLM failure + with unittest.mock.patch("openai.OpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_client.chat.completions.create.side_effect = Exception( + "API Error", + ) + + # Should fall back to raw text + with pytest.raises(Exception, match="API Error"): + TextProcessing.edit_chunk_for_speech( + "Test chunk", "Title", is_first=True + ) + + def test_chunk_concatenation(self) -> None: + """Verify audio joining.""" + import Biz.PodcastItLater.Worker as Worker + + # Mock multiple audio segments + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test", "Title") + + # Should produce combined audio + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_parallel_tts_generation(self) -> None: + """Test parallel TTS processing.""" + import Biz.PodcastItLater.Worker as Worker + + chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"] + + # Mock responses for each chunk + mock_responses = [] + for i in range(len(chunks)): + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{i}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # Make create return different responses for each call + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Mock AudioSegment operations + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + Path(path).write_bytes(b"combined-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=mock_segment, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.Processor.check_memory_usage", + return_value=50.0, # Normal memory usage + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Verify all chunks were processed + self.assertEqual(mock_speech.create.call_count, len(chunks)) + self.assertEqual(audio_data, b"combined-audio") + + def test_parallel_tts_high_memory_fallback(self) -> None: + """Test fallback to serial processing when memory is high.""" + import Biz.PodcastItLater.Worker as Worker + + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"serial-audio") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.Processor.check_memory_usage", + return_value=65.0, # High memory usage + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Should use serial processing + self.assertEqual(audio_data, b"serial-audio") + + @staticmethod + def test_parallel_tts_error_handling() -> None: + """Test error handling in parallel TTS processing.""" + import Biz.PodcastItLater.Worker as Worker + + chunks = ["Chunk 1", "Chunk 2"] + + # Mock OpenAI client with one failure + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # First call succeeds, second fails + mock_resp1 = unittest.mock.MagicMock() + mock_resp1.content = b"audio-1" + mock_speech.create.side_effect = [mock_resp1, Exception("API Error")] + + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Set up the test context + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.Processor.check_memory_usage", + return_value=50.0, + ), + pytest.raises(Exception, match="API Error"), + ): + processor.text_to_speech("Test content", "Test Title") + + def test_parallel_tts_order_preservation(self) -> None: + """Test that chunks are combined in the correct order.""" + import Biz.PodcastItLater.Worker as Worker + + chunks = ["First", "Second", "Third", "Fourth", "Fifth"] + + # Create mock responses with identifiable content + mock_responses = [] + for chunk in chunks: + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{chunk}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Track the order of segments being combined + combined_order = [] + + def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock: + content = data.read() + combined_order.append(content.decode()) + segment = unittest.mock.MagicMock() + segment.__add__.return_value = segment + return segment + + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + # Verify order is preserved + expected_order = [f"audio-{chunk}" for chunk in chunks] + if combined_order != expected_order: + msg = f"Order mismatch: {combined_order} != {expected_order}" + raise AssertionError(msg) + Path(path).write_bytes(b"ordered-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.TextProcessing.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + side_effect=mock_from_mp3, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.Processor.check_memory_usage", + return_value=50.0, + ), + ): + shutdown_handler = Worker.ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + self.assertEqual(audio_data, b"ordered-audio") + + +class TestIntroOutro(Test.TestCase): + """Test intro and outro generation with metadata.""" + + def test_create_intro_text_full_metadata(self) -> None: + """Test intro text creation with all metadata.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author="John Doe", + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertIn("Author: John Doe", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_no_author(self) -> None: + """Test intro text without author.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertNotIn("Author:", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_minimal(self) -> None: + """Test intro text with only title.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date=None, + ) + self.assertEqual(intro, "Title: Test Article.") + + def test_create_outro_text_with_author(self) -> None: + """Test outro text with author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author="Jane Smith", + ) + self.assertIn("Test Article", outro) + self.assertIn("Jane Smith", outro) + self.assertIn("Podcast It Later", outro) + + def test_create_outro_text_no_author(self) -> None: + """Test outro text without author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author=None, + ) + self.assertIn("Test Article", outro) + self.assertNotIn("by", outro) + self.assertIn("Podcast It Later", outro) + + def test_extract_with_metadata(self) -> None: + """Test that extraction returns metadata.""" + mock_html = "<html><body><p>Content</p></body></html>" + mock_result = json.dumps({ + "title": "Test Article", + "text": "Article content", + "author": "Test Author", + "date": "2024-01-15", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Article content") + self.assertEqual(author, "Test Author") + self.assertEqual(pub_date, "2024-01-15") + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestArticleExtraction, + TestTextToSpeech, + TestIntroOutro, + ], + ) + + +def main() -> None: + """Entry point for the module.""" + if "test" in sys.argv: + test() + else: + logger.info("Processor module loaded") diff --git a/Biz/PodcastItLater/Worker/TextProcessing.py b/Biz/PodcastItLater/Worker/TextProcessing.py new file mode 100644 index 0000000..52a7375 --- /dev/null +++ b/Biz/PodcastItLater/Worker/TextProcessing.py @@ -0,0 +1,211 @@ +"""Text processing utilities for TTS conversion.""" + +# : dep openai +# : dep pytest +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import os +import sys + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + + +def prepare_text_for_tts(text: str, title: str) -> list[str]: + """Use LLM to prepare text for TTS, returning chunks ready for speech. + + First splits text mechanically, then has LLM edit each chunk. + """ + # First, split the text into manageable chunks + raw_chunks = split_text_into_chunks(text, max_chars=3000) + + logger.info("Split article into %d raw chunks", len(raw_chunks)) + + # Prepare the first chunk with intro + edited_chunks = [] + + for i, chunk in enumerate(raw_chunks): + is_first = i == 0 + is_last = i == len(raw_chunks) - 1 + + try: + edited_chunk = edit_chunk_for_speech( + chunk, + title=title if is_first else None, + is_first=is_first, + is_last=is_last, + ) + edited_chunks.append(edited_chunk) + except Exception: + logger.exception("Failed to edit chunk %d", i + 1) + # Fall back to raw chunk if LLM fails + if is_first: + edited_chunks.append( + f"This is an audio version of {title}. {chunk}", + ) + elif is_last: + edited_chunks.append(f"{chunk} This concludes the article.") + else: + edited_chunks.append(chunk) + + return edited_chunks + + +def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: + """Split text into chunks at sentence boundaries.""" + chunks = [] + current_chunk = "" + + # Split into paragraphs first + paragraphs = text.split("\n\n") + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + # If paragraph itself is too long, split by sentences + if len(para_stripped) > max_chars: + sentences = para_stripped.split(". ") + for sentence in sentences: + if len(current_chunk) + len(sentence) + 2 < max_chars: + current_chunk += sentence + ". " + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + ". " + # If adding this paragraph would exceed limit, start new chunk + elif len(current_chunk) + len(para_stripped) + 2 > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = para_stripped + " " + else: + current_chunk += para_stripped + " " + + # Don't forget the last chunk + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def edit_chunk_for_speech( + chunk: str, + title: str | None = None, + *, + is_first: bool = False, + is_last: bool = False, +) -> str: + """Use LLM to lightly edit a single chunk for speech. + + Raises: + ValueError: If no content is returned from LLM. + """ + system_prompt = ( + "You are a podcast script editor. Your job is to lightly edit text " + "to make it sound natural when spoken aloud.\n\n" + "Guidelines:\n" + ) + system_prompt += """ +- Remove URLs and email addresses, replacing with descriptive phrases +- Convert bullet points and lists into flowing sentences +- Fix any awkward phrasing for speech +- Remove references like "click here" or "see below" +- Keep edits minimal - preserve the original content and style +- Do NOT add commentary or explanations +- Return ONLY the edited text, no JSON or formatting +""" + + user_prompt = chunk + + # Add intro/outro if needed + if is_first and title: + user_prompt = ( + f"Add a brief intro mentioning this is an audio version of " + f"'{title}', then edit this text:\n\n{chunk}" + ) + elif is_last: + user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" + + try: + client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.3, # Lower temperature for more consistent edits + max_tokens=4000, + ) + + content = response.choices[0].message.content + if not content: + msg = "No content returned from LLM" + raise ValueError(msg) # noqa: TRY301 + + # Ensure the chunk isn't too long + max_chunk_length = 4000 + if len(content) > max_chunk_length: + # Truncate at sentence boundary + sentences = content.split(". ") + truncated = "" + for sentence in sentences: + if len(truncated) + len(sentence) + 2 < max_chunk_length: + truncated += sentence + ". " + else: + break + content = truncated.strip() + + except Exception: + logger.exception("LLM chunk editing failed") + raise + else: + return content + + +class TestTextChunking(Test.TestCase): + """Test text chunking edge cases.""" + + def test_split_text_single_long_word(self) -> None: + """Handle text with a single word exceeding limit.""" + long_word = "a" * 4000 + chunks = split_text_into_chunks(long_word, max_chars=3000) + + # Should keep it as one chunk or split? + # The current implementation does not split words + self.assertEqual(len(chunks), 1) + self.assertEqual(len(chunks[0]), 4000) + + def test_split_text_no_sentence_boundaries(self) -> None: + """Handle long text with no sentence boundaries.""" + text = "word " * 1000 # 5000 chars + chunks = split_text_into_chunks(text, max_chars=3000) + + # Should keep it as one chunk as it can't split by ". " + self.assertEqual(len(chunks), 1) + self.assertGreater(len(chunks[0]), 3000) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestTextChunking, + ], + ) + + +def main() -> None: + """Entry point for the module.""" + if "test" in sys.argv: + test() + else: + logger.info("TextProcessing module loaded") |
