diff options
| author | Ben Sima <ben@bsima.me> | 2025-09-05 15:47:15 -0400 |
|---|---|---|
| committer | Ben Sima <ben@bsima.me> | 2025-09-05 15:47:15 -0400 |
| commit | 8f381492ce545bcfe6608f56c0134c26f42f0506 (patch) | |
| tree | a51722f703ec0c755db0f72cd4e214a0eb1dd5ce /Biz | |
| parent | eaa387204433999c2600a592d3e822d3ef8f2899 (diff) | |
Enhance worker memory management
Check to prevent processing of large articles, truncate oversized
content, defer jobs during high memory usage, use streaming TTS
generation and upload to minimize memory consumption.
Diffstat (limited to 'Biz')
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 244 |
1 files changed, 198 insertions, 46 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py index 8142b50..023d878 100644 --- a/Biz/PodcastItLater/Worker.py +++ b/Biz/PodcastItLater/Worker.py @@ -3,6 +3,7 @@ # : dep boto3 # : dep botocore # : dep openai +# : dep psutil # : dep pydub # : dep pytest # : dep pytest-asyncio @@ -19,9 +20,11 @@ import Omni.Log as Log import Omni.Test as Test import openai import os +import psutil # type: ignore[import-untyped] import pytest import signal import sys +import tempfile import threading import time import trafilatura @@ -31,6 +34,7 @@ from botocore.exceptions import ClientError # type: ignore[import-untyped] from datetime import datetime from datetime import timedelta from datetime import timezone +from pathlib import Path from pydub import AudioSegment # type: ignore[import-untyped] from typing import Any @@ -46,10 +50,12 @@ area = App.from_env() # Worker configuration MAX_CONTENT_LENGTH = 5000 # characters for TTS +MAX_ARTICLE_SIZE = 500_000 # 500KB character limit for articles POLL_INTERVAL = 30 # seconds MAX_RETRIES = 3 TTS_MODEL = "tts-1" TTS_VOICE = "alloy" +MEMORY_THRESHOLD = 80 # Percentage threshold for memory usage class ShutdownHandler: @@ -123,7 +129,7 @@ class ArticleProcessor: """Extract title and content from article URL using trafilatura. Raises: - ValueError: If content cannot be downloaded or extracted. + ValueError: If content cannot be downloaded, extracted, or large. """ try: downloaded = trafilatura.fetch_url(url) @@ -131,6 +137,13 @@ class ArticleProcessor: msg = f"Failed to download content from {url}" raise ValueError(msg) # noqa: TRY301 + # Check size before processing + if ( + len(downloaded) > MAX_ARTICLE_SIZE * 4 + ): # Rough HTML to text ratio + msg = f"Article too large: {len(downloaded)} bytes" + raise ValueError(msg) # noqa: TRY301 + # Extract with metadata result = trafilatura.extract( downloaded, @@ -153,7 +166,15 @@ class ArticleProcessor: msg = f"No content extracted from {url}" raise ValueError(msg) # noqa: TRY301 - # Don't truncate - we'll handle length in text_to_speech + # Enforce content size limit + if len(content) > MAX_ARTICLE_SIZE: + logger.warning( + "Article content truncated from %d to %d characters", + len(content), + MAX_ARTICLE_SIZE, + ) + content = content[:MAX_ARTICLE_SIZE] + logger.info("Extracted article: %s (%d chars)", title, len(content)) except Exception: logger.exception("Failed to extract content from %s", url) @@ -164,7 +185,7 @@ class ArticleProcessor: def text_to_speech(self, text: str, title: str) -> bytes: """Convert text to speech using OpenAI TTS API. - Uses LLM to prepare text, then handles chunking and concatenation. + Uses streaming approach to maintain constant memory usage. Raises: ValueError: If no chunks are generated from text. @@ -179,54 +200,81 @@ class ArticleProcessor: logger.info("Processing %d chunks for TTS", len(chunks)) - # Generate audio for each chunk - audio_segments = [] - for i, chunk in enumerate(chunks): - logger.info( - "Generating TTS for chunk %d/%d (%d chars)", - i + 1, - len(chunks), - len(chunk), - ) + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) response = self.openai_client.audio.speech.create( model=TTS_MODEL, voice=TTS_VOICE, - input=chunk, + input=chunks[0], response_format="mp3", ) - # Convert bytes to AudioSegment - audio_segment = AudioSegment.from_mp3( - io.BytesIO(response.content), - ) - audio_segments.append(audio_segment) - - # Small delay between API calls to be respectful - if i < len(chunks) - 1: - time.sleep(0.5) + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), + ) + + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment + + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() - # Concatenate all audio segments - combined_audio = audio_segments[0] - for segment in audio_segments[1:]: - # Add a small silence between chunks for natural pacing - silence = AudioSegment.silent(duration=300) - combined_audio = combined_audio + silence + segment + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data - # Export combined audio to bytes - output_buffer = io.BytesIO() - combined_audio.export(output_buffer, format="mp3", bitrate="128k") - audio_data = output_buffer.getvalue() + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() - logger.info( - "Generated combined TTS audio: %d bytes", - len(audio_data), - ) except Exception: logger.exception("TTS generation failed") raise - else: - return audio_data def upload_to_s3(self, audio_data: bytes, filename: str) -> str: """Upload audio file to S3-compatible storage and return public URL. @@ -240,18 +288,25 @@ class ArticleProcessor: raise ValueError(msg) try: - # Upload file - self.s3_client.put_object( - Bucket=S3_BUCKET, - Key=filename, - Body=audio_data, - ContentType="audio/mpeg", - ACL="public-read", + # Upload file using streaming to minimize memory usage + audio_stream = io.BytesIO(audio_data) + self.s3_client.upload_fileobj( + audio_stream, + S3_BUCKET, + filename, + ExtraArgs={ + "ContentType": "audio/mpeg", + "ACL": "public-read", + }, ) # Construct public URL audio_url = f"{S3_ENDPOINT}/{S3_BUCKET}/{filename}" - logger.info("Uploaded audio to: %s", audio_url) + logger.info( + "Uploaded audio to: %s (%d bytes)", + audio_url, + len(audio_data), + ) except ClientError: logger.exception("S3 upload failed") raise @@ -284,6 +339,16 @@ class ArticleProcessor: job_id = job["id"] url = job["url"] + # Check memory before starting + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD: + logger.warning( + "High memory usage (%.1f%%), deferring job %d", + mem_usage, + job_id, + ) + return + # Track current job for graceful shutdown self.shutdown_handler.set_current_job(job_id) @@ -599,6 +664,17 @@ def process_retryable_jobs() -> None: ) +def check_memory_usage() -> int | Any: + """Check current memory usage percentage.""" + try: + process = psutil.Process() + # this returns an int but psutil is untyped + return process.memory_percent() + except (psutil.Error, OSError): + logger.warning("Failed to check memory usage") + return 0.0 + + def cleanup_stale_jobs() -> None: """Reset jobs stuck in processing state on startup.""" with Core.Database.get_connection() as conn: @@ -1043,6 +1119,57 @@ class TestTextToSpeech(Test.TestCase): self.assertEqual(audio_data, b"test-audio-output") +class TestMemoryEfficiency(Test.TestCase): + """Test memory-efficient processing.""" + + def test_large_article_size_limit(self) -> None: + """Test that articles exceeding size limits are rejected.""" + huge_text = "x" * (MAX_ARTICLE_SIZE + 1000) # Exceed limit + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value=huge_text * 4, # Simulate large HTML + ), + pytest.raises(ValueError, match="Article too large") as cm, + ): + ArticleProcessor.extract_article_content("https://example.com") + + self.assertIn("Article too large", str(cm.value)) + + def test_content_truncation(self) -> None: + """Test that oversized content is truncated.""" + large_content = "Content " * 100_000 # Create large content + mock_result = json.dumps({ + "title": "Large Article", + "text": large_content, + }) + + with ( + unittest.mock.patch( + "trafilatura.fetch_url", + return_value="<html><body>content</body></html>", + ), + unittest.mock.patch( + "trafilatura.extract", + return_value=mock_result, + ), + ): + title, content = ArticleProcessor.extract_article_content( + "https://example.com", + ) + + self.assertEqual(title, "Large Article") + self.assertLessEqual(len(content), MAX_ARTICLE_SIZE) + + def test_memory_usage_check(self) -> None: + """Test memory usage monitoring.""" + usage = check_memory_usage() + self.assertIsInstance(usage, float) + self.assertGreaterEqual(usage, 0.0) + self.assertLessEqual(usage, 100.0) + + class TestJobProcessing(Test.TestCase): """Test job processing functionality.""" @@ -1314,6 +1441,30 @@ class TestJobProcessing(Test.TestCase): self.assertEqual(len(jobs), 3) self.assertEqual({j["id"] for j in jobs}, {self.job_id, job2, job3}) + def test_memory_threshold_deferral(self) -> None: + """Test that jobs are deferred when memory usage is high.""" + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + job = Core.Database.get_job_by_id(self.job_id) + if job is None: + msg = "no job found for %s" + raise Test.TestError(msg, self.job_id) + + # Mock high memory usage + with ( + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=90.0, # High memory usage + ), + unittest.mock.patch( + "Biz.PodcastItLater.Core.Database.update_job_status", + ) as mock_update, + ): + processor.process_job(job) + + # Job should not be processed (no status updates) + mock_update.assert_not_called() + def test() -> None: """Run the tests.""" @@ -1322,6 +1473,7 @@ def test() -> None: [ TestArticleExtraction, TestTextToSpeech, + TestMemoryEfficiency, TestJobProcessing, ], ) |
