summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2025-09-06 13:56:13 -0400
committerBen Sima (aider) <ben@bsima.me>2025-09-06 13:56:13 -0400
commitce6d313edbf5c545d16d88d28be867122b7c3d1b (patch)
tree7a19edf204f5bcbaea28deecfee56e8d129e9d41
parente0a2f7ec3e21891784c874ab6c90953bc4fedc19 (diff)
Implement Parallel TTS Processing with Robust Error Handling
-rw-r--r--Biz/PodcastItLater/Worker.py458
1 files changed, 392 insertions, 66 deletions
diff --git a/Biz/PodcastItLater/Worker.py b/Biz/PodcastItLater/Worker.py
index 4dfbfe2..2807e8a 100644
--- a/Biz/PodcastItLater/Worker.py
+++ b/Biz/PodcastItLater/Worker.py
@@ -13,12 +13,14 @@
# : run ffmpeg
import Biz.PodcastItLater.Core as Core
import boto3 # type: ignore[import-untyped]
+import concurrent.futures
import io
import json
import Omni.App as App
import Omni.Log as Log
import Omni.Test as Test
import openai
+import operator
import os
import psutil # type: ignore[import-untyped]
import pytest
@@ -185,7 +187,7 @@ class ArticleProcessor:
def text_to_speech(self, text: str, title: str) -> bytes:
"""Convert text to speech using OpenAI TTS API.
- Uses streaming approach to maintain constant memory usage.
+ Uses parallel processing for chunks while maintaining order.
Raises:
ValueError: If no chunks are generated from text.
@@ -200,81 +202,205 @@ class ArticleProcessor:
logger.info("Processing %d chunks for TTS", len(chunks))
- # Create a temporary file for streaming audio concatenation
- with tempfile.NamedTemporaryFile(
- suffix=".mp3",
- delete=False,
- ) as temp_file:
- temp_path = temp_file.name
+ # Check memory before parallel processing
+ mem_usage = check_memory_usage()
+ if mem_usage > MEMORY_THRESHOLD - 20: # Leave 20% buffer
+ logger.warning(
+ "High memory usage (%.1f%%), falling back to serial "
+ "processing",
+ mem_usage,
+ )
+ return self._text_to_speech_serial(chunks)
+
+ # Determine max workers based on chunk count and system resources
+ max_workers = min(
+ 4, # Reasonable limit to avoid rate limiting
+ len(chunks), # No more workers than chunks
+ max(1, psutil.cpu_count() // 2), # Use half of CPU cores
+ )
+
+ logger.info(
+ "Using %d workers for parallel TTS processing",
+ max_workers,
+ )
+
+ # Process chunks in parallel
+ chunk_results: list[tuple[int, bytes]] = []
+
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=max_workers,
+ ) as executor:
+ # Submit all chunks for processing
+ future_to_index = {
+ executor.submit(self._process_tts_chunk, chunk, i): i
+ for i, chunk in enumerate(chunks)
+ }
+
+ # Collect results as they complete
+ for future in concurrent.futures.as_completed(future_to_index):
+ index = future_to_index[future]
+ try:
+ audio_data = future.result()
+ chunk_results.append((index, audio_data))
+ except Exception:
+ logger.exception("Failed to process chunk %d", index)
+ raise
+
+ # Sort results by index to maintain order
+ chunk_results.sort(key=operator.itemgetter(0))
+
+ # Combine audio chunks
+ return self._combine_audio_chunks([
+ data for _, data in chunk_results
+ ])
+
+ except Exception:
+ logger.exception("TTS generation failed")
+ raise
+
+ def _process_tts_chunk(self, chunk: str, index: int) -> bytes:
+ """Process a single TTS chunk.
+
+ Args:
+ chunk: Text to convert to speech
+ index: Chunk index for logging
+
+ Returns:
+ Audio data as bytes
+ """
+ logger.info(
+ "Generating TTS for chunk %d (%d chars)",
+ index + 1,
+ len(chunk),
+ )
+
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunk,
+ response_format="mp3",
+ )
+
+ return response.content
+
+ @staticmethod
+ def _combine_audio_chunks(audio_chunks: list[bytes]) -> bytes:
+ """Combine multiple audio chunks with silence gaps.
+
+ Args:
+ audio_chunks: List of audio data in order
+
+ Returns:
+ Combined audio data
+ """
+ if not audio_chunks:
+ msg = "No audio chunks to combine"
+ raise ValueError(msg)
+
+ # Create a temporary file for the combined audio
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+
+ try:
+ # Start with the first chunk
+ combined_audio = AudioSegment.from_mp3(io.BytesIO(audio_chunks[0]))
+
+ # Add remaining chunks with silence gaps
+ for chunk_data in audio_chunks[1:]:
+ chunk_audio = AudioSegment.from_mp3(io.BytesIO(chunk_data))
+ silence = AudioSegment.silent(duration=300) # 300ms gap
+ combined_audio = combined_audio + silence + chunk_audio
+
+ # Export to file
+ combined_audio.export(temp_path, format="mp3", bitrate="128k")
+
+ # Read back the combined audio
+ return Path(temp_path).read_bytes()
+
+ finally:
+ # Clean up temp file
+ if Path(temp_path).exists():
+ Path(temp_path).unlink()
+
+ def _text_to_speech_serial(self, chunks: list[str]) -> bytes:
+ """Fallback serial processing for high memory situations.
+
+ This is the original serial implementation.
+ """
+ # Create a temporary file for streaming audio concatenation
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp3",
+ delete=False,
+ ) as temp_file:
+ temp_path = temp_file.name
+
+ try:
+ # Process first chunk
+ logger.info("Generating TTS for chunk 1/%d", len(chunks))
+ response = self.openai_client.audio.speech.create(
+ model=TTS_MODEL,
+ voice=TTS_VOICE,
+ input=chunks[0],
+ response_format="mp3",
+ )
+
+ # Write first chunk directly to file
+ Path(temp_path).write_bytes(response.content)
+
+ # Process remaining chunks
+ for i, chunk in enumerate(chunks[1:], 1):
+ logger.info(
+ "Generating TTS for chunk %d/%d (%d chars)",
+ i + 1,
+ len(chunks),
+ len(chunk),
+ )
- try:
- # Process first chunk
- logger.info("Generating TTS for chunk 1/%d", len(chunks))
response = self.openai_client.audio.speech.create(
model=TTS_MODEL,
voice=TTS_VOICE,
- input=chunks[0],
+ input=chunk,
response_format="mp3",
)
- # Write first chunk directly to file
- Path(temp_path).write_bytes(response.content)
-
- # Process remaining chunks
- for i, chunk in enumerate(chunks[1:], 1):
- logger.info(
- "Generating TTS for chunk %d/%d (%d chars)",
- i + 1,
- len(chunks),
- len(chunk),
- )
-
- response = self.openai_client.audio.speech.create(
- model=TTS_MODEL,
- voice=TTS_VOICE,
- input=chunk,
- response_format="mp3",
- )
-
- # Append to existing file with silence gap
- # Load only the current segment
- current_segment = AudioSegment.from_mp3(
- io.BytesIO(response.content),
- )
-
- # Load existing audio, append, and save back
- existing_audio = AudioSegment.from_mp3(temp_path)
- silence = AudioSegment.silent(duration=300)
- combined = existing_audio + silence + current_segment
-
- # Export back to the same file
- combined.export(temp_path, format="mp3", bitrate="128k")
-
- # Force garbage collection to free memory
- del existing_audio, current_segment, combined
-
- # Small delay between API calls
- if i < len(chunks) - 1:
- time.sleep(0.5)
-
- # Read final result
- audio_data = Path(temp_path).read_bytes()
-
- logger.info(
- "Generated combined TTS audio: %d bytes",
- len(audio_data),
+ # Append to existing file with silence gap
+ # Load only the current segment
+ current_segment = AudioSegment.from_mp3(
+ io.BytesIO(response.content),
)
- return audio_data
- finally:
- # Clean up temp file
- temp_file_path = Path(temp_path)
- if temp_file_path.exists():
- temp_file_path.unlink()
+ # Load existing audio, append, and save back
+ existing_audio = AudioSegment.from_mp3(temp_path)
+ silence = AudioSegment.silent(duration=300)
+ combined = existing_audio + silence + current_segment
- except Exception:
- logger.exception("TTS generation failed")
- raise
+ # Export back to the same file
+ combined.export(temp_path, format="mp3", bitrate="128k")
+
+ # Force garbage collection to free memory
+ del existing_audio, current_segment, combined
+
+ # Small delay between API calls
+ if i < len(chunks) - 1:
+ time.sleep(0.5)
+
+ # Read final result
+ audio_data = Path(temp_path).read_bytes()
+
+ logger.info(
+ "Generated combined TTS audio: %d bytes",
+ len(audio_data),
+ )
+ return audio_data
+
+ finally:
+ # Clean up temp file
+ temp_file_path = Path(temp_path)
+ if temp_file_path.exists():
+ temp_file_path.unlink()
def upload_to_s3(self, audio_data: bytes, filename: str) -> str:
"""Upload audio file to S3-compatible storage and return public URL.
@@ -1117,6 +1243,206 @@ class TestTextToSpeech(Test.TestCase):
self.assertIsInstance(audio_data, bytes)
self.assertEqual(audio_data, b"test-audio-output")
+ def test_parallel_tts_generation(self) -> None:
+ """Test parallel TTS processing."""
+ chunks = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4"]
+
+ # Mock responses for each chunk
+ mock_responses = []
+ for i in range(len(chunks)):
+ mock_resp = unittest.mock.MagicMock()
+ mock_resp.content = f"audio-{i}".encode()
+ mock_responses.append(mock_resp)
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+
+ # Make create return different responses for each call
+ mock_speech.create.side_effect = mock_responses
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Mock AudioSegment operations
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.__add__.return_value = mock_segment
+
+ def mock_export(path: str, **_kwargs: typing.Any) -> None:
+ Path(path).write_bytes(b"combined-audio")
+
+ mock_segment.export = mock_export
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0, # Normal memory usage
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ # Verify all chunks were processed
+ self.assertEqual(mock_speech.create.call_count, len(chunks))
+ self.assertEqual(audio_data, b"combined-audio")
+
+ def test_parallel_tts_high_memory_fallback(self) -> None:
+ """Test fallback to serial processing when memory is high."""
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ def mock_export(buffer: io.BytesIO, **_kwargs: typing.Any) -> None:
+ buffer.write(b"serial-audio")
+ buffer.seek(0)
+
+ self.mock_audio_segment.export.side_effect = mock_export
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.return_value = self.mock_audio_response
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=65.0, # High memory usage
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=self.mock_audio_segment,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ # Should use serial processing
+ self.assertEqual(audio_data, b"serial-audio")
+
+ @staticmethod
+ def test_parallel_tts_error_handling() -> None:
+ """Test error handling in parallel TTS processing."""
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ # Mock OpenAI client with one failure
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+
+ # First call succeeds, second fails
+ mock_resp1 = unittest.mock.MagicMock()
+ mock_resp1.content = b"audio-1"
+ mock_speech.create.side_effect = [mock_resp1, Exception("API Error")]
+
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Set up the test context
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ pytest.raises(Exception, match="API Error"),
+ ):
+ processor.text_to_speech("Test content", "Test Title")
+
+ def test_parallel_tts_order_preservation(self) -> None:
+ """Test that chunks are combined in the correct order."""
+ chunks = ["First", "Second", "Third", "Fourth", "Fifth"]
+
+ # Create mock responses with identifiable content
+ mock_responses = []
+ for chunk in chunks:
+ mock_resp = unittest.mock.MagicMock()
+ mock_resp.content = f"audio-{chunk}".encode()
+ mock_responses.append(mock_resp)
+
+ # Mock OpenAI client
+ mock_client = unittest.mock.MagicMock()
+ mock_audio = unittest.mock.MagicMock()
+ mock_speech = unittest.mock.MagicMock()
+ mock_speech.create.side_effect = mock_responses
+ mock_audio.speech = mock_speech
+ mock_client.audio = mock_audio
+
+ # Track the order of segments being combined
+ combined_order = []
+
+ def mock_from_mp3(data: io.BytesIO) -> unittest.mock.MagicMock:
+ content = data.read()
+ combined_order.append(content.decode())
+ segment = unittest.mock.MagicMock()
+ segment.__add__.return_value = segment
+ return segment
+
+ mock_segment = unittest.mock.MagicMock()
+ mock_segment.__add__.return_value = mock_segment
+
+ def mock_export(path: str, **_kwargs: typing.Any) -> None:
+ # Verify order is preserved
+ expected_order = [f"audio-{chunk}" for chunk in chunks]
+ if combined_order != expected_order:
+ msg = f"Order mismatch: {combined_order} != {expected_order}"
+ raise AssertionError(msg)
+ Path(path).write_bytes(b"ordered-audio")
+
+ mock_segment.export = mock_export
+
+ with (
+ unittest.mock.patch("openai.OpenAI", return_value=mock_client),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.prepare_text_for_tts",
+ return_value=chunks,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.from_mp3",
+ side_effect=mock_from_mp3,
+ ),
+ unittest.mock.patch(
+ "pydub.AudioSegment.silent",
+ return_value=mock_segment,
+ ),
+ unittest.mock.patch(
+ "Biz.PodcastItLater.Worker.check_memory_usage",
+ return_value=50.0,
+ ),
+ ):
+ shutdown_handler = ShutdownHandler()
+ processor = ArticleProcessor(shutdown_handler)
+ audio_data = processor.text_to_speech("Test content", "Test Title")
+
+ self.assertEqual(audio_data, b"ordered-audio")
+
class TestMemoryEfficiency(Test.TestCase):
"""Test memory-efficient processing."""