summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater/Worker.py')
-rw-r--r--Biz/PodcastItLater/Worker.py244
1 files changed, 198 insertions, 46 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py
index 8142b50..023d878 100644
--- a/Biz/PodcastItLater/Worker.py
+++ b/Biz/PodcastItLater/Worker.py
@@ -3,6 +3,7 @@
# : dep boto3
# : dep botocore
# : dep openai
+# : dep psutil
# : dep pydub
# : dep pytest
# : dep pytest-asyncio
@@ -19,9 +20,11 @@ import Omni.Log as Log
import Omni.Test as Test
import openai
import os
+import psutil # type: ignore[import-untyped]
import pytest
import signal
import sys
+import tempfile
import threading
import time
import trafilatura
@@ -31,6 +34,7 @@ from botocore.exceptions import ClientError # type: ignore[import-untyped]
from datetime import datetime
from datetime import timedelta
from datetime import timezone
+from pathlib import Path
from pydub import AudioSegment # type: ignore[import-untyped]
from typing import Any
@@ -46,10 +50,12 @@ area = App.from_env()
# Worker configuration
MAX_CONTENT_LENGTH = 5000 # characters for TTS
+MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles
POLL_INTERVAL = 30 # seconds
MAX_RETRIES = 3
TTS_MODEL = "tts-1"
TTS_VOICE = "alloy"
+MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage
class ShutdownHandler:
@@ -123,7 +129,7 @@ class ArticleProcessor:
"""Extract title and content from article URL using trafilatura.
Raises:
- ValueError: If content cannot be downloaded or extracted.
+ ValueError: If content cannot be downloaded, extracted, or large.
"""
try:
downloaded = trafilatura.fetch_url(url)
@@ -131,6 +137,13 @@ class ArticleProcessor:
msg = f"Failed to download content from {url}"
raise ValueError(msg) # noqa: TRY301
+ # Check size before processing
+ if (
+ len(downloaded) > MAX_ARTICLE_SIZE * 4
+ ): # Rough HTML to text ratio
+ msg = f"Article too large: {len(downloaded)} bytes"
+ raise ValueError(msg) # noqa: TRY301
+
# Extract with metadata
result = trafilatura.extract(
downloaded,
@@ -153,7 +166,15 @@ class ArticleProcessor:
msg = f"No content extracted from {url}"
raise ValueError(msg) # noqa: TRY301
- # Don't truncate - we'll handle length in text_to_speech
+ # Enforce content size limit
+ if len(content) > MAX_ARTICLE_SIZE:
+ logger.warning(
+ "Article content truncated from %d to %d characters",
+ len(content),
+ MAX_ARTICLE_SIZE,
+ )
+ content = content[:MAX_ARTICLE_SIZE]
+
logger.info("Extracted article: %s (%d chars)", title, len(content))
except Exception:
logger.exception("Failed to extract content from %s", url)
@@ -164,7 +185,7 @@ class ArticleProcessor:
def text_to_speech(self, text: str, title: str) -> bytes:
"""Convert text to speech using OpenAI TTS API.
- Uses LLM to prepare text, then handles chunking and concatenation.
+ Uses streaming approach to maintain constant memory usage.
Raises:
ValueError: If no chunks are generated from text.
@@ -179,54 +200,81 @@ class ArticleProcessor:
logger.info("Processing %d chunks for TTS", len(chunks))
- # Generate audio for each chunk
- audio_segments = []
- for i, chunk in enumerate(chunks):
- logger.info(
- "Generating TTS for chunk %d/%d (%d chars)",
- i + 1,
- len(chunks),
- len(chunk),
- )
+ # Create a temporary file for streaming audio concatenation
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+ try:
+ # Process first chunk
+ logger.info("Generating TTS for chunk 1/%d", len(chunks))
response = self.openai_client.audio.speech.create(
model=TTS_MODEL,
voice=TTS_VOICE,
- input=chunk,
+ input=chunks[0],
response_format="mp3",
)
- # Convert bytes to AudioSegment
- audio_segment = AudioSegment.from_mp3(
- io.BytesIO(response.content),
- )
- audio_segments.append(audio_segment)
-
- # Small delay between API calls to be respectful
- if i < len(chunks) - 1:
- time.sleep(0.5)
+ # Write first chunk directly to file
+ Path(temp_path).write_bytes(response.content)
+
+ # Process remaining chunks
+ for i, chunk in enumerate(chunks[1:], 1):
+ logger.info(
+ "Generating TTS for chunk %d/%d (%d chars)",
+ i + 1,
+ len(chunks),
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ # Append to existing file with silence gap
+ # Load only the current segment
+ current_segment = AudioSegment.from_mp3(
+ io.BytesIO(response.content),
+ )
+
+ # Load existing audio, append, and save back
+ existing_audio = AudioSegment.from_mp3(temp_path)
+ silence = AudioSegment.silent(duration=300)
+ combined = existing_audio + silence + current_segment
+
+ # Export back to the same file
+ combined.export(temp_path, format="mp3", bitrate="128k")
+
+ # Force garbage collection to free memory
+ del existing_audio, current_segment, combined
+
+ # Small delay between API calls
+ if i < len(chunks) - 1:
+ time.sleep(0.5)
+
+ # Read final result
+ audio_data = Path(temp_path).read_bytes()
- # Concatenate all audio segments
- combined_audio = audio_segments[0]
- for segment in audio_segments[1:]:
- # Add a small silence between chunks for natural pacing
- silence = AudioSegment.silent(duration=300)
- combined_audio = combined_audio + silence + segment
+ logger.info(
+ "Generated combined TTS audio: %d bytes",
+ len(audio_data),
+ )
+ return audio_data
- # Export combined audio to bytes
- output_buffer = io.BytesIO()
- combined_audio.export(output_buffer, format="mp3", bitrate="128k")
- audio_data = output_buffer.getvalue()
+ finally:
+ # Clean up temp file
+ temp_file_path = Path(temp_path)
+ if temp_file_path.exists():
+ temp_file_path.unlink()
- logger.info(
- "Generated combined TTS audio: %d bytes",
- len(audio_data),
- )
except Exception:
logger.exception("TTS generation failed")
raise
- else:
- return audio_data
def upload_to_s3(self, audio_data: bytes, filename: str) -> str:
"""Upload audio file to S3-compatible storage and return public URL.
@@ -240,18 +288,25 @@ class ArticleProcessor:
raise ValueError(msg)
try:
- # Upload file
- self.s3_client.put_object(
- Bucket=S3_BUCKET,
- Key=filename,
- Body=audio_data,
- ContentType="audio/mpeg",
- ACL="public-read",
+ # Upload file using streaming to minimize memory usage
+ audio_stream = io.BytesIO(audio_data)
+ self.s3_client.upload_fileobj(
+ audio_stream,
+ S3_BUCKET,
+ filename,
+ ExtraArgs={
+ "ContentType": "audio/mpeg",
+ "ACL": "public-read",
+ },
)
# Construct public URL
audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}"
- logger.info("Uploaded audio to: %s", audio_url)
+ logger.info(
+ "Uploaded audio to: %s (%d bytes)",
+ audio_url,
+ len(audio_data),
+ )
except ClientError:
logger.exception("S3 upload failed")
raise
@@ -284,6 +339,16 @@ class ArticleProcessor:
job_id = job["id"]
url = job["url"]
+ # Check memory before starting
+ mem_usage = check_memory_usage()
+ if mem_usage > MEMORY_THRESHOLD:
+ logger.warning(
+ "High memory usage (%.1f%%), deferring job %d",
+ mem_usage,
+ job_id,
+ )
+ return
+
# Track current job for graceful shutdown
self.shutdown_handler.set_current_job(job_id)
@@ -599,6 +664,17 @@ def process_retryable_jobs() -> None:
)
+def check_memory_usage() -> int | Any:
+ """Check current memory usage percentage."""
+ try:
+ process = psutil.Process()
+ # this returns an int but psutil is untyped
+ return process.memory_percent()
+ except (psutil.Error, OSError):
+ logger.warning("Failed to check memory usage")
+ return 0.0
+
+
def cleanup_stale_jobs() -> None:
"""Reset jobs stuck in processing state on startup."""
with Core.Database.get_connection() as conn:
@@ -1043,6 +1119,57 @@ class TestTextToSpeech(Test.TestCase):
self.assertEqual(audio_data, b"test-audio-output")
+class TestMemoryEfficiency(Test.TestCase):
+ """Test memory-efficient processing."""
+
+ def test_large_article_size_limit(self) -> None:
+ """Test that articles exceeding size limits are rejected."""
+ huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value=huge_text * 4, # Simulate large HTML
+ ),
+ pytest.raises(ValueError, match="Article too large") as cm,
+ ):
+ ArticleProcessor.extract_article_content("https://example.com")
+
+ self.assertIn("Article too large", str(cm.value))
+
+ def test_content_truncation(self) -> None:
+ """Test that oversized content is truncated."""
+ large_content = "Content " * 100_000 # Create large content
+ mock_result = json.dumps({
+ "title": "Large Article",
+ "text": large_content,
+ })
+
+ with (
+ unittest.mock.patch(
+ "trafilatura.fetch_url",
+ return_value="<html><body>content</body></html>",
+ ),
+ unittest.mock.patch(
+ "trafilatura.extract",
+ return_value=mock_result,
+ ),
+ ):
+ title, content = ArticleProcessor.extract_article_content(
+ "https://example.com",
+ )
+
+ self.assertEqual(title, "Large Article")
+ self.assertLessEqual(len(content), MAX_ARTICLE_SIZE)
+
+ def test_memory_usage_check(self) -> None:
+ """Test memory usage monitoring."""
+ usage = check_memory_usage()
+ self.assertIsInstance(usage, float)
+ self.assertGreaterEqual(usage, 0.0)
+ self.assertLessEqual(usage, 100.0)
+
+
class TestJobProcessing(Test.TestCase):
"""Test job processing functionality."""
@@ -1314,6 +1441,30 @@ class TestJobProcessing(Test.TestCase):
self.assertEqual(len(jobs), 3)
self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3})
+ def test_memory_threshold_deferral(self) -> None:
+ """Test that jobs are deferred when memory usage is high."""
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ job = Core.Database.get_job_by_id(self.job_id)
+ if job is None:
+ msg = "no job found for %s"
+ raise Test.TestError(msg, self.job_id)
+
+ # Mock high memory usage
+ with (
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=90.0, # High memory usage
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Core.Database.update_job_status",
+ ) as mock_update,
+ ):
+ processor.process_job(job)
+
+ # Job should not be processed (no status updates)
+ mock_update.assert_not_called()
+
def test() -> None:
"""Run the tests."""
@@ -1322,6 +1473,7 @@ def test() -> None:
[
TestArticleExtraction,
TestTextToSpeech,
+ TestMemoryEfficiency,
TestJobProcessing,
],
)