diff options
Diffstat (limited to 'Biz')
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 244 |
1 files changed, 198 insertions, 46 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py index 8142b50..023d878 100644 --- a/Biz/PodcastItLater/Worker.py +++ b/Biz/PodcastItLater/Worker.py @@ -3,6 +3,7 @@ # : dep boto3 # : dep botocore # : dep openai +# : dep psutil # : dep pydub # : dep pytest # : dep pytest-asyncio @@ -19,9 +20,11 @@ import Omni.Log as Log import Omni.Test as Test import openai import os +import psutil # type: ignore[import-untyped] import pytest import signal import sys +import tempfile import threading import time import trafilatura @@ -31,6 +34,7 @@ from botocore.exceptions import ClientError # type: ignore[import-untyped] from datetime import datetime from datetime import timedelta from datetime import timezone +from pathlib import Path from pydub import AudioSegment # type: ignore[import-untyped] from typing import Any @@ -46,10 +50,12 @@ area = App.from_env() # Worker configuration MAX_CONTENT_LENGTH = 5000 # characters for TTS +MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles POLL_INTERVAL = 30 # seconds MAX_RETRIES = 3 TTS_MODEL = "tts-1" TTS_VOICE = "alloy" +MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage class ShutdownHandler: @@ -123,7 +129,7 @@ class ArticleProcessor: """Extract title and content from article URL using trafilatura. Raises: - ValueError: If content cannot be downloaded or extracted. + ValueError: If content cannot be downloaded, extracted, or large. """ try: downloaded = trafilatura.fetch_url(url) @@ -131,6 +137,13 @@ class ArticleProcessor: msg = f"Failed to download content from {url}" raise ValueError(msg) # noqa: TRY301 + # Check size before processing + if ( + len(downloaded) > MAX_ARTICLE_SIZE * 4 + ): # Rough HTML to text ratio + msg = f"Article too large: {len(downloaded)} bytes" + raise ValueError(msg) # noqa: TRY301 + # Extract with metadata result = trafilatura.extract( downloaded, @@ -153,7 +166,15 @@ class ArticleProcessor: msg = f"No content extracted from {url}" raise ValueError(msg) # noqa: TRY301 - # Don't truncate - we'll handle length in text_to_speech + # Enforce content size limit + if len(content) > MAX_ARTICLE_SIZE: + logger.warning( + "Article content truncated from %d to %d characters", + len(content), + MAX_ARTICLE_SIZE, + ) + content = content[:MAX_ARTICLE_SIZE] + logger.info("Extracted article: %s (%d chars)", title, len(content)) except Exception: logger.exception("Failed to extract content from %s", url) @@ -164,7 +185,7 @@ class ArticleProcessor: def text_to_speech(self, text: str, title: str) -> bytes: """Convert text to speech using OpenAI TTS API. - Uses LLM to prepare text, then handles chunking and concatenation. + Uses streaming approach to maintain constant memory usage. Raises: ValueError: If no chunks are generated from text. @@ -179,54 +200,81 @@ class ArticleProcessor: logger.info("Processing %d chunks for TTS", len(chunks)) - # Generate audio for each chunk - audio_segments = [] - for i, chunk in enumerate(chunks): - logger.info( - "Generating TTS for chunk %d/%d (%d chars)", - i + 1, - len(chunks), - len(chunk), - ) + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) response = self.openai_client.audio.speech.create( model=TTS_MODEL, voice=TTS_VOICE, - input=chunk, + input=chunks[0], response_format="mp3", ) - # Convert bytes to AudioSegment - audio_segment = AudioSegment.from_mp3( - io.BytesIO(response.content), - ) - audio_segments.append(audio_segment) - - # Small delay between API calls to be respectful - if i < len(chunks) - 1: - time.sleep(0.5) + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment + + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() - # Concatenate all audio segments - combined_audio = audio_segments[0] - for segment in audio_segments[1:]: - # Add a small silence between chunks for natural pacing - silence = AudioSegment.silent(duration=300) - combined_audio = combined_audio + silence + segment + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data - # Export combined audio to bytes - output_buffer = io.BytesIO() - combined_audio.export(output_buffer, format="mp3", bitrate="128k") - audio_data = output_buffer.getvalue() + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() - logger.info( - "Generated combined TTS audio: %d bytes", - len(audio_data), - ) except Exception: logger.exception("TTS generation failed") raise - else: - return audio_data def upload_to_s3(self, audio_data: bytes, filename: str) -> str: """Upload audio file to S3-compatible storage and return public URL. @@ -240,18 +288,25 @@ class ArticleProcessor: raise ValueError(msg) try: - # Upload file - self.s3_client.put_object( - Bucket=S3_BUCKET, - Key=filename, - Body=audio_data, - ContentType="audio/mpeg", - ACL="public-read", + # Upload file using streaming to minimize memory usage + audio_stream = io.BytesIO(audio_data) + self.s3_client.upload_fileobj( + audio_stream, + S3_BUCKET, + filename, + ExtraArgs={ + "ContentType": "audio/mpeg", + "ACL": "public-read", + }, ) # Construct public URL audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" - logger.info("Uploaded audio to: %s", audio_url) + logger.info( + "Uploaded audio to: %s (%d bytes)", + audio_url, + len(audio_data), + ) except ClientError: logger.exception("S3 upload failed") raise @@ -284,6 +339,16 @@ class ArticleProcessor: job_id = job["id"] url = job["url"] + # Check memory before starting + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD: + logger.warning( + "High memory usage (%.1f%%), deferring job %d", + mem_usage, + job_id, + ) + return + # Track current job for graceful shutdown self.shutdown_handler.set_current_job(job_id) @@ -599,6 +664,17 @@ def process_retryable_jobs() -> None: ) +def check_memory_usage() -> int | Any: + """Check current memory usage percentage.""" + try: + process = psutil.Process() + # this returns an int but psutil is untyped + return process.memory_percent() + except (psutil.Error, OSError): + logger.warning("Failed to check memory usage") + return 0.0 + + def cleanup_stale_jobs() -> None: """Reset jobs stuck in processing state on startup.""" with Core.Database.get_connection() as conn: @@ -1043,6 +1119,57 @@ class TestTextToSpeech(Test.TestCase): self.assertEqual(audio_data, b"test-audio-output") +class TestMemoryEfficiency(Test.TestCase): + """Test memory-efficient processing.""" + + def test_large_article_size_limit(self) -> None: + """Test that articles exceeding size limits are rejected.""" + huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=huge_text * 4, # Simulate large HTML + ), + pytest.raises(ValueError, match="Article too large") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Article too large", str(cm.value)) + + def test_content_truncation(self) -> None: + """Test that oversized content is truncated.""" + large_content = "Content " * 100_000 # Create large content + mock_result = json.dumps({ + "title": "Large Article", + "text": large_content, + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value="<html><body>content</body></html>", + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(title, "Large Article") + self.assertLessEqual(len(content), MAX_ARTICLE_SIZE) + + def test_memory_usage_check(self) -> None: + """Test memory usage monitoring.""" + usage = check_memory_usage() + self.assertIsInstance(usage, float) + self.assertGreaterEqual(usage, 0.0) + self.assertLessEqual(usage, 100.0) + + class TestJobProcessing(Test.TestCase): """Test job processing functionality.""" @@ -1314,6 +1441,30 @@ class TestJobProcessing(Test.TestCase): self.assertEqual(len(jobs), 3) self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + def test_memory_threshold_deferral(self) -> None: + """Test that jobs are deferred when memory usage is high.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock high memory usage + with ( + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=90.0, # High memory usage + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should not be processed (no status updates) + mock_update.assert_not_called() + def test() -> None: """Run the tests.""" @@ -1322,6 +1473,7 @@ def test() -> None: [ TestArticleExtraction, TestTextToSpeech, + TestMemoryEfficiency, TestJobProcessing, ], ) |
