diff options
Diffstat (limited to 'Biz/PodcastItLater/Worker.py')
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 2199 |
1 files changed, 2199 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py new file mode 100644 index 0000000..bf6ef9e --- /dev/null +++ b/Biz/PodcastItLater/Worker.py @@ -0,0 +1,2199 @@ +"""Background worker for processing article-to-podcast conversions.""" + +# : dep boto3 +# : dep botocore +# : dep openai +# : dep psutil +# : dep pydub +# : dep pytest +# : dep pytest-asyncio +# : dep pytest-mock +# : dep trafilatura +# : out podcastitlater-worker +# : run ffmpeg +import Biz.PodcastItLater.Core as Core +import boto3 # type: ignore[import-untyped] +import concurrent.futures +import io +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import openai +import operator +import os +import psutil # type: ignore[import-untyped] +import pytest +import signal +import sys +import tempfile +import threading +import time +import trafilatura +import typing +import unittest.mock +from botocore.exceptions import ClientError # type: ignore[import-untyped] +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from pathlib import Path +from pydub import AudioSegment # type: ignore[import-untyped] +from typing import Any + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Configuration from environment variables +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +S3_ENDPOINT = os.getenv("S3_ENDPOINT") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") +area = App.from_env() + +# Worker configuration +MAX_CONTENT_LENGTH = 5000 # characters for TTS +MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles +POLL_INTERVAL = 30 # seconds +MAX_RETRIES = 3 +TTS_MODEL = "tts-1" +TTS_VOICE = "alloy" +MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage +CROSSFADE_DURATION = 500 # ms for crossfading segments +PAUSE_DURATION = 1000 # ms for silence between segments + + +class ShutdownHandler: + """Handles graceful shutdown of the worker.""" + + def __init__(self) -> None: + """Initialize shutdown handler.""" + self.shutdown_requested = threading.Event() + self.current_job_id: int | None = None + self.lock = threading.Lock() + + # Register signal handlers + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + def _handle_signal(self, signum: int, _frame: Any) -> None: + """Handle shutdown signals.""" + logger.info( + "Received signal %d, initiating graceful shutdown...", + signum, + ) + self.shutdown_requested.set() + + def is_shutdown_requested(self) -> bool: + """Check if shutdown has been requested.""" + return self.shutdown_requested.is_set() + + def set_current_job(self, job_id: int | None) -> None: + """Set the currently processing job.""" + with self.lock: + self.current_job_id = job_id + + def get_current_job(self) -> int | None: + """Get the currently processing job.""" + with self.lock: + return self.current_job_id + + +class ArticleProcessor: + """Handles the complete article-to-podcast conversion pipeline.""" + + def __init__(self, shutdown_handler: ShutdownHandler) -> None: + """Initialize the processor with required services. + + Raises: + ValueError: If OPENAI_API_KEY environment variable is not set. + """ + if not OPENAI_API_KEY: + msg = "OPENAI_API_KEY environment variable is required" + raise ValueError(msg) + + self.openai_client: openai.OpenAI = openai.OpenAI( + api_key=OPENAI_API_KEY, + ) + self.shutdown_handler = shutdown_handler + + # Initialize S3 client for Digital Ocean Spaces + if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]): + self.s3_client: Any = boto3.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + ) + else: + logger.warning("S3 configuration incomplete, uploads will fail") + self.s3_client = None + + @staticmethod + def extract_article_content( + url: str, + ) -> tuple[str, str, str | None, str | None]: + """Extract title, content, author, and date from article URL. + + Returns: + tuple: (title, content, author, publication_date) + + Raises: + ValueError: If content cannot be downloaded, extracted, or large. + """ + try: + downloaded = trafilatura.fetch_url(url) + if not downloaded: + msg = f"Failed to download content from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Check size before processing + if ( + len(downloaded) > MAX_ARTICLE_SIZE * 4 + ): # Rough HTML to text ratio + msg = f"Article too large: {len(downloaded)} bytes" + raise ValueError(msg) # noqa: TRY301 + + # Extract with metadata + result = trafilatura.extract( + downloaded, + include_comments=False, + include_tables=False, + with_metadata=True, + output_format="json", + ) + + if not result: + msg = f"Failed to extract content from {url}" + raise ValueError(msg) # noqa: TRY301 + + data = json.loads(result) + + title = data.get("title", "Untitled Article") + content = data.get("text", "") + author = data.get("author") + pub_date = data.get("date") + + if not content: + msg = f"No content extracted from {url}" + raise ValueError(msg) # noqa: TRY301 + + # Enforce content size limit + if len(content) > MAX_ARTICLE_SIZE: + logger.warning( + "Article content truncated from %d to %d characters", + len(content), + MAX_ARTICLE_SIZE, + ) + content = content[:MAX_ARTICLE_SIZE] + + logger.info( + "Extracted article: %s (%d chars, author: %s, date: %s)", + title, + len(content), + author or "unknown", + pub_date or "unknown", + ) + except Exception: + logger.exception("Failed to extract content from %s", url) + raise + else: + return title, content, author, pub_date + + def text_to_speech( + self, + text: str, + title: str, + author: str | None = None, + pub_date: str | None = None, + ) -> bytes: + """Convert text to speech with intro/outro using OpenAI TTS API. + + Uses parallel processing for chunks while maintaining order. + Adds intro with metadata and outro with attribution. + + Args: + text: Article content to convert + title: Article title + author: Article author (optional) + pub_date: Publication date (optional) + + Raises: + ValueError: If no chunks are generated from text. + """ + try: + # Generate intro audio + intro_text = self._create_intro_text(title, author, pub_date) + intro_audio = self._generate_tts_segment(intro_text) + + # Generate outro audio + outro_text = self._create_outro_text(title, author) + outro_audio = self._generate_tts_segment(outro_text) + + # Use LLM to prepare and chunk the main content + chunks = prepare_text_for_tts(text, title) + + if not chunks: + msg = "No chunks generated from text" + raise ValueError(msg) # noqa: TRY301 + + logger.info("Processing %d chunks for TTS", len(chunks)) + + # Check memory before parallel processing + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer + logger.warning( + "High memory usage (%.1f%%), falling back to serial " + "processing", + mem_usage, + ) + content_audio_bytes = self._text_to_speech_serial(chunks) + else: + # Determine max workers + max_workers = min( + 4, # Reasonable limit to avoid rate limiting + len(chunks), # No more workers than chunks + max(1, psutil.cpu_count() // 2), # Use half of CPU cores + ) + + logger.info( + "Using %d workers for parallel TTS processing", + max_workers, + ) + + # Process chunks in parallel + chunk_results: list[tuple[int, bytes]] = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + ) as executor: + # Submit all chunks for processing + future_to_index = { + executor.submit(self._process_tts_chunk, chunk, i): i + for i, chunk in enumerate(chunks) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed( + future_to_index, + ): + index = future_to_index[future] + try: + audio_data = future.result() + chunk_results.append((index, audio_data)) + except Exception: + logger.exception( + "Failed to process chunk %d", + index, + ) + raise + + # Sort results by index to maintain order + chunk_results.sort(key=operator.itemgetter(0)) + + # Combine audio chunks + content_audio_bytes = self._combine_audio_chunks([ + data for _, data in chunk_results + ]) + + # Combine intro, content, and outro with pauses + return ArticleProcessor._combine_intro_content_outro( + intro_audio, + content_audio_bytes, + outro_audio, + ) + + except Exception: + logger.exception("TTS generation failed") + raise + + @staticmethod + def _create_intro_text( + title: str, + author: str | None, + pub_date: str | None, + ) -> str: + """Create intro text with available metadata.""" + parts = [f"Title: {title}"] + + if author: + parts.append(f"Author: {author}") + + if pub_date: + parts.append(f"Published: {pub_date}") + + return ". ".join(parts) + "." + + @staticmethod + def _create_outro_text(title: str, author: str | None) -> str: + """Create outro text with attribution.""" + if author: + return ( + f"This has been an audio version of {title} " + f"by {author}, created using Podcast It Later." + ) + return ( + f"This has been an audio version of {title}, " + "created using Podcast It Later." + ) + + def _generate_tts_segment(self, text: str) -> bytes: + """Generate TTS audio for a single segment (intro/outro). + + Args: + text: Text to convert to speech + + Returns: + MP3 audio bytes + """ + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=text, + ) + return response.content + + @staticmethod + def _combine_intro_content_outro( + intro_audio: bytes, + content_audio: bytes, + outro_audio: bytes, + ) -> bytes: + """Combine intro, content, and outro with crossfades. + + Args: + intro_audio: MP3 bytes for intro + content_audio: MP3 bytes for main content + outro_audio: MP3 bytes for outro + + Returns: + Combined MP3 audio bytes + """ + # Load audio segments + intro = AudioSegment.from_mp3(io.BytesIO(intro_audio)) + content = AudioSegment.from_mp3(io.BytesIO(content_audio)) + outro = AudioSegment.from_mp3(io.BytesIO(outro_audio)) + + # Create bridge silence (pause + 2 * crossfade to account for overlap) + bridge = AudioSegment.silent( + duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION + ) + + def safe_append( + seg1: AudioSegment, seg2: AudioSegment, crossfade: int + ) -> AudioSegment: + if len(seg1) < crossfade or len(seg2) < crossfade: + logger.warning( + "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation", + crossfade, + len(seg1), + len(seg2), + ) + return seg1 + seg2 + return seg1.append(seg2, crossfade=crossfade) + + # Combine segments with crossfades + # Intro -> Bridge -> Content -> Bridge -> Outro + # This effectively fades out the previous segment and fades in the next one + combined = safe_append(intro, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, content, CROSSFADE_DURATION) + combined = safe_append(combined, bridge, CROSSFADE_DURATION) + combined = safe_append(combined, outro, CROSSFADE_DURATION) + + # Export to bytes + output = io.BytesIO() + combined.export(output, format="mp3") + return output.getvalue() + + def _process_tts_chunk(self, chunk: str, index: int) -> bytes: + """Process a single TTS chunk. + + Args: + chunk: Text to convert to speech + index: Chunk index for logging + + Returns: + Audio data as bytes + """ + logger.info( + "Generating TTS for chunk %d (%d chars)", + index + 1, + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + return response.content + + @staticmethod + def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes: + """Combine multiple audio chunks with silence gaps. + + Args: + audio_chunks: List of audio data in order + + Returns: + Combined audio data + """ + if not audio_chunks: + msg = "No audio chunks to combine" + raise ValueError(msg) + + # Create a temporary file for the combined audio + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Start with the first chunk + combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0])) + + # Add remaining chunks with silence gaps + for chunk_data in audio_chunks[1:]: + chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data)) + silence = AudioSegment.silent(duration=300) # 300ms gap + combined_audio = combined_audio + silence + chunk_audio + + # Export to file + combined_audio.export(temp_path, format="mp3", bitrate="128k") + + # Read back the combined audio + return Path(temp_path).read_bytes() + + finally: + # Clean up temp file + if Path(temp_path).exists(): + Path(temp_path).unlink() + + def _text_to_speech_serial(self, chunks: list[str]) -> bytes: + """Fallback serial processing for high memory situations. + + This is the original serial implementation. + """ + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunks[0], + response_format="mp3", + ) + + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment + + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data + + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() + + def upload_to_s3(self, audio_data: bytes, filename: str) -> str: + """Upload audio file to S3-compatible storage and return public URL. + + Raises: + ValueError: If S3 client is not configured. + ClientError: If S3 upload fails. + """ + if not self.s3_client: + msg = "S3 client not configured" + raise ValueError(msg) + + try: + # Upload file using streaming to minimize memory usage + audio_stream = io.BytesIO(audio_data) + self.s3_client.upload_fileobj( + audio_stream, + S3_BUCKET, + filename, + ExtraArgs={ + "ContentType": "audio/mpeg", + "ACL": "public-read", + }, + ) + + # Construct public URL + audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" + logger.info( + "Uploaded audio to: %s (%d bytes)", + audio_url, + len(audio_data), + ) + except ClientError: + logger.exception("S3 upload failed") + raise + else: + return audio_url + + @staticmethod + def estimate_duration(audio_data: bytes) -> int: + """Estimate audio duration in seconds based on file size and bitrate.""" + # Rough estimation: MP3 at 128kbps = ~16KB per second + estimated_seconds = len(audio_data) // 16000 + return max(1, estimated_seconds) # Minimum 1 second + + @staticmethod + def generate_filename(job_id: int, title: str) -> str: + """Generate unique filename for audio file.""" + timestamp = int(datetime.now(tz=timezone.utc).timestamp()) + # Create safe filename from title + safe_title = "".join( + c for c in title if c.isalnum() or c in {" ", "-", "_"} + ).rstrip() + safe_title = safe_title.replace(" ", "_")[:50] # Limit length + return f"episode_{timestamp}_{job_id}_{safe_title}.mp3" + + def process_job( + self, + job: dict[str, Any], + ) -> None: + """Process a single job through the complete pipeline.""" + job_id = job["id"] + url = job["url"] + + # Check memory before starting + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD: + logger.warning( + "High memory usage (%.1f%%), deferring job %d", + mem_usage, + job_id, + ) + return + + # Track current job for graceful shutdown + self.shutdown_handler.set_current_job(job_id) + + try: + logger.info("Processing job %d: %s", job_id, url) + + # Update status to processing + Core.Database.update_job_status( + job_id, + "processing", + ) + + # Check for shutdown before each major step + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 1: Extract article content + Core.Database.update_job_status(job_id, "extracting") + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content(url) + ) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 2: Generate audio with metadata + Core.Database.update_job_status(job_id, "synthesizing") + audio_data = self.text_to_speech(content, title, author, pub_date) + + if self.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, aborting job %d", job_id) + Core.Database.update_job_status(job_id, "pending") + return + + # Step 3: Upload to S3 + Core.Database.update_job_status(job_id, "uploading") + filename = ArticleProcessor.generate_filename(job_id, title) + audio_url = self.upload_to_s3(audio_data, filename) + + # Step 4: Calculate duration + duration = ArticleProcessor.estimate_duration(audio_data) + + # Step 5: Create episode record + url_hash = Core.hash_url(url) + episode_id = Core.Database.create_episode( + title=title, + audio_url=audio_url, + duration=duration, + content_length=len(content), + user_id=job.get("user_id"), + author=job.get("author"), # Pass author from job + original_url=url, # Pass the original article URL + original_url_hash=url_hash, + ) + + # Add episode to user's feed + user_id = job.get("user_id") + if user_id: + Core.Database.add_episode_to_user(user_id, episode_id) + Core.Database.track_episode_event( + episode_id, + "added", + user_id, + ) + + # Step 6: Mark job as complete + Core.Database.update_job_status( + job_id, + "completed", + ) + + logger.info( + "Successfully processed job %d -> episode %d", + job_id, + episode_id, + ) + + except Exception as e: + error_msg = str(e) + logger.exception("Job %d failed: %s", job_id, error_msg) + Core.Database.update_job_status( + job_id, + "error", + error_msg, + ) + raise + finally: + # Clear current job + self.shutdown_handler.set_current_job(None) + + +def prepare_text_for_tts(text: str, title: str) -> list[str]: + """Use LLM to prepare text for TTS, returning chunks ready for speech. + + First splits text mechanically, then has LLM edit each chunk. + """ + # First, split the text into manageable chunks + raw_chunks = split_text_into_chunks(text, max_chars=3000) + + logger.info("Split article into %d raw chunks", len(raw_chunks)) + + # Prepare the first chunk with intro + edited_chunks = [] + + for i, chunk in enumerate(raw_chunks): + is_first = i == 0 + is_last = i == len(raw_chunks) - 1 + + try: + edited_chunk = edit_chunk_for_speech( + chunk, + title=title if is_first else None, + is_first=is_first, + is_last=is_last, + ) + edited_chunks.append(edited_chunk) + except Exception: + logger.exception("Failed to edit chunk %d", i + 1) + # Fall back to raw chunk if LLM fails + if is_first: + edited_chunks.append( + f"This is an audio version of {title}. {chunk}", + ) + elif is_last: + edited_chunks.append(f"{chunk} This concludes the article.") + else: + edited_chunks.append(chunk) + + return edited_chunks + + +def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]: + """Split text into chunks at sentence boundaries.""" + chunks = [] + current_chunk = "" + + # Split into paragraphs first + paragraphs = text.split("\n\n") + + for para in paragraphs: + para_stripped = para.strip() + if not para_stripped: + continue + + # If paragraph itself is too long, split by sentences + if len(para_stripped) > max_chars: + sentences = para_stripped.split(". ") + for sentence in sentences: + if len(current_chunk) + len(sentence) + 2 < max_chars: + current_chunk += sentence + ". " + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + ". " + # If adding this paragraph would exceed limit, start new chunk + elif len(current_chunk) + len(para_stripped) + 2 > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = para_stripped + " " + else: + current_chunk += para_stripped + " " + + # Don't forget the last chunk + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def edit_chunk_for_speech( + chunk: str, + title: str | None = None, + *, + is_first: bool = False, + is_last: bool = False, +) -> str: + """Use LLM to lightly edit a single chunk for speech. + + Raises: + ValueError: If no content is returned from LLM. + """ + system_prompt = ( + "You are a podcast script editor. Your job is to lightly edit text " + "to make it sound natural when spoken aloud.\n\n" + "Guidelines:\n" + ) + system_prompt += """ +- Remove URLs and email addresses, replacing with descriptive phrases +- Convert bullet points and lists into flowing sentences +- Fix any awkward phrasing for speech +- Remove references like "click here" or "see below" +- Keep edits minimal - preserve the original content and style +- Do NOT add commentary or explanations +- Return ONLY the edited text, no JSON or formatting +""" + + user_prompt = chunk + + # Add intro/outro if needed + if is_first and title: + user_prompt = ( + f"Add a brief intro mentioning this is an audio version of " + f"'{title}', then edit this text:\n\n{chunk}" + ) + elif is_last: + user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}" + + try: + client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY) + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.3, # Lower temperature for more consistent edits + max_tokens=4000, + ) + + content = response.choices[0].message.content + if not content: + msg = "No content returned from LLM" + raise ValueError(msg) # noqa: TRY301 + + # Ensure the chunk isn't too long + max_chunk_length = 4000 + if len(content) > max_chunk_length: + # Truncate at sentence boundary + sentences = content.split(". ") + truncated = "" + for sentence in sentences: + if len(truncated) + len(sentence) + 2 < max_chunk_length: + truncated += sentence + ". " + else: + break + content = truncated.strip() + + except Exception: + logger.exception("LLM chunk editing failed") + raise + else: + return content + + +def parse_datetime_with_timezone(created_at: str | datetime) -> datetime: + """Parse datetime string and ensure it has timezone info.""" + if isinstance(created_at, str): + # Handle timezone-aware datetime strings + if created_at.endswith("Z"): + created_at = created_at[:-1] + "+00:00" + last_attempt = datetime.fromisoformat(created_at) + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + else: + last_attempt = created_at + if last_attempt.tzinfo is None: + last_attempt = last_attempt.replace(tzinfo=timezone.utc) + return last_attempt + + +def should_retry_job(job: dict[str, Any], max_retries: int) -> bool: + """Check if a job should be retried based on retry count and backoff time. + + Uses exponential backoff to determine if enough time has passed. + """ + retry_count = job["retry_count"] + if retry_count >= max_retries: + return False + + # Exponential backoff: 30s, 60s, 120s + backoff_time = 30 * (2**retry_count) + last_attempt = parse_datetime_with_timezone(job["created_at"]) + time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt + + return time_since_attempt > timedelta(seconds=backoff_time) + + +def process_pending_jobs( + processor: ArticleProcessor, +) -> None: + """Process all pending jobs.""" + pending_jobs = Core.Database.get_pending_jobs( + limit=5, + ) + + for job in pending_jobs: + if processor.shutdown_handler.is_shutdown_requested(): + logger.info("Shutdown requested, stopping job processing") + break + + current_job = job["id"] + try: + processor.process_job(job) + except Exception as e: + # Ensure job is marked as error even if process_job didn't handle it + logger.exception("Failed to process job: %d", current_job) + # Check if job is still in processing state + current_status = Core.Database.get_job_by_id( + current_job, + ) + if current_status and current_status.get("status") == "processing": + Core.Database.update_job_status( + current_job, + "error", + str(e), + ) + continue + + +def process_retryable_jobs() -> None: + """Check and retry failed jobs with exponential backoff.""" + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + for job in retryable_jobs: + if should_retry_job(job, MAX_RETRIES): + logger.info( + "Retrying job %d (attempt %d)", + job["id"], + job["retry_count"] + 1, + ) + Core.Database.update_job_status( + job["id"], + "pending", + ) + + +def check_memory_usage() -> int | Any: + """Check current memory usage percentage.""" + try: + process = psutil.Process() + # this returns an int but psutil is untyped + return process.memory_percent() + except (psutil.Error, OSError): + logger.warning("Failed to check memory usage") + return 0.0 + + +def cleanup_stale_jobs() -> None: + """Reset jobs stuck in processing state on startup.""" + with Core.Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE queue + SET status = 'pending' + WHERE status = 'processing' + """, + ) + affected = cursor.rowcount + conn.commit() + + if affected > 0: + logger.info( + "Reset %d stale jobs from processing to pending", + affected, + ) + + +def main_loop() -> None: + """Poll for jobs and process them in a continuous loop.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + # Clean up any stale jobs from previous runs + cleanup_stale_jobs() + + logger.info("Worker started, polling for jobs...") + + while not shutdown_handler.is_shutdown_requested(): + try: + # Process pending jobs + process_pending_jobs(processor) + process_retryable_jobs() + + # Check if there's any work + pending_jobs = Core.Database.get_pending_jobs( + limit=1, + ) + retryable_jobs = Core.Database.get_retryable_jobs( + MAX_RETRIES, + ) + + if not pending_jobs and not retryable_jobs: + logger.debug("No jobs to process, sleeping...") + + except Exception: + logger.exception("Error in main loop") + + # Use interruptible sleep + if not shutdown_handler.is_shutdown_requested(): + shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL) + + # Graceful shutdown + current_job = shutdown_handler.get_current_job() + if current_job: + logger.info( + "Waiting for job %d to complete before shutdown...", + current_job, + ) + # The job will complete or be reset to pending + + logger.info("Worker shutdown complete") + + +def move() -> None: + """Make the worker move.""" + try: + # Initialize database + Core.Database.init_db() + + # Start main processing loop + main_loop() + + except KeyboardInterrupt: + logger.info("Worker stopped by user") + except Exception: + logger.exception("Worker crashed") + raise + + +class TestArticleExtraction(Test.TestCase): + """Test article extraction functionality.""" + + def test_extract_valid_article(self) -> None: + """Extract from well-formed HTML.""" + # Mock trafilatura.fetch_url and extract + mock_html = ( + "<html><body><h1>Test Article</h1><p>Content here</p></body></html>" + ) + mock_result = json.dumps({ + "title": "Test Article", + "text": "Content here", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Content here") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_missing_title(self) -> None: + """Handle articles without titles.""" + mock_html = "<html><body><p>Content without title</p></body></html>" + mock_result = json.dumps({"text": "Content without title"}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Untitled Article") + self.assertEqual(content, "Content without title") + self.assertIsNone(author) + self.assertIsNone(pub_date) + + def test_extract_empty_content(self) -> None: + """Handle empty articles.""" + mock_html = "<html><body></body></html>" + mock_result = json.dumps({"title": "Empty Article", "text": ""}) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + pytest.raises(ValueError, match="No content extracted") as cm, + ): + ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertIn("No content extracted", str(cm.value)) + + def test_extract_network_error(self) -> None: + """Handle connection failures.""" + with ( + unittest.mock.patch("trafilatura.fetch_url", return_value=None), + pytest.raises(ValueError, match="Failed to download") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Failed to download", str(cm.value)) + + @staticmethod + def test_extract_timeout() -> None: + """Handle slow responses.""" + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + side_effect=TimeoutError("Timeout"), + ), + pytest.raises(TimeoutError), + ): + ArticleProcessor.extract_article_content("https://example.com") + + def test_content_sanitization(self) -> None: + """Remove unwanted elements.""" + mock_html = """ + <html><body> + <h1>Article</h1> + <p>Good content</p> + <script>alert('bad')</script> + <table><tr><td>data</td></tr></table> + </body></html> + """ + mock_result = json.dumps({ + "title": "Article", + "text": "Good content", # Tables and scripts removed + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + _title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(content, "Good content") + self.assertNotIn("script", content) + self.assertNotIn("table", content) + + +class TestTextToSpeech(Test.TestCase): + """Test text-to-speech functionality.""" + + def setUp(self) -> None: + """Set up mocks.""" + # Mock OpenAI API key + self.env_patcher = unittest.mock.patch.dict( + os.environ, + {"OPENAI_API_KEY": "test-key"}, + ) + self.env_patcher.start() + + # Mock OpenAI response + self.mock_audio_response: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_response.content = b"fake-audio-data" + + # Mock AudioSegment to avoid ffmpeg issues in tests + self.mock_audio_segment: unittest.mock.MagicMock = ( + unittest.mock.MagicMock() + ) + self.mock_audio_segment.export.return_value = None + self.audio_segment_patcher = unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=self.mock_audio_segment, + ) + self.audio_segment_patcher.start() + + # Mock the concatenation operations + self.mock_audio_segment.__add__.return_value = self.mock_audio_segment + + def tearDown(self) -> None: + """Clean up mocks.""" + self.env_patcher.stop() + self.audio_segment_patcher.stop() + + def test_tts_generation(self) -> None: + """Generate audio from text.""" + + # Mock the export to write test audio data + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=["Test content"], + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + "Test content", + "Test Title", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_chunking(self) -> None: + """Handle long articles with chunking.""" + long_text = "Long content " * 1000 + chunks = ["Chunk 1", "Chunk 2", "Chunk 3"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock AudioSegment.silent + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + long_text, + "Long Article", + ) + + # Should have called TTS for each chunk + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_tts_empty_text(self) -> None: + """Handle empty input.""" + with unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[], + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + with pytest.raises(ValueError, match="No chunks generated") as cm: + processor.text_to_speech("", "Empty") + + self.assertIn("No chunks generated", str(cm.value)) + + def test_tts_special_characters(self) -> None: + """Handle unicode and special chars.""" + special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"' + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=[special_text], + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech( + special_text, + "Special", + ) + + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_llm_text_preparation(self) -> None: + """Verify LLM editing.""" + # Test the actual text preparation functions + chunks = split_text_into_chunks("Short text", max_chars=100) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], "Short text") + + # Test long text splitting + long_text = "Sentence one. " * 100 + chunks = split_text_into_chunks(long_text, max_chars=100) + self.assertGreater(len(chunks), 1) + for chunk in chunks: + self.assertLessEqual(len(chunk), 100) + + @staticmethod + def test_llm_failure_fallback() -> None: + """Handle LLM API failures.""" + # Mock LLM failure + with unittest.mock.patch("openai.OpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_client.chat.completions.create.side_effect = Exception( + "API Error", + ) + + # Should fall back to raw text + with pytest.raises(Exception, match="API Error"): + edit_chunk_for_speech("Test chunk", "Title", is_first=True) + + def test_chunk_concatenation(self) -> None: + """Verify audio joining.""" + # Mock multiple audio segments + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"test-audio-output") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test", "Title") + + # Should produce combined audio + self.assertIsInstance(audio_data, bytes) + self.assertEqual(audio_data, b"test-audio-output") + + def test_parallel_tts_generation(self) -> None: + """Test parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"] + + # Mock responses for each chunk + mock_responses = [] + for i in range(len(chunks)): + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{i}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # Make create return different responses for each call + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Mock AudioSegment operations + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + Path(path).write_bytes(b"combined-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=mock_segment, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, # Normal memory usage + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Verify all chunks were processed + self.assertEqual(mock_speech.create.call_count, len(chunks)) + self.assertEqual(audio_data, b"combined-audio") + + def test_parallel_tts_high_memory_fallback(self) -> None: + """Test fallback to serial processing when memory is high.""" + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"serial-audio") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=65.0, # High memory usage + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Should use serial processing + self.assertEqual(audio_data, b"serial-audio") + + @staticmethod + def test_parallel_tts_error_handling() -> None: + """Test error handling in parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2"] + + # Mock OpenAI client with one failure + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # First call succeeds, second fails + mock_resp1 = unittest.mock.MagicMock() + mock_resp1.content = b"audio-1" + mock_speech.create.side_effect = [mock_resp1, Exception("API Error")] + + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Set up the test context + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + pytest.raises(Exception, match="API Error"), + ): + processor.text_to_speech("Test content", "Test Title") + + def test_parallel_tts_order_preservation(self) -> None: + """Test that chunks are combined in the correct order.""" + chunks = ["First", "Second", "Third", "Fourth", "Fifth"] + + # Create mock responses with identifiable content + mock_responses = [] + for chunk in chunks: + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{chunk}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Track the order of segments being combined + combined_order = [] + + def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock: + content = data.read() + combined_order.append(content.decode()) + segment = unittest.mock.MagicMock() + segment.__add__.return_value = segment + return segment + + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + # Verify order is preserved + expected_order = [f"audio-{chunk}" for chunk in chunks] + if combined_order != expected_order: + msg = f"Order mismatch: {combined_order} != {expected_order}" + raise AssertionError(msg) + Path(path).write_bytes(b"ordered-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + side_effect=mock_from_mp3, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + self.assertEqual(audio_data, b"ordered-audio") + + +class TestIntroOutro(Test.TestCase): + """Test intro and outro generation with metadata.""" + + def test_create_intro_text_full_metadata(self) -> None: + """Test intro text creation with all metadata.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author="John Doe", + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertIn("Author: John Doe", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_no_author(self) -> None: + """Test intro text without author.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date="2024-01-15", + ) + self.assertIn("Title: Test Article", intro) + self.assertNotIn("Author:", intro) + self.assertIn("Published: 2024-01-15", intro) + + def test_create_intro_text_minimal(self) -> None: + """Test intro text with only title.""" + intro = ArticleProcessor._create_intro_text( # noqa: SLF001 + title="Test Article", + author=None, + pub_date=None, + ) + self.assertEqual(intro, "Title: Test Article.") + + def test_create_outro_text_with_author(self) -> None: + """Test outro text with author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author="Jane Smith", + ) + self.assertIn("Test Article", outro) + self.assertIn("Jane Smith", outro) + self.assertIn("Podcast It Later", outro) + + def test_create_outro_text_no_author(self) -> None: + """Test outro text without author.""" + outro = ArticleProcessor._create_outro_text( # noqa: SLF001 + title="Test Article", + author=None, + ) + self.assertIn("Test Article", outro) + self.assertNotIn("by", outro) + self.assertIn("Podcast It Later", outro) + + def test_extract_with_metadata(self) -> None: + """Test that extraction returns metadata.""" + mock_html = "<html><body><p>Content</p></body></html>" + mock_result = json.dumps({ + "title": "Test Article", + "text": "Article content", + "author": "Test Author", + "date": "2024-01-15", + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=mock_html, + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, author, pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Test Article") + self.assertEqual(content, "Article content") + self.assertEqual(author, "Test Author") + self.assertEqual(pub_date, "2024-01-15") + + +class TestMemoryEfficiency(Test.TestCase): + """Test memory-efficient processing.""" + + def test_large_article_size_limit(self) -> None: + """Test that articles exceeding size limits are rejected.""" + huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=huge_text * 4, # Simulate large HTML + ), + pytest.raises(ValueError, match="Article too large") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Article too large", str(cm.value)) + + def test_content_truncation(self) -> None: + """Test that oversized content is truncated.""" + large_content = "Content " * 100_000 # Create large content + mock_result = json.dumps({ + "title": "Large Article", + "text": large_content, + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value="<html><body>content</body></html>", + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content, _author, _pub_date = ( + ArticleProcessor.extract_article_content( + "https://example.com", + ) + ) + + self.assertEqual(title, "Large Article") + self.assertLessEqual(len(content), MAX_ARTICLE_SIZE) + + def test_memory_usage_check(self) -> None: + """Test memory usage monitoring.""" + usage = check_memory_usage() + self.assertIsInstance(usage, float) + self.assertGreaterEqual(usage, 0.0) + self.assertLessEqual(usage, 100.0) + + +class TestJobProcessing(Test.TestCase): + """Test job processing functionality.""" + + def setUp(self) -> None: + """Set up test environment.""" + Core.Database.init_db() + + # Create test user and job + self.user_id, _ = Core.Database.create_user( + "test@example.com", + ) + self.job_id = Core.Database.add_to_queue( + "https://example.com/article", + "test@example.com", + self.user_id, + ) + + # Mock environment + self.env_patcher = unittest.mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test-key", + "S3_ENDPOINT": "https://s3.example.com", + "S3_BUCKET": "test-bucket", + "S3_ACCESS_KEY": "test-access", + "S3_SECRET_KEY": "test-secret", + }, + ) + self.env_patcher.start() + + def tearDown(self) -> None: + """Clean up.""" + self.env_patcher.stop() + Core.Database.teardown() + + def test_process_job_success(self) -> None: + """Complete pipeline execution.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock all external calls + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio-data", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + return_value="https://s3.example.com/audio.mp3", + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.create_episode", + ) as mock_create, + ): + mock_create.return_value = 1 + processor.process_job(job) + + # Verify job was marked complete + mock_update.assert_called_with(self.job_id, "completed") + mock_create.assert_called_once() + + def test_process_job_extraction_failure(self) -> None: + """Handle bad URLs.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + side_effect=ValueError("Bad URL"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(ValueError, match="Bad URL"), + ): + processor.process_job(job) + + # Job should be marked as error + mock_update.assert_called_with(self.job_id, "error", "Bad URL") + + def test_process_job_tts_failure(self) -> None: + """Handle TTS errors.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=Exception("TTS Error"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + pytest.raises(Exception, match="TTS Error"), + ): + processor.process_job(job) + + mock_update.assert_called_with(self.job_id, "error", "TTS Error") + + def test_process_job_s3_failure(self) -> None: + """Handle upload errors.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=("Title", "Content"), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + return_value=b"audio", + ), + unittest.mock.patch.object( + processor, + "upload_to_s3", + side_effect=ClientError({}, "PutObject"), + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ), + pytest.raises(ClientError), + ): + processor.process_job(job) + + def test_job_retry_logic(self) -> None: + """Verify exponential backoff.""" + # Set job to error with retry count + Core.Database.update_job_status( + self.job_id, + "error", + "First failure", + ) + Core.Database.update_job_status( + self.job_id, + "error", + "Second failure", + ) + + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + self.assertEqual(job["retry_count"], 2) + + # Should be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 1) + + def test_max_retries(self) -> None: + """Stop after max attempts.""" + # Exceed retry limit + for i in range(4): + Core.Database.update_job_status( + self.job_id, + "error", + f"Failure {i}", + ) + + # Should not be retryable + retryable = Core.Database.get_retryable_jobs( + max_retries=3, + ) + self.assertEqual(len(retryable), 0) + + def test_graceful_shutdown(self) -> None: + """Test graceful shutdown during job processing.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + def mock_tts(*_args: Any) -> bytes: + shutdown_handler.shutdown_requested.set() + return b"audio-data" + + with ( + unittest.mock.patch.object( + ArticleProcessor, + "extract_article_content", + return_value=( + "Test Title", + "Test content", + "Test Author", + "2024-01-15", + ), + ), + unittest.mock.patch.object( + ArticleProcessor, + "text_to_speech", + side_effect=mock_tts, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should be reset to pending due to shutdown + mock_update.assert_any_call(self.job_id, "pending") + + def test_cleanup_stale_jobs(self) -> None: + """Test cleanup of stale processing jobs.""" + # Manually set job to processing + Core.Database.update_job_status(self.job_id, "processing") + + # Run cleanup + cleanup_stale_jobs() + + # Job should be back to pending + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + self.assertEqual(job["status"], "pending") + + def test_concurrent_processing(self) -> None: + """Handle multiple jobs.""" + # Create multiple jobs + job2 = Core.Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + job3 = Core.Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Get pending jobs + jobs = Core.Database.get_pending_jobs(limit=5) + + self.assertEqual(len(jobs), 3) + self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + + def test_memory_threshold_deferral(self) -> None: + """Test that jobs are deferred when memory usage is high.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock high memory usage + with ( + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=90.0, # High memory usage + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should not be processed (no status updates) + mock_update.assert_not_called() + + +class TestWorkerErrorHandling(Test.TestCase): + """Test worker error handling and recovery.""" + + def setUp(self) -> None: + """Set up test environment.""" + Core.Database.init_db() + self.user_id, _ = Core.Database.create_user("test@example.com") + self.job_id = Core.Database.add_to_queue( + "https://example.com", + "test@example.com", + self.user_id, + ) + self.shutdown_handler = ShutdownHandler() + self.processor = ArticleProcessor(self.shutdown_handler) + + @staticmethod + def tearDown() -> None: + """Clean up.""" + Core.Database.teardown() + + def test_process_pending_jobs_exception_handling(self) -> None: + """Test that process_pending_jobs handles exceptions.""" + + def side_effect(job: dict[str, Any]) -> None: + # Simulate process_job starting and setting status to processing + Core.Database.update_job_status(job["id"], "processing") + msg = "Unexpected Error" + raise ValueError(msg) + + with ( + unittest.mock.patch.object( + self.processor, + "process_job", + side_effect=side_effect, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + side_effect=Core.Database.update_job_status, + ) as _mock_update, + ): + process_pending_jobs(self.processor) + + # Job should be marked as error + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "error") + self.assertIn("Unexpected Error", job["error_message"]) + + def test_process_retryable_jobs_success(self) -> None: + """Test processing of retryable jobs.""" + # Set up a retryable job + Core.Database.update_job_status(self.job_id, "error", "Fail 1") + + # Modify created_at to be in the past to satisfy backoff + with Core.Database.get_connection() as conn: + conn.execute( + "UPDATE queue SET created_at = ? WHERE id = ?", + ( + ( + datetime.now(tz=timezone.utc) - timedelta(minutes=5) + ).isoformat(), + self.job_id, + ), + ) + conn.commit() + + process_retryable_jobs() + + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "pending") + + def test_process_retryable_jobs_not_ready(self) -> None: + """Test that jobs are not retried before backoff period.""" + # Set up a retryable job that just failed + Core.Database.update_job_status(self.job_id, "error", "Fail 1") + + # created_at is now, so backoff should prevent retry + process_retryable_jobs() + + job = Core.Database.get_job_by_id(self.job_id) + self.assertIsNotNone(job) + if job: + self.assertEqual(job["status"], "error") + + +class TestTextChunking(Test.TestCase): + """Test text chunking edge cases.""" + + def test_split_text_single_long_word(self) -> None: + """Handle text with a single word exceeding limit.""" + long_word = "a" * 4000 + chunks = split_text_into_chunks(long_word, max_chars=3000) + + # Should keep it as one chunk or split? + # The current implementation does not split words + self.assertEqual(len(chunks), 1) + self.assertEqual(len(chunks[0]), 4000) + + def test_split_text_no_sentence_boundaries(self) -> None: + """Handle long text with no sentence boundaries.""" + text = "word " * 1000 # 5000 chars + chunks = split_text_into_chunks(text, max_chars=3000) + + # Should keep it as one chunk as it can't split by ". " + self.assertEqual(len(chunks), 1) + self.assertGreater(len(chunks[0]), 3000) + + +def test() -> None: + """Run the tests.""" + Test.run( + App.Area.Test, + [ + TestArticleExtraction, + TestTextToSpeech, + TestMemoryEfficiency, + TestJobProcessing, + TestWorkerErrorHandling, + TestTextChunking, + ], + ) + + +def main() -> None: + """Entry point for the worker.""" + if "test" in sys.argv: + test() + else: + move() |
