summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater/Worker.py')
-rw-r--r--Biz/PodcastItLater/Worker.py2199
1 files changed, 2199 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py
new file mode 100644
index 0000000..bf6ef9e
--- /dev/null
+++ b/Biz/PodcastItLater/Worker.py
@@ -0,0 +1,2199 @@
+"""Background worker for processing article-to-podcast conversions."""
+
+# : dep boto3
+# : dep botocore
+# : dep openai
+# : dep psutil
+# : dep pydub
+# : dep pytest
+# : dep pytest-asyncio
+# : dep pytest-mock
+# : dep trafilatura
+# : out podcastitlater-worker
+# : run ffmpeg
+import Biz.PodcastItLater.Core as Core
+import boto3 # type: ignore[import-untyped]
+import concurrent.futures
+import io
+import json
+import logging
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import openai
+import operator
+import os
+import psutil # type: ignore[import-untyped]
+import pytest
+import signal
+import sys
+import tempfile
+import threading
+import time
+import trafilatura
+import typing
+import unittest.mock
+from botocore.exceptions import ClientError # type: ignore[import-untyped]
+from datetime import datetime
+from datetime import timedelta
+from datetime import timezone
+from pathlib import Path
+from pydub import AudioSegment # type: ignore[import-untyped]
+from typing import Any
+
+logger = logging.getLogger(__name__)
+Log.setup(logger)
+
+# Configuration from environment variables
+OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
+S3_ENDPOINT = os.getenv("S3_ENDPOINT")
+S3_BUCKET = os.getenv("S3_BUCKET")
+S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY")
+S3_SECRET_KEY = os.getenv("S3_SECRET_KEY")
+area = App.from_env()
+
+# Worker configuration
+MAX_CONTENT_LENGTH = 5000 # characters for TTS
+MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles
+POLL_INTERVAL = 30 # seconds
+MAX_RETRIES = 3
+TTS_MODEL = "tts-1"
+TTS_VOICE = "alloy"
+MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage
+CROSSFADE_DURATION = 500 # ms for crossfading segments
+PAUSE_DURATION = 1000 # ms for silence between segments
+
+
+class ShutdownHandler:
+ """Handles graceful shutdown of the worker."""
+
+ def __init__(self) -> None:
+ """Initialize shutdown handler."""
+ self.shutdown_requested = threading.Event()
+ self.current_job_id: int | None = None
+ self.lock = threading.Lock()
+
+ # Register signal handlers
+ signal.signal(signal.SIGTERM, self._handle_signal)
+ signal.signal(signal.SIGINT, self._handle_signal)
+
+ def _handle_signal(self, signum: int, _frame: Any) -> None:
+ """Handle shutdown signals."""
+ logger.info(
+ "Received signal %d, initiating graceful shutdown...",
+ signum,
+ )
+ self.shutdown_requested.set()
+
+ def is_shutdown_requested(self) -> bool:
+ """Check if shutdown has been requested."""
+ return self.shutdown_requested.is_set()
+
+ def set_current_job(self, job_id: int | None) -> None:
+ """Set the currently processing job."""
+ with self.lock:
+ self.current_job_id = job_id
+
+ def get_current_job(self) -> int | None:
+ """Get the currently processing job."""
+ with self.lock:
+ return self.current_job_id
+
+
+class ArticleProcessor:
+ """Handles the complete article-to-podcast conversion pipeline."""
+
+ def __init__(self, shutdown_handler: ShutdownHandler) -> None:
+ """Initialize the processor with required services.
+
+ Raises:
+ ValueError: If OPENAI_API_KEY environment variable is not set.
+ """
+ if not OPENAI_API_KEY:
+ msg = "OPENAI_API_KEY environment variable is required"
+ raise ValueError(msg)
+
+ self.openai_client: openai.OpenAI = openai.OpenAI(
+ api_key=OPENAI_API_KEY,
+ )
+ self.shutdown_handler = shutdown_handler
+
+ # Initialize S3 client for Digital Ocean Spaces
+ if all([S3_ENDPOINT, S3_BUCKET, S3_ACCESS_KEY, S3_SECRET_KEY]):
+ self.s3_client: Any = boto3.client(
+ "s3",
+ endpoint_url=S3_ENDPOINT,
+ aws_access_key_id=S3_ACCESS_KEY,
+ aws_secret_access_key=S3_SECRET_KEY,
+ )
+ else:
+ logger.warning("S3 configuration incomplete, uploads will fail")
+ self.s3_client = None
+
+ @staticmethod
+ def extract_article_content(
+ url: str,
+ ) -> tuple[str, str, str | None, str | None]:
+ """Extract title, content, author, and date from article URL.
+
+ Returns:
+ tuple: (title, content, author, publication_date)
+
+ Raises:
+ ValueError: If content cannot be downloaded, extracted, or large.
+ """
+ try:
+ downloaded = trafilatura.fetch_url(url)
+ if not downloaded:
+ msg = f"Failed to download content from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Check size before processing
+ if (
+ len(downloaded) > MAX_ARTICLE_SIZE * 4
+ ): # Rough HTML to text ratio
+ msg = f"Article too large: {len(downloaded)} bytes"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Extract with metadata
+ result = trafilatura.extract(
+ downloaded,
+ include_comments=False,
+ include_tables=False,
+ with_metadata=True,
+ output_format="json",
+ )
+
+ if not result:
+ msg = f"Failed to extract content from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ data = json.loads(result)
+
+ title = data.get("title", "Untitled Article")
+ content = data.get("text", "")
+ author = data.get("author")
+ pub_date = data.get("date")
+
+ if not content:
+ msg = f"No content extracted from {url}"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Enforce content size limit
+ if len(content) > MAX_ARTICLE_SIZE:
+ logger.warning(
+ "Article content truncated from %d to %d characters",
+ len(content),
+ MAX_ARTICLE_SIZE,
+ )
+ content = content[:MAX_ARTICLE_SIZE]
+
+ logger.info(
+ "Extracted article: %s (%d chars, author: %s, date: %s)",
+ title,
+ len(content),
+ author or "unknown",
+ pub_date or "unknown",
+ )
+ except Exception:
+ logger.exception("Failed to extract content from %s", url)
+ raise
+ else:
+ return title, content, author, pub_date
+
+ def text_to_speech(
+ self,
+ text: str,
+ title: str,
+ author: str | None = None,
+ pub_date: str | None = None,
+ ) -> bytes:
+ """Convert text to speech with intro/outro using OpenAI TTS API.
+
+ Uses parallel processing for chunks while maintaining order.
+ Adds intro with metadata and outro with attribution.
+
+ Args:
+ text: Article content to convert
+ title: Article title
+ author: Article author (optional)
+ pub_date: Publication date (optional)
+
+ Raises:
+ ValueError: If no chunks are generated from text.
+ """
+ try:
+ # Generate intro audio
+ intro_text = self._create_intro_text(title, author, pub_date)
+ intro_audio = self._generate_tts_segment(intro_text)
+
+ # Generate outro audio
+ outro_text = self._create_outro_text(title, author)
+ outro_audio = self._generate_tts_segment(outro_text)
+
+ # Use LLM to prepare and chunk the main content
+ chunks = prepare_text_for_tts(text, title)
+
+ if not chunks:
+ msg = "No chunks generated from text"
+ raise ValueError(msg) # noqa: TRY301
+
+ logger.info("Processing %d chunks for TTS", len(chunks))
+
+ # Check memory before parallel processing
+ mem_usage = check_memory_usage()
+ if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer
+ logger.warning(
+ "High memory usage (%.1f%%), falling back to serial "
+ "processing",
+ mem_usage,
+ )
+ content_audio_bytes = self._text_to_speech_serial(chunks)
+ else:
+ # Determine max workers
+ max_workers = min(
+ 4, # Reasonable limit to avoid rate limiting
+ len(chunks), # No more workers than chunks
+ max(1, psutil.cpu_count() // 2), # Use half of CPU cores
+ )
+
+ logger.info(
+ "Using %d workers for parallel TTS processing",
+ max_workers,
+ )
+
+ # Process chunks in parallel
+ chunk_results: list[tuple[int, bytes]] = []
+
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=max_workers,
+ ) as executor:
+ # Submit all chunks for processing
+ future_to_index = {
+ executor.submit(self._process_tts_chunk, chunk, i): i
+ for i, chunk in enumerate(chunks)
+ }
+
+ # Collect results as they complete
+ for future in concurrent.futures.as_completed(
+ future_to_index,
+ ):
+ index = future_to_index[future]
+ try:
+ audio_data = future.result()
+ chunk_results.append((index, audio_data))
+ except Exception:
+ logger.exception(
+ "Failed to process chunk %d",
+ index,
+ )
+ raise
+
+ # Sort results by index to maintain order
+ chunk_results.sort(key=operator.itemgetter(0))
+
+ # Combine audio chunks
+ content_audio_bytes = self._combine_audio_chunks([
+ data for _, data in chunk_results
+ ])
+
+ # Combine intro, content, and outro with pauses
+ return ArticleProcessor._combine_intro_content_outro(
+ intro_audio,
+ content_audio_bytes,
+ outro_audio,
+ )
+
+ except Exception:
+ logger.exception("TTS generation failed")
+ raise
+
+ @staticmethod
+ def _create_intro_text(
+ title: str,
+ author: str | None,
+ pub_date: str | None,
+ ) -> str:
+ """Create intro text with available metadata."""
+ parts = [f"Title: {title}"]
+
+ if author:
+ parts.append(f"Author: {author}")
+
+ if pub_date:
+ parts.append(f"Published: {pub_date}")
+
+ return ". ".join(parts) + "."
+
+ @staticmethod
+ def _create_outro_text(title: str, author: str | None) -> str:
+ """Create outro text with attribution."""
+ if author:
+ return (
+ f"This has been an audio version of {title} "
+ f"by {author}, created using Podcast It Later."
+ )
+ return (
+ f"This has been an audio version of {title}, "
+ "created using Podcast It Later."
+ )
+
+ def _generate_tts_segment(self, text: str) -> bytes:
+ """Generate TTS audio for a single segment (intro/outro).
+
+ Args:
+ text: Text to convert to speech
+
+ Returns:
+ MP3 audio bytes
+ """
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=text,
+ )
+ return response.content
+
+ @staticmethod
+ def _combine_intro_content_outro(
+ intro_audio: bytes,
+ content_audio: bytes,
+ outro_audio: bytes,
+ ) -> bytes:
+ """Combine intro, content, and outro with crossfades.
+
+ Args:
+ intro_audio: MP3 bytes for intro
+ content_audio: MP3 bytes for main content
+ outro_audio: MP3 bytes for outro
+
+ Returns:
+ Combined MP3 audio bytes
+ """
+ # Load audio segments
+ intro = AudioSegment.from_mp3(io.BytesIO(intro_audio))
+ content = AudioSegment.from_mp3(io.BytesIO(content_audio))
+ outro = AudioSegment.from_mp3(io.BytesIO(outro_audio))
+
+ # Create bridge silence (pause + 2 * crossfade to account for overlap)
+ bridge = AudioSegment.silent(
+ duration=PAUSE_DURATION + 2 * CROSSFADE_DURATION
+ )
+
+ def safe_append(
+ seg1: AudioSegment, seg2: AudioSegment, crossfade: int
+ ) -> AudioSegment:
+ if len(seg1) < crossfade or len(seg2) < crossfade:
+ logger.warning(
+ "Segment too short for crossfade (%dms vs %dms/%dms), using concatenation",
+ crossfade,
+ len(seg1),
+ len(seg2),
+ )
+ return seg1 + seg2
+ return seg1.append(seg2, crossfade=crossfade)
+
+ # Combine segments with crossfades
+ # Intro -> Bridge -> Content -> Bridge -> Outro
+ # This effectively fades out the previous segment and fades in the next one
+ combined = safe_append(intro, bridge, CROSSFADE_DURATION)
+ combined = safe_append(combined, content, CROSSFADE_DURATION)
+ combined = safe_append(combined, bridge, CROSSFADE_DURATION)
+ combined = safe_append(combined, outro, CROSSFADE_DURATION)
+
+ # Export to bytes
+ output = io.BytesIO()
+ combined.export(output, format="mp3")
+ return output.getvalue()
+
+ def _process_tts_chunk(self, chunk: str, index: int) -> bytes:
+ """Process a single TTS chunk.
+
+ Args:
+ chunk: Text to convert to speech
+ index: Chunk index for logging
+
+ Returns:
+ Audio data as bytes
+ """
+ logger.info(
+ "Generating TTS for chunk %d (%d chars)",
+ index + 1,
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ return response.content
+
+ @staticmethod
+ def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes:
+ """Combine multiple audio chunks with silence gaps.
+
+ Args:
+ audio_chunks: List of audio data in order
+
+ Returns:
+ Combined audio data
+ """
+ if not audio_chunks:
+ msg = "No audio chunks to combine"
+ raise ValueError(msg)
+
+ # Create a temporary file for the combined audio
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+
+ try:
+ # Start with the first chunk
+ combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0]))
+
+ # Add remaining chunks with silence gaps
+ for chunk_data in audio_chunks[1:]:
+ chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data))
+ silence = AudioSegment.silent(duration=300) # 300ms gap
+ combined_audio = combined_audio + silence + chunk_audio
+
+ # Export to file
+ combined_audio.export(temp_path, format="mp3", bitrate="128k")
+
+ # Read back the combined audio
+ return Path(temp_path).read_bytes()
+
+ finally:
+ # Clean up temp file
+ if Path(temp_path).exists():
+ Path(temp_path).unlink()
+
+ def _text_to_speech_serial(self, chunks: list[str]) -> bytes:
+ """Fallback serial processing for high memory situations.
+
+ This is the original serial implementation.
+ """
+ # Create a temporary file for streaming audio concatenation
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+
+ try:
+ # Process first chunk
+ logger.info("Generating TTS for chunk 1/%d", len(chunks))
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunks[0],
+ response_format="mp3",
+ )
+
+ # Write first chunk directly to file
+ Path(temp_path).write_bytes(response.content)
+
+ # Process remaining chunks
+ for i, chunk in enumerate(chunks[1:], 1):
+ logger.info(
+ "Generating TTS for chunk %d/%d (%d chars)",
+ i + 1,
+ len(chunks),
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ # Append to existing file with silence gap
+ # Load only the current segment
+ current_segment = AudioSegment.from_mp3(
+ io.BytesIO(response.content),
+ )
+
+ # Load existing audio, append, and save back
+ existing_audio = AudioSegment.from_mp3(temp_path)
+ silence = AudioSegment.silent(duration=300)
+ combined = existing_audio + silence + current_segment
+
+ # Export back to the same file
+ combined.export(temp_path, format="mp3", bitrate="128k")
+
+ # Force garbage collection to free memory
+ del existing_audio, current_segment, combined
+
+ # Small delay between API calls
+ if i < len(chunks) - 1:
+ time.sleep(0.5)
+
+ # Read final result
+ audio_data = Path(temp_path).read_bytes()
+
+ logger.info(
+ "Generated combined TTS audio: %d bytes",
+ len(audio_data),
+ )
+ return audio_data
+
+ finally:
+ # Clean up temp file
+ temp_file_path = Path(temp_path)
+ if temp_file_path.exists():
+ temp_file_path.unlink()
+
+ def upload_to_s3(self, audio_data: bytes, filename: str) -> str:
+ """Upload audio file to S3-compatible storage and return public URL.
+
+ Raises:
+ ValueError: If S3 client is not configured.
+ ClientError: If S3 upload fails.
+ """
+ if not self.s3_client:
+ msg = "S3 client not configured"
+ raise ValueError(msg)
+
+ try:
+ # Upload file using streaming to minimize memory usage
+ audio_stream = io.BytesIO(audio_data)
+ self.s3_client.upload_fileobj(
+ audio_stream,
+ S3_BUCKET,
+ filename,
+ ExtraArgs={
+ "ContentType": "audio/mpeg",
+ "ACL": "public-read",
+ },
+ )
+
+ # Construct public URL
+ audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}"
+ logger.info(
+ "Uploaded audio to: %s (%d bytes)",
+ audio_url,
+ len(audio_data),
+ )
+ except ClientError:
+ logger.exception("S3 upload failed")
+ raise
+ else:
+ return audio_url
+
+ @staticmethod
+ def estimate_duration(audio_data: bytes) -> int:
+ """Estimate audio duration in seconds based on file size and bitrate."""
+ # Rough estimation: MP3 at 128kbps = ~16KB per second
+ estimated_seconds = len(audio_data) // 16000
+ return max(1, estimated_seconds) # Minimum 1 second
+
+ @staticmethod
+ def generate_filename(job_id: int, title: str) -> str:
+ """Generate unique filename for audio file."""
+ timestamp = int(datetime.now(tz=timezone.utc).timestamp())
+ # Create safe filename from title
+ safe_title = "".join(
+ c for c in title if c.isalnum() or c in {" ", "-", "_"}
+ ).rstrip()
+ safe_title = safe_title.replace(" ", "_")[:50] # Limit length
+ return f"episode_{timestamp}_{job_id}_{safe_title}.mp3"
+
+ def process_job(
+ self,
+ job: dict[str, Any],
+ ) -> None:
+ """Process a single job through the complete pipeline."""
+ job_id = job["id"]
+ url = job["url"]
+
+ # Check memory before starting
+ mem_usage = check_memory_usage()
+ if mem_usage > MEMORY_THRESHOLD:
+ logger.warning(
+ "High memory usage (%.1f%%), deferring job %d",
+ mem_usage,
+ job_id,
+ )
+ return
+
+ # Track current job for graceful shutdown
+ self.shutdown_handler.set_current_job(job_id)
+
+ try:
+ logger.info("Processing job %d: %s", job_id, url)
+
+ # Update status to processing
+ Core.Database.update_job_status(
+ job_id,
+ "processing",
+ )
+
+ # Check for shutdown before each major step
+ if self.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, aborting job %d", job_id)
+ Core.Database.update_job_status(job_id, "pending")
+ return
+
+ # Step 1: Extract article content
+ Core.Database.update_job_status(job_id, "extracting")
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(url)
+ )
+
+ if self.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, aborting job %d", job_id)
+ Core.Database.update_job_status(job_id, "pending")
+ return
+
+ # Step 2: Generate audio with metadata
+ Core.Database.update_job_status(job_id, "synthesizing")
+ audio_data = self.text_to_speech(content, title, author, pub_date)
+
+ if self.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, aborting job %d", job_id)
+ Core.Database.update_job_status(job_id, "pending")
+ return
+
+ # Step 3: Upload to S3
+ Core.Database.update_job_status(job_id, "uploading")
+ filename = ArticleProcessor.generate_filename(job_id, title)
+ audio_url = self.upload_to_s3(audio_data, filename)
+
+ # Step 4: Calculate duration
+ duration = ArticleProcessor.estimate_duration(audio_data)
+
+ # Step 5: Create episode record
+ url_hash = Core.hash_url(url)
+ episode_id = Core.Database.create_episode(
+ title=title,
+ audio_url=audio_url,
+ duration=duration,
+ content_length=len(content),
+ user_id=job.get("user_id"),
+ author=job.get("author"), # Pass author from job
+ original_url=url, # Pass the original article URL
+ original_url_hash=url_hash,
+ )
+
+ # Add episode to user's feed
+ user_id = job.get("user_id")
+ if user_id:
+ Core.Database.add_episode_to_user(user_id, episode_id)
+ Core.Database.track_episode_event(
+ episode_id,
+ "added",
+ user_id,
+ )
+
+ # Step 6: Mark job as complete
+ Core.Database.update_job_status(
+ job_id,
+ "completed",
+ )
+
+ logger.info(
+ "Successfully processed job %d -> episode %d",
+ job_id,
+ episode_id,
+ )
+
+ except Exception as e:
+ error_msg = str(e)
+ logger.exception("Job %d failed: %s", job_id, error_msg)
+ Core.Database.update_job_status(
+ job_id,
+ "error",
+ error_msg,
+ )
+ raise
+ finally:
+ # Clear current job
+ self.shutdown_handler.set_current_job(None)
+
+
+def prepare_text_for_tts(text: str, title: str) -> list[str]:
+ """Use LLM to prepare text for TTS, returning chunks ready for speech.
+
+ First splits text mechanically, then has LLM edit each chunk.
+ """
+ # First, split the text into manageable chunks
+ raw_chunks = split_text_into_chunks(text, max_chars=3000)
+
+ logger.info("Split article into %d raw chunks", len(raw_chunks))
+
+ # Prepare the first chunk with intro
+ edited_chunks = []
+
+ for i, chunk in enumerate(raw_chunks):
+ is_first = i == 0
+ is_last = i == len(raw_chunks) - 1
+
+ try:
+ edited_chunk = edit_chunk_for_speech(
+ chunk,
+ title=title if is_first else None,
+ is_first=is_first,
+ is_last=is_last,
+ )
+ edited_chunks.append(edited_chunk)
+ except Exception:
+ logger.exception("Failed to edit chunk %d", i + 1)
+ # Fall back to raw chunk if LLM fails
+ if is_first:
+ edited_chunks.append(
+ f"This is an audio version of {title}. {chunk}",
+ )
+ elif is_last:
+ edited_chunks.append(f"{chunk} This concludes the article.")
+ else:
+ edited_chunks.append(chunk)
+
+ return edited_chunks
+
+
+def split_text_into_chunks(text: str, max_chars: int = 3000) -> list[str]:
+ """Split text into chunks at sentence boundaries."""
+ chunks = []
+ current_chunk = ""
+
+ # Split into paragraphs first
+ paragraphs = text.split("\n\n")
+
+ for para in paragraphs:
+ para_stripped = para.strip()
+ if not para_stripped:
+ continue
+
+ # If paragraph itself is too long, split by sentences
+ if len(para_stripped) > max_chars:
+ sentences = para_stripped.split(". ")
+ for sentence in sentences:
+ if len(current_chunk) + len(sentence) + 2 < max_chars:
+ current_chunk += sentence + ". "
+ else:
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+ current_chunk = sentence + ". "
+ # If adding this paragraph would exceed limit, start new chunk
+ elif len(current_chunk) + len(para_stripped) + 2 > max_chars:
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+ current_chunk = para_stripped + " "
+ else:
+ current_chunk += para_stripped + " "
+
+ # Don't forget the last chunk
+ if current_chunk:
+ chunks.append(current_chunk.strip())
+
+ return chunks
+
+
+def edit_chunk_for_speech(
+ chunk: str,
+ title: str | None = None,
+ *,
+ is_first: bool = False,
+ is_last: bool = False,
+) -> str:
+ """Use LLM to lightly edit a single chunk for speech.
+
+ Raises:
+ ValueError: If no content is returned from LLM.
+ """
+ system_prompt = (
+ "You are a podcast script editor. Your job is to lightly edit text "
+ "to make it sound natural when spoken aloud.\n\n"
+ "Guidelines:\n"
+ )
+ system_prompt += """
+- Remove URLs and email addresses, replacing with descriptive phrases
+- Convert bullet points and lists into flowing sentences
+- Fix any awkward phrasing for speech
+- Remove references like "click here" or "see below"
+- Keep edits minimal - preserve the original content and style
+- Do NOT add commentary or explanations
+- Return ONLY the edited text, no JSON or formatting
+"""
+
+ user_prompt = chunk
+
+ # Add intro/outro if needed
+ if is_first and title:
+ user_prompt = (
+ f"Add a brief intro mentioning this is an audio version of "
+ f"'{title}', then edit this text:\n\n{chunk}"
+ )
+ elif is_last:
+ user_prompt = f"Edit this text and add a brief closing:\n\n{chunk}"
+
+ try:
+ client: openai.OpenAI = openai.OpenAI(api_key=OPENAI_API_KEY)
+ response = client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ],
+ temperature=0.3, # Lower temperature for more consistent edits
+ max_tokens=4000,
+ )
+
+ content = response.choices[0].message.content
+ if not content:
+ msg = "No content returned from LLM"
+ raise ValueError(msg) # noqa: TRY301
+
+ # Ensure the chunk isn't too long
+ max_chunk_length = 4000
+ if len(content) > max_chunk_length:
+ # Truncate at sentence boundary
+ sentences = content.split(". ")
+ truncated = ""
+ for sentence in sentences:
+ if len(truncated) + len(sentence) + 2 < max_chunk_length:
+ truncated += sentence + ". "
+ else:
+ break
+ content = truncated.strip()
+
+ except Exception:
+ logger.exception("LLM chunk editing failed")
+ raise
+ else:
+ return content
+
+
+def parse_datetime_with_timezone(created_at: str | datetime) -> datetime:
+ """Parse datetime string and ensure it has timezone info."""
+ if isinstance(created_at, str):
+ # Handle timezone-aware datetime strings
+ if created_at.endswith("Z"):
+ created_at = created_at[:-1] + "+00:00"
+ last_attempt = datetime.fromisoformat(created_at)
+ if last_attempt.tzinfo is None:
+ last_attempt = last_attempt.replace(tzinfo=timezone.utc)
+ else:
+ last_attempt = created_at
+ if last_attempt.tzinfo is None:
+ last_attempt = last_attempt.replace(tzinfo=timezone.utc)
+ return last_attempt
+
+
+def should_retry_job(job: dict[str, Any], max_retries: int) -> bool:
+ """Check if a job should be retried based on retry count and backoff time.
+
+ Uses exponential backoff to determine if enough time has passed.
+ """
+ retry_count = job["retry_count"]
+ if retry_count >= max_retries:
+ return False
+
+ # Exponential backoff: 30s, 60s, 120s
+ backoff_time = 30 * (2**retry_count)
+ last_attempt = parse_datetime_with_timezone(job["created_at"])
+ time_since_attempt = datetime.now(tz=timezone.utc) - last_attempt
+
+ return time_since_attempt > timedelta(seconds=backoff_time)
+
+
+def process_pending_jobs(
+ processor: ArticleProcessor,
+) -> None:
+ """Process all pending jobs."""
+ pending_jobs = Core.Database.get_pending_jobs(
+ limit=5,
+ )
+
+ for job in pending_jobs:
+ if processor.shutdown_handler.is_shutdown_requested():
+ logger.info("Shutdown requested, stopping job processing")
+ break
+
+ current_job = job["id"]
+ try:
+ processor.process_job(job)
+ except Exception as e:
+ # Ensure job is marked as error even if process_job didn't handle it
+ logger.exception("Failed to process job: %d", current_job)
+ # Check if job is still in processing state
+ current_status = Core.Database.get_job_by_id(
+ current_job,
+ )
+ if current_status and current_status.get("status") == "processing":
+ Core.Database.update_job_status(
+ current_job,
+ "error",
+ str(e),
+ )
+ continue
+
+
+def process_retryable_jobs() -> None:
+ """Check and retry failed jobs with exponential backoff."""
+ retryable_jobs = Core.Database.get_retryable_jobs(
+ MAX_RETRIES,
+ )
+
+ for job in retryable_jobs:
+ if should_retry_job(job, MAX_RETRIES):
+ logger.info(
+ "Retrying job %d (attempt %d)",
+ job["id"],
+ job["retry_count"] + 1,
+ )
+ Core.Database.update_job_status(
+ job["id"],
+ "pending",
+ )
+
+
+def check_memory_usage() -> int | Any:
+ """Check current memory usage percentage."""
+ try:
+ process = psutil.Process()
+ # this returns an int but psutil is untyped
+ return process.memory_percent()
+ except (psutil.Error, OSError):
+ logger.warning("Failed to check memory usage")
+ return 0.0
+
+
+def cleanup_stale_jobs() -> None:
+ """Reset jobs stuck in processing state on startup."""
+ with Core.Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE queue
+ SET status = 'pending'
+ WHERE status = 'processing'
+ """,
+ )
+ affected = cursor.rowcount
+ conn.commit()
+
+ if affected > 0:
+ logger.info(
+ "Reset %d stale jobs from processing to pending",
+ affected,
+ )
+
+
+def main_loop() -> None:
+ """Poll for jobs and process them in a continuous loop."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+
+ # Clean up any stale jobs from previous runs
+ cleanup_stale_jobs()
+
+ logger.info("Worker started, polling for jobs...")
+
+ while not shutdown_handler.is_shutdown_requested():
+ try:
+ # Process pending jobs
+ process_pending_jobs(processor)
+ process_retryable_jobs()
+
+ # Check if there's any work
+ pending_jobs = Core.Database.get_pending_jobs(
+ limit=1,
+ )
+ retryable_jobs = Core.Database.get_retryable_jobs(
+ MAX_RETRIES,
+ )
+
+ if not pending_jobs and not retryable_jobs:
+ logger.debug("No jobs to process, sleeping...")
+
+ except Exception:
+ logger.exception("Error in main loop")
+
+ # Use interruptible sleep
+ if not shutdown_handler.is_shutdown_requested():
+ shutdown_handler.shutdown_requested.wait(timeout=POLL_INTERVAL)
+
+ # Graceful shutdown
+ current_job = shutdown_handler.get_current_job()
+ if current_job:
+ logger.info(
+ "Waiting for job %d to complete before shutdown...",
+ current_job,
+ )
+ # The job will complete or be reset to pending
+
+ logger.info("Worker shutdown complete")
+
+
+def move() -> None:
+ """Make the worker move."""
+ try:
+ # Initialize database
+ Core.Database.init_db()
+
+ # Start main processing loop
+ main_loop()
+
+ except KeyboardInterrupt:
+ logger.info("Worker stopped by user")
+ except Exception:
+ logger.exception("Worker crashed")
+ raise
+
+
+class TestArticleExtraction(Test.TestCase):
+ """Test article extraction functionality."""
+
+ def test_extract_valid_article(self) -> None:
+ """Extract from well-formed HTML."""
+ # Mock trafilatura.fetch_url and extract
+ mock_html = (
+ "<html><body><h1>Test Article</h1><p>Content here</p></body></html>"
+ )
+ mock_result = json.dumps({
+ "title": "Test Article",
+ "text": "Content here",
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Test Article")
+ self.assertEqual(content, "Content here")
+ self.assertIsNone(author)
+ self.assertIsNone(pub_date)
+
+ def test_extract_missing_title(self) -> None:
+ """Handle articles without titles."""
+ mock_html = "<html><body><p>Content without title</p></body></html>"
+ mock_result = json.dumps({"text": "Content without title"})
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Untitled Article")
+ self.assertEqual(content, "Content without title")
+ self.assertIsNone(author)
+ self.assertIsNone(pub_date)
+
+ def test_extract_empty_content(self) -> None:
+ """Handle empty articles."""
+ mock_html = "<html><body></body></html>"
+ mock_result = json.dumps({"title": "Empty Article", "text": ""})
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ pytest.raises(ValueError, match="No content extracted") as cm,
+ ):
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertIn("No content extracted", str(cm.value))
+
+ def test_extract_network_error(self) -> None:
+ """Handle connection failures."""
+ with (
+ unittest.mock.patch("trafilatura.fetch_url", return_value=None),
+ pytest.raises(ValueError, match="Failed to download") as cm,
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ self.assertIn("Failed to download", str(cm.value))
+
+ @staticmethod
+ def test_extract_timeout() -> None:
+ """Handle slow responses."""
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ side_effect=TimeoutError("Timeout"),
+ ),
+ pytest.raises(TimeoutError),
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ def test_content_sanitization(self) -> None:
+ """Remove unwanted elements."""
+ mock_html = """
+ <html><body>
+ <h1>Article</h1>
+ <p>Good content</p>
+ <script>alert('bad')</script>
+ <table><tr><td>data</td></tr></table>
+ </body></html>
+ """
+ mock_result = json.dumps({
+ "title": "Article",
+ "text": "Good content", # Tables and scripts removed
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ _title, content, _author, _pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(content, "Good content")
+ self.assertNotIn("script", content)
+ self.assertNotIn("table", content)
+
+
+class TestTextToSpeech(Test.TestCase):
+ """Test text-to-speech functionality."""
+
+ def setUp(self) -> None:
+ """Set up mocks."""
+ # Mock OpenAI API key
+ self.env_patcher = unittest.mock.patch.dict(
+ os.environ,
+ {"OPENAI_API_KEY": "test-key"},
+ )
+ self.env_patcher.start()
+
+ # Mock OpenAI response
+ self.mock_audio_response: unittest.mock.MagicMock = (
+ unittest.mock.MagicMock()
+ )
+ self.mock_audio_response.content = b"fake-audio-data"
+
+ # Mock AudioSegment to avoid ffmpeg issues in tests
+ self.mock_audio_segment: unittest.mock.MagicMock = (
+ unittest.mock.MagicMock()
+ )
+ self.mock_audio_segment.export.return_value = None
+ self.audio_segment_patcher = unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ return_value=self.mock_audio_segment,
+ )
+ self.audio_segment_patcher.start()
+
+ # Mock the concatenation operations
+ self.mock_audio_segment.__add__.return_value = self.mock_audio_segment
+
+ def tearDown(self) -> None:
+ """Clean up mocks."""
+ self.env_patcher.stop()
+ self.audio_segment_patcher.stop()
+
+ def test_tts_generation(self) -> None:
+ """Generate audio from text."""
+
+ # Mock the export to write test audio data
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=["Test content"],
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech(
+ "Test content",
+ "Test Title",
+ )
+
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_tts_chunking(self) -> None:
+ """Handle long articles with chunking."""
+ long_text = "Long content " * 1000
+ chunks = ["Chunk 1", "Chunk 2", "Chunk 3"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock AudioSegment.silent
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech(
+ long_text,
+ "Long Article",
+ )
+
+ # Should have called TTS for each chunk
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_tts_empty_text(self) -> None:
+ """Handle empty input."""
+ with unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=[],
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ with pytest.raises(ValueError, match="No chunks generated") as cm:
+ processor.text_to_speech("", "Empty")
+
+ self.assertIn("No chunks generated", str(cm.value))
+
+ def test_tts_special_characters(self) -> None:
+ """Handle unicode and special chars."""
+ special_text = 'Unicode: 你好世界 Émojis: 🎙️📰 Special: <>&"'
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=[special_text],
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech(
+ special_text,
+ "Special",
+ )
+
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_llm_text_preparation(self) -> None:
+ """Verify LLM editing."""
+ # Test the actual text preparation functions
+ chunks = split_text_into_chunks("Short text", max_chars=100)
+ self.assertEqual(len(chunks), 1)
+ self.assertEqual(chunks[0], "Short text")
+
+ # Test long text splitting
+ long_text = "Sentence one. " * 100
+ chunks = split_text_into_chunks(long_text, max_chars=100)
+ self.assertGreater(len(chunks), 1)
+ for chunk in chunks:
+ self.assertLessEqual(len(chunk), 100)
+
+ @staticmethod
+ def test_llm_failure_fallback() -> None:
+ """Handle LLM API failures."""
+ # Mock LLM failure
+ with unittest.mock.patch("openai.OpenAI") as mock_openai:
+ mock_client = mock_openai.return_value
+ mock_client.chat.completions.create.side_effect = Exception(
+ "API Error",
+ )
+
+ # Should fall back to raw text
+ with pytest.raises(Exception, match="API Error"):
+ edit_chunk_for_speech("Test chunk", "Title", is_first=True)
+
+ def test_chunk_concatenation(self) -> None:
+ """Verify audio joining."""
+ # Mock multiple audio segments
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"test-audio-output")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test", "Title")
+
+ # Should produce combined audio
+ self.assertIsInstance(audio_data, bytes)
+ self.assertEqual(audio_data, b"test-audio-output")
+
+ def test_parallel_tts_generation(self) -> None:
+ """Test parallel TTS processing."""
+ chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"]
+
+ # Mock responses for each chunk
+ mock_responses = []
+ for i in range(len(chunks)):
+ mock_resp = unittest.mock.MagicMock()
+ mock_resp.content = f"audio-{i}".encode()
+ mock_responses.append(mock_resp)
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+
+ # Make create return different responses for each call
+ mock_speech.create.side_effect = mock_responses
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Mock AudioSegment operations
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.__add__.return_value = mock_segment
+
+ def mock_export(path: str, **_kwargs: typing.Any) -> None:
+ Path(path).write_bytes(b"combined-audio")
+
+ mock_segment.export = mock_export
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0, # Normal memory usage
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ # Verify all chunks were processed
+ self.assertEqual(mock_speech.create.call_count, len(chunks))
+ self.assertEqual(audio_data, b"combined-audio")
+
+ def test_parallel_tts_high_memory_fallback(self) -> None:
+ """Test fallback to serial processing when memory is high."""
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"serial-audio")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=65.0, # High memory usage
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ # Should use serial processing
+ self.assertEqual(audio_data, b"serial-audio")
+
+ @staticmethod
+ def test_parallel_tts_error_handling() -> None:
+ """Test error handling in parallel TTS processing."""
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ # Mock OpenAI client with one failure
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+
+ # First call succeeds, second fails
+ mock_resp1 = unittest.mock.MagicMock()
+ mock_resp1.content = b"audio-1"
+ mock_speech.create.side_effect = [mock_resp1, Exception("API Error")]
+
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Set up the test context
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ pytest.raises(Exception, match="API Error"),
+ ):
+ processor.text_to_speech("Test content", "Test Title")
+
+ def test_parallel_tts_order_preservation(self) -> None:
+ """Test that chunks are combined in the correct order."""
+ chunks = ["First", "Second", "Third", "Fourth", "Fifth"]
+
+ # Create mock responses with identifiable content
+ mock_responses = []
+ for chunk in chunks:
+ mock_resp = unittest.mock.MagicMock()
+ mock_resp.content = f"audio-{chunk}".encode()
+ mock_responses.append(mock_resp)
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.side_effect = mock_responses
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Track the order of segments being combined
+ combined_order = []
+
+ def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock:
+ content = data.read()
+ combined_order.append(content.decode())
+ segment = unittest.mock.MagicMock()
+ segment.__add__.return_value = segment
+ return segment
+
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.__add__.return_value = mock_segment
+
+ def mock_export(path: str, **_kwargs: typing.Any) -> None:
+ # Verify order is preserved
+ expected_order = [f"audio-{chunk}" for chunk in chunks]
+ if combined_order != expected_order:
+ msg = f"Order mismatch: {combined_order} != {expected_order}"
+ raise AssertionError(msg)
+ Path(path).write_bytes(b"ordered-audio")
+
+ mock_segment.export = mock_export
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ side_effect=mock_from_mp3,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ self.assertEqual(audio_data, b"ordered-audio")
+
+
+class TestIntroOutro(Test.TestCase):
+ """Test intro and outro generation with metadata."""
+
+ def test_create_intro_text_full_metadata(self) -> None:
+ """Test intro text creation with all metadata."""
+ intro = ArticleProcessor._create_intro_text( # noqa: SLF001
+ title="Test Article",
+ author="John Doe",
+ pub_date="2024-01-15",
+ )
+ self.assertIn("Title: Test Article", intro)
+ self.assertIn("Author: John Doe", intro)
+ self.assertIn("Published: 2024-01-15", intro)
+
+ def test_create_intro_text_no_author(self) -> None:
+ """Test intro text without author."""
+ intro = ArticleProcessor._create_intro_text( # noqa: SLF001
+ title="Test Article",
+ author=None,
+ pub_date="2024-01-15",
+ )
+ self.assertIn("Title: Test Article", intro)
+ self.assertNotIn("Author:", intro)
+ self.assertIn("Published: 2024-01-15", intro)
+
+ def test_create_intro_text_minimal(self) -> None:
+ """Test intro text with only title."""
+ intro = ArticleProcessor._create_intro_text( # noqa: SLF001
+ title="Test Article",
+ author=None,
+ pub_date=None,
+ )
+ self.assertEqual(intro, "Title: Test Article.")
+
+ def test_create_outro_text_with_author(self) -> None:
+ """Test outro text with author."""
+ outro = ArticleProcessor._create_outro_text( # noqa: SLF001
+ title="Test Article",
+ author="Jane Smith",
+ )
+ self.assertIn("Test Article", outro)
+ self.assertIn("Jane Smith", outro)
+ self.assertIn("Podcast It Later", outro)
+
+ def test_create_outro_text_no_author(self) -> None:
+ """Test outro text without author."""
+ outro = ArticleProcessor._create_outro_text( # noqa: SLF001
+ title="Test Article",
+ author=None,
+ )
+ self.assertIn("Test Article", outro)
+ self.assertNotIn("by", outro)
+ self.assertIn("Podcast It Later", outro)
+
+ def test_extract_with_metadata(self) -> None:
+ """Test that extraction returns metadata."""
+ mock_html = "<html><body><p>Content</p></body></html>"
+ mock_result = json.dumps({
+ "title": "Test Article",
+ "text": "Article content",
+ "author": "Test Author",
+ "date": "2024-01-15",
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=mock_html,
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, author, pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Test Article")
+ self.assertEqual(content, "Article content")
+ self.assertEqual(author, "Test Author")
+ self.assertEqual(pub_date, "2024-01-15")
+
+
+class TestMemoryEfficiency(Test.TestCase):
+ """Test memory-efficient processing."""
+
+ def test_large_article_size_limit(self) -> None:
+ """Test that articles exceeding size limits are rejected."""
+ huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=huge_text * 4, # Simulate large HTML
+ ),
+ pytest.raises(ValueError, match="Article too large") as cm,
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ self.assertIn("Article too large", str(cm.value))
+
+ def test_content_truncation(self) -> None:
+ """Test that oversized content is truncated."""
+ large_content = "Content " * 100_000 # Create large content
+ mock_result = json.dumps({
+ "title": "Large Article",
+ "text": large_content,
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value="<html><body>content</body></html>",
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content, _author, _pub_date = (
+ ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+ )
+
+ self.assertEqual(title, "Large Article")
+ self.assertLessEqual(len(content), MAX_ARTICLE_SIZE)
+
+ def test_memory_usage_check(self) -> None:
+ """Test memory usage monitoring."""
+ usage = check_memory_usage()
+ self.assertIsInstance(usage, float)
+ self.assertGreaterEqual(usage, 0.0)
+ self.assertLessEqual(usage, 100.0)
+
+
+class TestJobProcessing(Test.TestCase):
+ """Test job processing functionality."""
+
+ def setUp(self) -> None:
+ """Set up test environment."""
+ Core.Database.init_db()
+
+ # Create test user and job
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ )
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com/article",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Mock environment
+ self.env_patcher = unittest.mock.patch.dict(
+ os.environ,
+ {
+ "OPENAI_API_KEY": "test-key",
+ "S3_ENDPOINT": "https://s3.example.com",
+ "S3_BUCKET": "test-bucket",
+ "S3_ACCESS_KEY": "test-access",
+ "S3_SECRET_KEY": "test-secret",
+ },
+ )
+ self.env_patcher.start()
+
+ def tearDown(self) -> None:
+ """Clean up."""
+ self.env_patcher.stop()
+ Core.Database.teardown()
+
+ def test_process_job_success(self) -> None:
+ """Complete pipeline execution."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ # Mock all external calls
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=(
+ "Test Title",
+ "Test content",
+ "Test Author",
+ "2024-01-15",
+ ),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ return_value=b"audio-data",
+ ),
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ return_value="https://s3.example.com/audio.mp3",
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.create_episode",
+ ) as mock_create,
+ ):
+ mock_create.return_value = 1
+ processor.process_job(job)
+
+ # Verify job was marked complete
+ mock_update.assert_called_with(self.job_id, "completed")
+ mock_create.assert_called_once()
+
+ def test_process_job_extraction_failure(self) -> None:
+ """Handle bad URLs."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ side_effect=ValueError("Bad URL"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ pytest.raises(ValueError, match="Bad URL"),
+ ):
+ processor.process_job(job)
+
+ # Job should be marked as error
+ mock_update.assert_called_with(self.job_id, "error", "Bad URL")
+
+ def test_process_job_tts_failure(self) -> None:
+ """Handle TTS errors."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Title", "Content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ side_effect=Exception("TTS Error"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ pytest.raises(Exception, match="TTS Error"),
+ ):
+ processor.process_job(job)
+
+ mock_update.assert_called_with(self.job_id, "error", "TTS Error")
+
+ def test_process_job_s3_failure(self) -> None:
+ """Handle upload errors."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=("Title", "Content"),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ return_value=b"audio",
+ ),
+ unittest.mock.patch.object(
+ processor,
+ "upload_to_s3",
+ side_effect=ClientError({}, "PutObject"),
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ),
+ pytest.raises(ClientError),
+ ):
+ processor.process_job(job)
+
+ def test_job_retry_logic(self) -> None:
+ """Verify exponential backoff."""
+ # Set job to error with retry count
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "First failure",
+ )
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ "Second failure",
+ )
+
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ self.assertEqual(job["retry_count"], 2)
+
+ # Should be retryable
+ retryable = Core.Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 1)
+
+ def test_max_retries(self) -> None:
+ """Stop after max attempts."""
+ # Exceed retry limit
+ for i in range(4):
+ Core.Database.update_job_status(
+ self.job_id,
+ "error",
+ f"Failure {i}",
+ )
+
+ # Should not be retryable
+ retryable = Core.Database.get_retryable_jobs(
+ max_retries=3,
+ )
+ self.assertEqual(len(retryable), 0)
+
+ def test_graceful_shutdown(self) -> None:
+ """Test graceful shutdown during job processing."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ def mock_tts(*_args: Any) -> bytes:
+ shutdown_handler.shutdown_requested.set()
+ return b"audio-data"
+
+ with (
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "extract_article_content",
+ return_value=(
+ "Test Title",
+ "Test content",
+ "Test Author",
+ "2024-01-15",
+ ),
+ ),
+ unittest.mock.patch.object(
+ ArticleProcessor,
+ "text_to_speech",
+ side_effect=mock_tts,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ ):
+ processor.process_job(job)
+
+ # Job should be reset to pending due to shutdown
+ mock_update.assert_any_call(self.job_id, "pending")
+
+ def test_cleanup_stale_jobs(self) -> None:
+ """Test cleanup of stale processing jobs."""
+ # Manually set job to processing
+ Core.Database.update_job_status(self.job_id, "processing")
+
+ # Run cleanup
+ cleanup_stale_jobs()
+
+ # Job should be back to pending
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+ self.assertEqual(job["status"], "pending")
+
+ def test_concurrent_processing(self) -> None:
+ """Handle multiple jobs."""
+ # Create multiple jobs
+ job2 = Core.Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ job3 = Core.Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Get pending jobs
+ jobs = Core.Database.get_pending_jobs(limit=5)
+
+ self.assertEqual(len(jobs), 3)
+ self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3})
+
+ def test_memory_threshold_deferral(self) -> None:
+ """Test that jobs are deferred when memory usage is high."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ # Mock high memory usage
+ with (
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=90.0, # High memory usage
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ ):
+ processor.process_job(job)
+
+ # Job should not be processed (no status updates)
+ mock_update.assert_not_called()
+
+
+class TestWorkerErrorHandling(Test.TestCase):
+ """Test worker error handling and recovery."""
+
+ def setUp(self) -> None:
+ """Set up test environment."""
+ Core.Database.init_db()
+ self.user_id, _ = Core.Database.create_user("test@example.com")
+ self.job_id = Core.Database.add_to_queue(
+ "https://example.com",
+ "test@example.com",
+ self.user_id,
+ )
+ self.shutdown_handler = ShutdownHandler()
+ self.processor = ArticleProcessor(self.shutdown_handler)
+
+ @staticmethod
+ def tearDown() -> None:
+ """Clean up."""
+ Core.Database.teardown()
+
+ def test_process_pending_jobs_exception_handling(self) -> None:
+ """Test that process_pending_jobs handles exceptions."""
+
+ def side_effect(job: dict[str, Any]) -> None:
+ # Simulate process_job starting and setting status to processing
+ Core.Database.update_job_status(job["id"], "processing")
+ msg = "Unexpected Error"
+ raise ValueError(msg)
+
+ with (
+ unittest.mock.patch.object(
+ self.processor,
+ "process_job",
+ side_effect=side_effect,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ side_effect=Core.Database.update_job_status,
+ ) as _mock_update,
+ ):
+ process_pending_jobs(self.processor)
+
+ # Job should be marked as error
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job:
+ self.assertEqual(job["status"], "error")
+ self.assertIn("Unexpected Error", job["error_message"])
+
+ def test_process_retryable_jobs_success(self) -> None:
+ """Test processing of retryable jobs."""
+ # Set up a retryable job
+ Core.Database.update_job_status(self.job_id, "error", "Fail 1")
+
+ # Modify created_at to be in the past to satisfy backoff
+ with Core.Database.get_connection() as conn:
+ conn.execute(
+ "UPDATE queue SET created_at = ? WHERE id = ?",
+ (
+ (
+ datetime.now(tz=timezone.utc) - timedelta(minutes=5)
+ ).isoformat(),
+ self.job_id,
+ ),
+ )
+ conn.commit()
+
+ process_retryable_jobs()
+
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job:
+ self.assertEqual(job["status"], "pending")
+
+ def test_process_retryable_jobs_not_ready(self) -> None:
+ """Test that jobs are not retried before backoff period."""
+ # Set up a retryable job that just failed
+ Core.Database.update_job_status(self.job_id, "error", "Fail 1")
+
+ # created_at is now, so backoff should prevent retry
+ process_retryable_jobs()
+
+ job = Core.Database.get_job_by_id(self.job_id)
+ self.assertIsNotNone(job)
+ if job:
+ self.assertEqual(job["status"], "error")
+
+
+class TestTextChunking(Test.TestCase):
+ """Test text chunking edge cases."""
+
+ def test_split_text_single_long_word(self) -> None:
+ """Handle text with a single word exceeding limit."""
+ long_word = "a" * 4000
+ chunks = split_text_into_chunks(long_word, max_chars=3000)
+
+ # Should keep it as one chunk or split?
+ # The current implementation does not split words
+ self.assertEqual(len(chunks), 1)
+ self.assertEqual(len(chunks[0]), 4000)
+
+ def test_split_text_no_sentence_boundaries(self) -> None:
+ """Handle long text with no sentence boundaries."""
+ text = "word " * 1000 # 5000 chars
+ chunks = split_text_into_chunks(text, max_chars=3000)
+
+ # Should keep it as one chunk as it can't split by ". "
+ self.assertEqual(len(chunks), 1)
+ self.assertGreater(len(chunks[0]), 3000)
+
+
+def test() -> None:
+ """Run the tests."""
+ Test.run(
+ App.Area.Test,
+ [
+ TestArticleExtraction,
+ TestTextToSpeech,
+ TestMemoryEfficiency,
+ TestJobProcessing,
+ TestWorkerErrorHandling,
+ TestTextChunking,
+ ],
+ )
+
+
+def main() -> None:
+ """Entry point for the worker."""
+ if "test" in sys.argv:
+ test()
+ else:
+ move()