Spaces:
Sleeping
Sleeping
""" | |
Main TTS service class with all API endpoints. | |
""" | |
import io | |
import base64 | |
import warnings | |
from typing import Optional | |
import modal | |
from fastapi.responses import StreamingResponse, Response | |
from fastapi import HTTPException, File, UploadFile, Form | |
from .config import app, image | |
from .models import TTSRequest, TTSResponse, HealthResponse, FullTextTTSRequest, FullTextTTSResponse | |
from .audio_utils import AudioUtils | |
from .text_processing import TextChunker | |
from .audio_concatenator import AudioConcatenator | |
with image.imports(): | |
from chatterbox.tts import ChatterboxTTS | |
import torch # Add torch import here | |
# Suppress specific transformers deprecation warnings | |
warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning) | |
class ChatterboxTTSService: | |
""" | |
Advanced text-to-speech service using Chatterbox TTS model. | |
Provides multiple endpoints for different use cases including | |
voice cloning, file uploads, and JSON responses. | |
""" | |
def load(self): | |
"""Load the Chatterbox TTS model on container startup.""" | |
print("Loading Chatterbox TTS model...") | |
# Suppress transformers deprecation warnings | |
warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning) | |
warnings.filterwarnings("ignore", message=".*tuple of tuples.*", category=FutureWarning) | |
self.model = ChatterboxTTS.from_pretrained(device="cuda") | |
print(f"Model loaded successfully! Sample rate: {self.model.sr}") | |
def _validate_text_input(self, text: str) -> None: | |
"""Validate text input parameters.""" | |
if not text or len(text.strip()) == 0: | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
def _process_voice_prompt(self, voice_prompt_base64: Optional[str]) -> Optional[str]: | |
"""Process base64 encoded voice prompt and return temp file path.""" | |
if not voice_prompt_base64: | |
return None | |
try: | |
audio_data = base64.b64decode(voice_prompt_base64) | |
return AudioUtils.save_temp_audio_file(audio_data) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Invalid voice prompt audio: {str(e)}") | |
def _generate_audio(self, text: str, audio_prompt_path: Optional[str] = None): | |
"""Generate audio with optional voice prompt.""" | |
print(f"Generating audio for text: {text[:50]}...") | |
try: | |
if audio_prompt_path: | |
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path) | |
AudioUtils.cleanup_temp_file(audio_prompt_path) | |
else: | |
wav = self.model.generate(text) | |
return wav | |
except Exception as e: | |
if audio_prompt_path: | |
AudioUtils.cleanup_temp_file(audio_prompt_path) | |
raise e | |
def health(self) -> HealthResponse: | |
"""Health check endpoint to verify model status.""" | |
return HealthResponse( | |
status="healthy", | |
model_loaded=hasattr(self, 'model') and self.model is not None | |
) | |
def generate_audio(self, request: TTSRequest) -> StreamingResponse: | |
""" | |
Generate speech audio from text with optional voice prompt. | |
Args: | |
request: TTSRequest containing text and optional voice prompt | |
Returns: | |
StreamingResponse with generated audio as WAV file | |
""" | |
try: | |
self._validate_text_input(request.text) | |
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64) | |
# Generate audio | |
wav = self._generate_audio(request.text, audio_prompt_path) | |
# Create audio buffer | |
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr) | |
return StreamingResponse( | |
io.BytesIO(buffer.read()), | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": "attachment; filename=generated_speech.wav", | |
"X-Audio-Duration": str(len(wav[0]) / self.model.sr) | |
} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
print(f"Error generating audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}") | |
def generate_with_file( | |
self, | |
text: str = Form(..., description="Text to convert to speech"), | |
voice_prompt: Optional[UploadFile] = File(None, description="Optional voice prompt audio file") | |
) -> StreamingResponse: | |
""" | |
Generate speech audio from text with optional voice prompt file upload. | |
Args: | |
text: Text to convert to speech | |
voice_prompt: Optional audio file for voice cloning | |
Returns: | |
StreamingResponse with generated audio as WAV file | |
""" | |
try: | |
self._validate_text_input(text) | |
# Handle voice prompt file if provided | |
audio_prompt_path = None | |
if voice_prompt: | |
if voice_prompt.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3"]: | |
raise HTTPException( | |
status_code=400, | |
detail="Voice prompt must be WAV, MP3, or MPEG audio file" | |
) | |
# Read and save the uploaded file | |
audio_data = voice_prompt.file.read() | |
audio_prompt_path = AudioUtils.save_temp_audio_file(audio_data) | |
# Generate audio | |
wav = self._generate_audio(text, audio_prompt_path) | |
# Create audio buffer | |
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr) | |
return StreamingResponse( | |
io.BytesIO(buffer.read()), | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": "attachment; filename=generated_speech.wav", | |
"X-Audio-Duration": str(len(wav[0]) / self.model.sr) | |
} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
print(f"Error generating audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}") | |
def generate_json(self, request: TTSRequest) -> TTSResponse: | |
""" | |
Generate speech audio and return as JSON with base64 encoded audio. | |
Args: | |
request: TTSRequest containing text and optional voice prompt | |
Returns: | |
TTSResponse with base64 encoded audio data | |
""" | |
try: | |
self._validate_text_input(request.text) | |
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64) | |
# Generate audio | |
wav = self._generate_audio(request.text, audio_prompt_path) | |
# Convert to base64 | |
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr) | |
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8') | |
duration = len(wav[0]) / self.model.sr | |
return TTSResponse( | |
success=True, | |
message="Audio generated successfully", | |
audio_base64=audio_base64, | |
duration_seconds=duration | |
) | |
except HTTPException as http_exc: | |
return TTSResponse(success=False, message=str(http_exc.detail)) | |
except Exception as e: | |
print(f"Error generating audio: {str(e)}") | |
return TTSResponse(success=False, message=f"Audio generation failed: {str(e)}") | |
def generate(self, prompt: str): | |
""" | |
Legacy endpoint for backward compatibility. | |
Generate audio waveform from the input text. | |
""" | |
try: | |
# Generate audio waveform from the input text | |
wav = self.model.generate(prompt) | |
# Create audio buffer | |
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr) | |
# Return the audio as a streaming response with appropriate MIME type. | |
return StreamingResponse( | |
io.BytesIO(buffer.read()), | |
media_type="audio/wav", | |
) | |
except Exception as e: | |
print(f"Error in legacy endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}") | |
def generate_audio_file(self, request: TTSRequest) -> Response: | |
""" | |
Generate speech audio from text with optional voice prompt and return as a complete file. | |
Unlike the streaming endpoint, this returns the entire file at once. | |
Args: | |
request: TTSRequest containing text and optional voice prompt | |
Returns: | |
Response with complete audio file data | |
""" | |
try: | |
self._validate_text_input(request.text) | |
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64) | |
# Generate audio | |
wav = self._generate_audio(request.text, audio_prompt_path) | |
# Create audio buffer | |
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr) | |
audio_data = buffer.read() | |
duration = len(wav[0]) / self.model.sr | |
# Return the complete audio file | |
return Response( | |
content=audio_data, | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": "attachment; filename=generated_speech.wav", | |
"X-Audio-Duration": str(duration) | |
} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
print(f"Error generating audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}") | |
def generate_full_text_audio(self, request: FullTextTTSRequest) -> StreamingResponse: | |
""" | |
Generate speech audio from full text with server-side chunking and parallel processing. | |
This endpoint handles texts of any length by: | |
1. Chunking the text intelligently (respecting sentence/paragraph boundaries) | |
2. Processing chunks in parallel using GPU resources | |
3. Concatenating audio chunks with proper transitions | |
4. Returning the final audio file | |
Args: | |
request: FullTextTTSRequest containing text and processing parameters | |
Returns: | |
StreamingResponse with final concatenated audio as WAV file | |
""" | |
try: | |
self._validate_text_input(request.text) | |
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64) | |
print(f"Processing full text ({len(request.text)} chars) with server-side chunking...") | |
# Initialize text chunker with request parameters | |
chunker = TextChunker( | |
max_chunk_size=request.max_chunk_size, | |
overlap_sentences=request.overlap_sentences | |
) | |
# Chunk the text | |
text_chunks = chunker.chunk_text(request.text) | |
chunk_info = chunker.get_chunk_info(text_chunks) | |
print(f"Split text into {len(text_chunks)} chunks for processing") | |
# Initialize audio_chunks variable for processing info | |
audio_chunks = [] | |
# If only one chunk, process directly | |
if len(text_chunks) == 1: | |
wav = self._generate_audio(text_chunks[0], audio_prompt_path) | |
# For single chunk, pass the full wav object to maintain consistency | |
final_audio = wav | |
audio_chunks = [wav] # For consistent processing info | |
else: | |
# Process chunks in parallel | |
import concurrent.futures | |
import numpy as np | |
def process_chunk(chunk_text: str): | |
"""Process a single chunk.""" | |
wav_result = self._generate_audio(chunk_text, audio_prompt_path) | |
# Return the full wav result, not just wav[0] | |
return wav_result | |
# Use ThreadPoolExecutor for parallel processing | |
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: | |
# Submit all chunks for processing | |
future_to_chunk = { | |
executor.submit(process_chunk, chunk): i | |
for i, chunk in enumerate(text_chunks) | |
} | |
# Collect results in order | |
results = [None] * len(text_chunks) | |
for future in concurrent.futures.as_completed(future_to_chunk): | |
chunk_index = future_to_chunk[future] | |
try: | |
audio_result = future.result() | |
results[chunk_index] = audio_result | |
except Exception as exc: | |
print(f'Chunk {chunk_index} generated an exception: {exc}') | |
raise HTTPException(status_code=500, detail=f"Failed to process chunk {chunk_index}: {str(exc)}") | |
# Filter out None results | |
audio_chunks = [result for result in results if result is not None] | |
if len(audio_chunks) != len(text_chunks): | |
raise HTTPException(status_code=500, detail=f"Only {len(audio_chunks)} out of {len(text_chunks)} chunks processed successfully") | |
# Concatenate audio chunks | |
print("Concatenating audio chunks...") | |
concatenator = AudioConcatenator( | |
silence_duration=request.silence_duration, | |
fade_duration=request.fade_duration | |
) | |
final_audio = concatenator.concatenate_audio_chunks(audio_chunks, self.model.sr) | |
# --- Start of new audio processing logic --- | |
import torch | |
import numpy as np | |
processed_tensor = final_audio | |
# Unwrap if it's a single-element tuple repeatedly | |
while isinstance(processed_tensor, tuple) and len(processed_tensor) == 1: | |
processed_tensor = processed_tensor[0] | |
# Convert to PyTorch tensor if it's a NumPy array | |
if isinstance(processed_tensor, np.ndarray): | |
processed_tensor = torch.from_numpy(processed_tensor.astype(np.float32)) | |
if not isinstance(processed_tensor, torch.Tensor): # Check if it's a tensor now | |
raise TypeError(f"Audio data after concatenation is not a tensor. Got type: {type(processed_tensor)}") | |
# Ensure correct shape (C, L) for torchaudio.save | |
if processed_tensor.ndim == 1: # Shape (L,) | |
audio_to_save = processed_tensor.unsqueeze(0) # Convert to (1, L) | |
elif processed_tensor.ndim == 2: # Shape (C, L) | |
if processed_tensor.shape[0] == 0: | |
raise ValueError(f"Audio tensor has 0 channels: {processed_tensor.shape}") | |
if processed_tensor.shape[0] > 1: # If C > 1 (stereo/multi-channel) | |
print(f"Multi-channel audio (shape {processed_tensor.shape}) detected. Taking the first channel.") | |
audio_to_save = processed_tensor[0, :].unsqueeze(0) # Result is (1, L) | |
else: # Already (1, L) | |
audio_to_save = processed_tensor | |
else: | |
raise ValueError(f"Unexpected audio tensor dimensions: {processed_tensor.ndim}, shape: {processed_tensor.shape}") | |
buffer = AudioUtils.save_audio_to_buffer(audio_to_save, self.model.sr) | |
duration = audio_to_save.shape[1] / self.model.sr # Use shape[1] for length | |
# Reset buffer position for reading | |
buffer.seek(0) | |
# --- End of new audio processing logic --- # Prepare processing info | |
processing_info = { | |
"total_chunks": len(text_chunks), | |
"processed_chunks": len(audio_chunks), | |
"failed_chunks": len(text_chunks) - len(audio_chunks), | |
"sample_rate": self.model.sr, | |
"duration": duration | |
} | |
print(f"Full text processing complete! Final audio duration: {duration:.2f} seconds") | |
return StreamingResponse( | |
buffer, | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": "attachment; filename=generated_full_text_speech.wav", | |
"X-Audio-Duration": str(duration), | |
"X-Chunks-Processed": str(len(audio_chunks)), | |
"X-Total-Characters": str(len(request.text)) | |
} | |
) | |
except HTTPException as http_exc: | |
print(f"HTTP exception in full text generation: {http_exc.detail}") | |
raise http_exc | |
except Exception as e: | |
error_msg = f"Full text audio generation failed: {str(e)}" | |
print(f"Exception in full text generation: {error_msg}") | |
raise HTTPException(status_code=500, detail=error_msg) | |