spagestic commited on
Commit
06a06a0
·
1 Parent(s): e37b0d2

api docs and code added

Browse files
Files changed (7) hide show
  1. api/README.md +61 -0
  2. api/__init__.py +22 -0
  3. api/audio_utils.py +62 -0
  4. api/config.py +23 -0
  5. api/models.py +27 -0
  6. api/tts_service.py +278 -0
  7. requirements.txt +1 -0
api/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Package
2
+
3
+ This package contains the modular components of the Chatterbox TTS API.
4
+
5
+ ## Structure
6
+
7
+ ```
8
+ api/
9
+ ├── __init__.py # Package initialization and exports
10
+ ├── config.py # Modal app configuration and container image setup
11
+ ├── models.py # Pydantic request/response models
12
+ ├── audio_utils.py # Audio processing utilities and helper functions
13
+ ├── tts_service.py # Main TTS service class with all API endpoints
14
+ └── README.md # This file
15
+ ```
16
+
17
+ ## Components
18
+
19
+ ### config.py
20
+
21
+ - Modal app configuration
22
+ - Container image setup with required dependencies
23
+ - Centralized configuration management
24
+
25
+ ### models.py
26
+
27
+ - `TTSRequest`: Request model for TTS generation
28
+ - `TTSResponse`: Response model for JSON endpoints
29
+ - `HealthResponse`: Response model for health checks
30
+ - All models include proper type hints and documentation
31
+
32
+ ### audio_utils.py
33
+
34
+ - `AudioUtils`: Static utility class for audio operations
35
+ - Buffer management for audio data
36
+ - Temporary file handling with automatic cleanup
37
+ - Reusable audio processing functions
38
+
39
+ ### tts_service.py
40
+
41
+ - `ChatterboxTTSService`: Main service class with all endpoints
42
+ - GPU-accelerated TTS model loading and inference
43
+ - Multiple API endpoints for different use cases
44
+ - Comprehensive error handling and validation
45
+
46
+ ## Usage
47
+
48
+ ```python
49
+ from api import app, ChatterboxTTSService
50
+
51
+ # The app is automatically configured and ready to deploy
52
+ # The service class contains all the endpoints
53
+ ```
54
+
55
+ ## Benefits of Modular Architecture
56
+
57
+ 1. **Separation of Concerns**: Each file has a specific responsibility
58
+ 2. **Maintainability**: Easier to update and modify individual components
59
+ 3. **Testability**: Components can be tested in isolation
60
+ 4. **Reusability**: Components can be imported and used in other projects
61
+ 5. **Readability**: Smaller files are easier to understand and navigate
api/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Chatterbox TTS API package.
4
+
5
+ This package provides a modular text-to-speech API using the Chatterbox TTS model
6
+ deployed on Modal with GPU acceleration.
7
+ """
8
+
9
+ from .config import app, image
10
+ from .models import TTSRequest, TTSResponse, HealthResponse
11
+ from .audio_utils import AudioUtils
12
+ from .tts_service import ChatterboxTTSService
13
+
14
+ __all__ = [
15
+ "app",
16
+ "image",
17
+ "TTSRequest",
18
+ "TTSResponse",
19
+ "HealthResponse",
20
+ "AudioUtils",
21
+ "ChatterboxTTSService"
22
+ ]
api/audio_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Audio processing utilities for TTS service.
4
+ """
5
+
6
+ import io
7
+ import tempfile
8
+ import os
9
+ from .config import image
10
+
11
+ with image.imports():
12
+ import torchaudio as ta
13
+
14
+
15
+ class AudioUtils:
16
+ """Helper class for audio processing operations."""
17
+
18
+ @staticmethod
19
+ def save_audio_to_buffer(wav_tensor, sample_rate: int) -> io.BytesIO:
20
+ """
21
+ Save audio tensor to BytesIO buffer.
22
+
23
+ Args:
24
+ wav_tensor: Audio tensor to save
25
+ sample_rate: Sample rate of the audio
26
+
27
+ Returns:
28
+ BytesIO buffer containing WAV audio data
29
+ """
30
+ buffer = io.BytesIO()
31
+ ta.save(buffer, wav_tensor, sample_rate, format="wav")
32
+ buffer.seek(0)
33
+ return buffer
34
+
35
+ @staticmethod
36
+ def save_temp_audio_file(audio_data: bytes) -> str:
37
+ """
38
+ Save uploaded audio data to a temporary file.
39
+
40
+ Args:
41
+ audio_data: Raw audio data bytes
42
+
43
+ Returns:
44
+ Path to the temporary audio file
45
+ """
46
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
47
+ temp_file.write(audio_data)
48
+ return temp_file.name
49
+
50
+ @staticmethod
51
+ def cleanup_temp_file(file_path: str) -> None:
52
+ """
53
+ Clean up temporary audio file.
54
+
55
+ Args:
56
+ file_path: Path to the temporary file to delete
57
+ """
58
+ try:
59
+ if file_path and os.path.exists(file_path):
60
+ os.unlink(file_path)
61
+ except Exception as e:
62
+ print(f"Warning: Failed to cleanup temp file {file_path}: {e}")
api/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Modal app configuration and container image setup.
4
+ """
5
+
6
+ import modal
7
+
8
+ # Define a container image with required dependencies
9
+ image = modal.Image.debian_slim(python_version="3.12").pip_install(
10
+ "chatterbox-tts==0.1.1",
11
+ "fastapi[standard]",
12
+ "pydantic",
13
+ "numpy",
14
+ "transformers>=4.45.0,<4.47.0", # Pin to avoid deprecation warnings
15
+ "torch>=2.0.0",
16
+ "torchaudio>=2.0.0"
17
+ ).env({
18
+ # Suppress the specific transformers deprecation warning
19
+ "PYTHONWARNINGS": "ignore::FutureWarning:transformers"
20
+ })
21
+
22
+ # Create the Modal app
23
+ app = modal.App("chatterbox-api-example", image=image)
api/models.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Pydantic models for request/response validation and API documentation.
4
+ """
5
+
6
+ from typing import Optional
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class TTSRequest(BaseModel):
11
+ """Request model for TTS generation with optional voice cloning."""
12
+ text: str
13
+ voice_prompt_base64: Optional[str] = None # Base64 encoded audio file
14
+
15
+
16
+ class TTSResponse(BaseModel):
17
+ """Response model for TTS generation with JSON output."""
18
+ success: bool
19
+ message: str
20
+ audio_base64: Optional[str] = None # Base64 encoded audio response
21
+ duration_seconds: Optional[float] = None
22
+
23
+
24
+ class HealthResponse(BaseModel):
25
+ """Response model for health check endpoint."""
26
+ status: str
27
+ model_loaded: bool
api/tts_service.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main TTS service class with all API endpoints.
3
+ """
4
+
5
+ import io
6
+ import base64
7
+ import warnings
8
+ from typing import Optional
9
+
10
+ import modal
11
+ from fastapi.responses import StreamingResponse, Response
12
+ from fastapi import HTTPException, File, UploadFile, Form
13
+
14
+ from .config import app, image
15
+ from .models import TTSRequest, TTSResponse, HealthResponse
16
+ from .audio_utils import AudioUtils
17
+
18
+ with image.imports():
19
+ from chatterbox.tts import ChatterboxTTS
20
+ # Suppress specific transformers deprecation warnings
21
+ warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning)
22
+
23
+
24
+ @app.cls(
25
+ gpu="a10g",
26
+ scaledown_window=60 * 5,
27
+ enable_memory_snapshot=True
28
+ )
29
+ @modal.concurrent(
30
+ max_inputs=10
31
+ )
32
+ class ChatterboxTTSService:
33
+ """
34
+ Advanced text-to-speech service using Chatterbox TTS model.
35
+
36
+ Provides multiple endpoints for different use cases including
37
+ voice cloning, file uploads, and JSON responses.
38
+ """
39
+
40
+ @modal.enter()
41
+ def load(self):
42
+ """Load the Chatterbox TTS model on container startup."""
43
+ print("Loading Chatterbox TTS model...")
44
+
45
+ # Suppress transformers deprecation warnings
46
+ warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning)
47
+ warnings.filterwarnings("ignore", message=".*tuple of tuples.*", category=FutureWarning)
48
+
49
+ self.model = ChatterboxTTS.from_pretrained(device="cuda")
50
+ print(f"Model loaded successfully! Sample rate: {self.model.sr}")
51
+
52
+ def _validate_text_input(self, text: str) -> None:
53
+ """Validate text input parameters."""
54
+ if not text or len(text.strip()) == 0:
55
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
56
+
57
+ def _process_voice_prompt(self, voice_prompt_base64: Optional[str]) -> Optional[str]:
58
+ """Process base64 encoded voice prompt and return temp file path."""
59
+ if not voice_prompt_base64:
60
+ return None
61
+
62
+ try:
63
+ audio_data = base64.b64decode(voice_prompt_base64)
64
+ return AudioUtils.save_temp_audio_file(audio_data)
65
+ except Exception as e:
66
+ raise HTTPException(status_code=400, detail=f"Invalid voice prompt audio: {str(e)}")
67
+
68
+ def _generate_audio(self, text: str, audio_prompt_path: Optional[str] = None):
69
+ """Generate audio with optional voice prompt."""
70
+ print(f"Generating audio for text: {text[:50]}...")
71
+
72
+ try:
73
+ if audio_prompt_path:
74
+ wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
75
+ AudioUtils.cleanup_temp_file(audio_prompt_path)
76
+ else:
77
+ wav = self.model.generate(text)
78
+ return wav
79
+ except Exception as e:
80
+ if audio_prompt_path:
81
+ AudioUtils.cleanup_temp_file(audio_prompt_path)
82
+ raise e
83
+
84
+ @modal.fastapi_endpoint(docs=True, method="GET")
85
+ def health(self) -> HealthResponse:
86
+ """Health check endpoint to verify model status."""
87
+ return HealthResponse(
88
+ status="healthy",
89
+ model_loaded=hasattr(self, 'model') and self.model is not None
90
+ )
91
+
92
+ @modal.fastapi_endpoint(docs=True, method="POST")
93
+ def generate_audio(self, request: TTSRequest) -> StreamingResponse:
94
+ """
95
+ Generate speech audio from text with optional voice prompt.
96
+
97
+ Args:
98
+ request: TTSRequest containing text and optional voice prompt
99
+
100
+ Returns:
101
+ StreamingResponse with generated audio as WAV file
102
+ """
103
+ try:
104
+ self._validate_text_input(request.text)
105
+ audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
106
+
107
+ # Generate audio
108
+ wav = self._generate_audio(request.text, audio_prompt_path)
109
+
110
+ # Create audio buffer
111
+ buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
112
+
113
+ return StreamingResponse(
114
+ io.BytesIO(buffer.read()),
115
+ media_type="audio/wav",
116
+ headers={
117
+ "Content-Disposition": "attachment; filename=generated_speech.wav",
118
+ "X-Audio-Duration": str(len(wav[0]) / self.model.sr)
119
+ }
120
+ )
121
+
122
+ except HTTPException:
123
+ raise
124
+ except Exception as e:
125
+ print(f"Error generating audio: {str(e)}")
126
+ raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
127
+
128
+ @modal.fastapi_endpoint(docs=True, method="POST")
129
+ def generate_with_file(
130
+ self,
131
+ text: str = Form(..., description="Text to convert to speech"),
132
+ voice_prompt: Optional[UploadFile] = File(None, description="Optional voice prompt audio file")
133
+ ) -> StreamingResponse:
134
+ """
135
+ Generate speech audio from text with optional voice prompt file upload.
136
+
137
+ Args:
138
+ text: Text to convert to speech
139
+ voice_prompt: Optional audio file for voice cloning
140
+
141
+ Returns:
142
+ StreamingResponse with generated audio as WAV file
143
+ """
144
+ try:
145
+ self._validate_text_input(text)
146
+
147
+ # Handle voice prompt file if provided
148
+ audio_prompt_path = None
149
+ if voice_prompt:
150
+ if voice_prompt.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3"]:
151
+ raise HTTPException(
152
+ status_code=400,
153
+ detail="Voice prompt must be WAV, MP3, or MPEG audio file"
154
+ )
155
+
156
+ # Read and save the uploaded file
157
+ audio_data = voice_prompt.file.read()
158
+ audio_prompt_path = AudioUtils.save_temp_audio_file(audio_data)
159
+
160
+ # Generate audio
161
+ wav = self._generate_audio(text, audio_prompt_path)
162
+
163
+ # Create audio buffer
164
+ buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
165
+
166
+ return StreamingResponse(
167
+ io.BytesIO(buffer.read()),
168
+ media_type="audio/wav",
169
+ headers={
170
+ "Content-Disposition": "attachment; filename=generated_speech.wav",
171
+ "X-Audio-Duration": str(len(wav[0]) / self.model.sr)
172
+ }
173
+ )
174
+
175
+ except HTTPException:
176
+ raise
177
+ except Exception as e:
178
+ print(f"Error generating audio: {str(e)}")
179
+ raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
180
+
181
+ @modal.fastapi_endpoint(docs=True, method="POST")
182
+ def generate_json(self, request: TTSRequest) -> TTSResponse:
183
+ """
184
+ Generate speech audio and return as JSON with base64 encoded audio.
185
+
186
+ Args:
187
+ request: TTSRequest containing text and optional voice prompt
188
+
189
+ Returns:
190
+ TTSResponse with base64 encoded audio data
191
+ """
192
+ try:
193
+ self._validate_text_input(request.text)
194
+ audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
195
+
196
+ # Generate audio
197
+ wav = self._generate_audio(request.text, audio_prompt_path)
198
+
199
+ # Convert to base64
200
+ buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
201
+ audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
202
+ duration = len(wav[0]) / self.model.sr
203
+
204
+ return TTSResponse(
205
+ success=True,
206
+ message="Audio generated successfully",
207
+ audio_base64=audio_base64,
208
+ duration_seconds=duration
209
+ )
210
+
211
+ except HTTPException as http_exc:
212
+ return TTSResponse(success=False, message=str(http_exc.detail))
213
+ except Exception as e:
214
+ print(f"Error generating audio: {str(e)}")
215
+ return TTSResponse(success=False, message=f"Audio generation failed: {str(e)}")
216
+
217
+ @modal.fastapi_endpoint(docs=True, method="POST")
218
+ def generate(self, prompt: str):
219
+ """
220
+ Legacy endpoint for backward compatibility.
221
+ Generate audio waveform from the input text.
222
+ """
223
+ try:
224
+ # Generate audio waveform from the input text
225
+ wav = self.model.generate(prompt)
226
+
227
+ # Create audio buffer
228
+ buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
229
+
230
+ # Return the audio as a streaming response with appropriate MIME type.
231
+ return StreamingResponse(
232
+ io.BytesIO(buffer.read()),
233
+ media_type="audio/wav",
234
+ )
235
+ except Exception as e:
236
+ print(f"Error in legacy endpoint: {str(e)}")
237
+ raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
238
+
239
+ @modal.fastapi_endpoint(docs=True, method="POST")
240
+ def generate_audio_file(self, request: TTSRequest) -> Response:
241
+ """
242
+ Generate speech audio from text with optional voice prompt and return as a complete file.
243
+
244
+ Unlike the streaming endpoint, this returns the entire file at once.
245
+
246
+ Args:
247
+ request: TTSRequest containing text and optional voice prompt
248
+
249
+ Returns:
250
+ Response with complete audio file data
251
+ """
252
+ try:
253
+ self._validate_text_input(request.text)
254
+ audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
255
+
256
+ # Generate audio
257
+ wav = self._generate_audio(request.text, audio_prompt_path)
258
+
259
+ # Create audio buffer
260
+ buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
261
+ audio_data = buffer.read()
262
+ duration = len(wav[0]) / self.model.sr
263
+
264
+ # Return the complete audio file
265
+ return Response(
266
+ content=audio_data,
267
+ media_type="audio/wav",
268
+ headers={
269
+ "Content-Disposition": "attachment; filename=generated_speech.wav",
270
+ "X-Audio-Duration": str(duration)
271
+ }
272
+ )
273
+
274
+ except HTTPException:
275
+ raise
276
+ except Exception as e:
277
+ print(f"Error generating audio: {str(e)}")
278
+ raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
requirements.txt CHANGED
@@ -24,6 +24,7 @@ Jinja2==3.1.6
24
  markdown-it-py==3.0.0
25
  mdurl==0.1.2
26
  mistralai==1.8.1
 
27
  numpy==2.2.6
28
  orjson==3.10.18
29
  packaging==25.0
 
24
  markdown-it-py==3.0.0
25
  mdurl==0.1.2
26
  mistralai==1.8.1
27
+ modal==1.0.3
28
  numpy==2.2.6
29
  orjson==3.10.18
30
  packaging==25.0