diff options
Diffstat (limited to 'Biz')
| -rw-r--r-- | Biz/PodcastItLater/Worker.py | 458 |
1 files changed, 392 insertions, 66 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py index 4dfbfe2..2807e8a 100644 --- a/Biz/PodcastItLater/Worker.py +++ b/Biz/PodcastItLater/Worker.py @@ -13,12 +13,14 @@ # : run ffmpeg import Biz.PodcastItLater.Core as Core import boto3 # type: ignore[import-untyped] +import concurrent.futures import io import json import Omni.App as App import Omni.Log as Log import Omni.Test as Test import openai +import operator import os import psutil # type: ignore[import-untyped] import pytest @@ -185,7 +187,7 @@ class ArticleProcessor: def text_to_speech(self, text: str, title: str) -> bytes: """Convert text to speech using OpenAI TTS API. - Uses streaming approach to maintain constant memory usage. + Uses parallel processing for chunks while maintaining order. Raises: ValueError: If no chunks are generated from text. @@ -200,81 +202,205 @@ class ArticleProcessor: logger.info("Processing %d chunks for TTS", len(chunks)) - # Create a temporary file for streaming audio concatenation - with tempfile.NamedTemporaryFile( - suffix=".mp3", - delete=False, - ) as temp_file: - temp_path = temp_file.name + # Check memory before parallel processing + mem_usage = check_memory_usage() + if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer + logger.warning( + "High memory usage (%.1f%%), falling back to serial " + "processing", + mem_usage, + ) + return self._text_to_speech_serial(chunks) + + # Determine max workers based on chunk count and system resources + max_workers = min( + 4, # Reasonable limit to avoid rate limiting + len(chunks), # No more workers than chunks + max(1, psutil.cpu_count() // 2), # Use half of CPU cores + ) + + logger.info( + "Using %d workers for parallel TTS processing", + max_workers, + ) + + # Process chunks in parallel + chunk_results: list[tuple[int, bytes]] = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + ) as executor: + # Submit all chunks for processing + future_to_index = { + executor.submit(self._process_tts_chunk, chunk, i): i + for i, chunk in enumerate(chunks) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + audio_data = future.result() + chunk_results.append((index, audio_data)) + except Exception: + logger.exception("Failed to process chunk %d", index) + raise + + # Sort results by index to maintain order + chunk_results.sort(key=operator.itemgetter(0)) + + # Combine audio chunks + return self._combine_audio_chunks([ + data for _, data in chunk_results + ]) + + except Exception: + logger.exception("TTS generation failed") + raise + + def _process_tts_chunk(self, chunk: str, index: int) -> bytes: + """Process a single TTS chunk. + + Args: + chunk: Text to convert to speech + index: Chunk index for logging + + Returns: + Audio data as bytes + """ + logger.info( + "Generating TTS for chunk %d (%d chars)", + index + 1, + len(chunk), + ) + + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunk, + response_format="mp3", + ) + + return response.content + + @staticmethod + def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes: + """Combine multiple audio chunks with silence gaps. + + Args: + audio_chunks: List of audio data in order + + Returns: + Combined audio data + """ + if not audio_chunks: + msg = "No audio chunks to combine" + raise ValueError(msg) + + # Create a temporary file for the combined audio + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Start with the first chunk + combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0])) + + # Add remaining chunks with silence gaps + for chunk_data in audio_chunks[1:]: + chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data)) + silence = AudioSegment.silent(duration=300) # 300ms gap + combined_audio = combined_audio + silence + chunk_audio + + # Export to file + combined_audio.export(temp_path, format="mp3", bitrate="128k") + + # Read back the combined audio + return Path(temp_path).read_bytes() + + finally: + # Clean up temp file + if Path(temp_path).exists(): + Path(temp_path).unlink() + + def _text_to_speech_serial(self, chunks: list[str]) -> bytes: + """Fallback serial processing for high memory situations. + + This is the original serial implementation. + """ + # Create a temporary file for streaming audio concatenation + with tempfile.NamedTemporaryFile( + suffix=".mp3", + delete=False, + ) as temp_file: + temp_path = temp_file.name + + try: + # Process first chunk + logger.info("Generating TTS for chunk 1/%d", len(chunks)) + response = self.openai_client.audio.speech.create( + model=TTS_MODEL, + voice=TTS_VOICE, + input=chunks[0], + response_format="mp3", + ) + + # Write first chunk directly to file + Path(temp_path).write_bytes(response.content) + + # Process remaining chunks + for i, chunk in enumerate(chunks[1:], 1): + logger.info( + "Generating TTS for chunk %d/%d (%d chars)", + i + 1, + len(chunks), + len(chunk), + ) - try: - # Process first chunk - logger.info("Generating TTS for chunk 1/%d", len(chunks)) response = self.openai_client.audio.speech.create( model=TTS_MODEL, voice=TTS_VOICE, - input=chunks[0], + input=chunk, response_format="mp3", ) - # Write first chunk directly to file - Path(temp_path).write_bytes(response.content) - - # Process remaining chunks - for i, chunk in enumerate(chunks[1:], 1): - logger.info( - "Generating TTS for chunk %d/%d (%d chars)", - i + 1, - len(chunks), - len(chunk), - ) - - response = self.openai_client.audio.speech.create( - model=TTS_MODEL, - voice=TTS_VOICE, - input=chunk, - response_format="mp3", - ) - - # Append to existing file with silence gap - # Load only the current segment - current_segment = AudioSegment.from_mp3( - io.BytesIO(response.content), - ) - - # Load existing audio, append, and save back - existing_audio = AudioSegment.from_mp3(temp_path) - silence = AudioSegment.silent(duration=300) - combined = existing_audio + silence + current_segment - - # Export back to the same file - combined.export(temp_path, format="mp3", bitrate="128k") - - # Force garbage collection to free memory - del existing_audio, current_segment, combined - - # Small delay between API calls - if i < len(chunks) - 1: - time.sleep(0.5) - - # Read final result - audio_data = Path(temp_path).read_bytes() - - logger.info( - "Generated combined TTS audio: %d bytes", - len(audio_data), + # Append to existing file with silence gap + # Load only the current segment + current_segment = AudioSegment.from_mp3( + io.BytesIO(response.content), ) - return audio_data - finally: - # Clean up temp file - temp_file_path = Path(temp_path) - if temp_file_path.exists(): - temp_file_path.unlink() + # Load existing audio, append, and save back + existing_audio = AudioSegment.from_mp3(temp_path) + silence = AudioSegment.silent(duration=300) + combined = existing_audio + silence + current_segment - except Exception: - logger.exception("TTS generation failed") - raise + # Export back to the same file + combined.export(temp_path, format="mp3", bitrate="128k") + + # Force garbage collection to free memory + del existing_audio, current_segment, combined + + # Small delay between API calls + if i < len(chunks) - 1: + time.sleep(0.5) + + # Read final result + audio_data = Path(temp_path).read_bytes() + + logger.info( + "Generated combined TTS audio: %d bytes", + len(audio_data), + ) + return audio_data + + finally: + # Clean up temp file + temp_file_path = Path(temp_path) + if temp_file_path.exists(): + temp_file_path.unlink() def upload_to_s3(self, audio_data: bytes, filename: str) -> str: """Upload audio file to S3-compatible storage and return public URL. @@ -1117,6 +1243,206 @@ class TestTextToSpeech(Test.TestCase): self.assertIsInstance(audio_data, bytes) self.assertEqual(audio_data, b"test-audio-output") + def test_parallel_tts_generation(self) -> None: + """Test parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"] + + # Mock responses for each chunk + mock_responses = [] + for i in range(len(chunks)): + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{i}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # Make create return different responses for each call + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Mock AudioSegment operations + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + Path(path).write_bytes(b"combined-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + return_value=mock_segment, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, # Normal memory usage + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Verify all chunks were processed + self.assertEqual(mock_speech.create.call_count, len(chunks)) + self.assertEqual(audio_data, b"combined-audio") + + def test_parallel_tts_high_memory_fallback(self) -> None: + """Test fallback to serial processing when memory is high.""" + chunks = ["Chunk 1", "Chunk 2"] + + def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None: + buffer.write(b"serial-audio") + buffer.seek(0) + + self.mock_audio_segment.export.side_effect = mock_export + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.return_value = self.mock_audio_response + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=65.0, # High memory usage + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=self.mock_audio_segment, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + # Should use serial processing + self.assertEqual(audio_data, b"serial-audio") + + @staticmethod + def test_parallel_tts_error_handling() -> None: + """Test error handling in parallel TTS processing.""" + chunks = ["Chunk 1", "Chunk 2"] + + # Mock OpenAI client with one failure + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + + # First call succeeds, second fails + mock_resp1 = unittest.mock.MagicMock() + mock_resp1.content = b"audio-1" + mock_speech.create.side_effect = [mock_resp1, Exception("API Error")] + + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Set up the test context + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + pytest.raises(Exception, match="API Error"), + ): + processor.text_to_speech("Test content", "Test Title") + + def test_parallel_tts_order_preservation(self) -> None: + """Test that chunks are combined in the correct order.""" + chunks = ["First", "Second", "Third", "Fourth", "Fifth"] + + # Create mock responses with identifiable content + mock_responses = [] + for chunk in chunks: + mock_resp = unittest.mock.MagicMock() + mock_resp.content = f"audio-{chunk}".encode() + mock_responses.append(mock_resp) + + # Mock OpenAI client + mock_client = unittest.mock.MagicMock() + mock_audio = unittest.mock.MagicMock() + mock_speech = unittest.mock.MagicMock() + mock_speech.create.side_effect = mock_responses + mock_audio.speech = mock_speech + mock_client.audio = mock_audio + + # Track the order of segments being combined + combined_order = [] + + def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock: + content = data.read() + combined_order.append(content.decode()) + segment = unittest.mock.MagicMock() + segment.__add__.return_value = segment + return segment + + mock_segment = unittest.mock.MagicMock() + mock_segment.__add__.return_value = mock_segment + + def mock_export(path: str, **_kwargs: typing.Any) -> None: + # Verify order is preserved + expected_order = [f"audio-{chunk}" for chunk in chunks] + if combined_order != expected_order: + msg = f"Order mismatch: {combined_order} != {expected_order}" + raise AssertionError(msg) + Path(path).write_bytes(b"ordered-audio") + + mock_segment.export = mock_export + + with ( + unittest.mock.patch("openai.OpenAI", return_value=mock_client), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.prepare_text_for_tts", + return_value=chunks, + ), + unittest.mock.patch( + "pydub.AudioSegment.from_mp3", + side_effect=mock_from_mp3, + ), + unittest.mock.patch( + "pydub.AudioSegment.silent", + return_value=mock_segment, + ), + unittest.mock.patch( + "Biz.PodcastItLater.Worker.check_memory_usage", + return_value=50.0, + ), + ): + shutdown_handler = ShutdownHandler() + processor = ArticleProcessor(shutdown_handler) + audio_data = processor.text_to_speech("Test content", "Test Title") + + self.assertEqual(audio_data, b"ordered-audio") + class TestMemoryEfficiency(Test.TestCase): """Test memory-efficient processing.""" |
