TaiYa1 commited on
Commit
748ab1e
·
verified ·
1 Parent(s): 048ff41

Update tools/api.py

Browse files
Files changed (1) hide show
  1. tools/api.py +440 -440
tools/api.py CHANGED
@@ -1,440 +1,440 @@
1
- import base64
2
- import io
3
- import json
4
- import queue
5
- import random
6
- import sys
7
- import traceback
8
- import wave
9
- from argparse import ArgumentParser
10
- from http import HTTPStatus
11
- from pathlib import Path
12
- from typing import Annotated, Any, Literal, Optional
13
-
14
- import numpy as np
15
- import ormsgpack
16
- import pyrootutils
17
- import soundfile as sf
18
- import torch
19
- import torchaudio
20
- from baize.datastructures import ContentType
21
- from kui.asgi import (
22
- Body,
23
- FactoryClass,
24
- HTTPException,
25
- HttpRequest,
26
- HttpView,
27
- JSONResponse,
28
- Kui,
29
- OpenAPI,
30
- StreamResponse,
31
- )
32
- from kui.asgi.routing import MultimethodRoutes
33
- from loguru import logger
34
- from pydantic import BaseModel, Field, conint
35
-
36
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
37
-
38
- # from fish_speech.models.vqgan.lit_module import VQGAN
39
- from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
40
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
41
- from fish_speech.utils import autocast_exclude_mps
42
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
43
- from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
44
- from tools.llama.generate import (
45
- GenerateRequest,
46
- GenerateResponse,
47
- WrappedGenerateResponse,
48
- launch_thread_safe_queue,
49
- )
50
- from tools.vqgan.inference import load_model as load_decoder_model
51
-
52
-
53
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
54
- buffer = io.BytesIO()
55
-
56
- with wave.open(buffer, "wb") as wav_file:
57
- wav_file.setnchannels(channels)
58
- wav_file.setsampwidth(bit_depth // 8)
59
- wav_file.setframerate(sample_rate)
60
-
61
- wav_header_bytes = buffer.getvalue()
62
- buffer.close()
63
- return wav_header_bytes
64
-
65
-
66
- # Define utils for web server
67
- async def http_execption_handler(exc: HTTPException):
68
- return JSONResponse(
69
- dict(
70
- statusCode=exc.status_code,
71
- message=exc.content,
72
- error=HTTPStatus(exc.status_code).phrase,
73
- ),
74
- exc.status_code,
75
- exc.headers,
76
- )
77
-
78
-
79
- async def other_exception_handler(exc: "Exception"):
80
- traceback.print_exc()
81
-
82
- status = HTTPStatus.INTERNAL_SERVER_ERROR
83
- return JSONResponse(
84
- dict(statusCode=status, message=str(exc), error=status.phrase),
85
- status,
86
- )
87
-
88
-
89
- def load_audio(reference_audio, sr):
90
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
91
- audio_data = reference_audio
92
- reference_audio = io.BytesIO(audio_data)
93
-
94
- waveform, original_sr = torchaudio.load(
95
- reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
96
- )
97
-
98
- if waveform.shape[0] > 1:
99
- waveform = torch.mean(waveform, dim=0, keepdim=True)
100
-
101
- if original_sr != sr:
102
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
103
- waveform = resampler(waveform)
104
-
105
- audio = waveform.squeeze().numpy()
106
- return audio
107
-
108
-
109
- def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
110
- if enable_reference_audio and reference_audio is not None:
111
- # Load audios, and prepare basic info here
112
- reference_audio_content = load_audio(
113
- reference_audio, decoder_model.spec_transform.sample_rate
114
- )
115
-
116
- audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
117
- None, None, :
118
- ]
119
- audio_lengths = torch.tensor(
120
- [audios.shape[2]], device=decoder_model.device, dtype=torch.long
121
- )
122
- logger.info(
123
- f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
124
- )
125
-
126
- # VQ Encoder
127
- if isinstance(decoder_model, FireflyArchitecture):
128
- prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
129
-
130
- logger.info(f"Encoded prompt: {prompt_tokens.shape}")
131
- else:
132
- prompt_tokens = None
133
- logger.info("No reference audio provided")
134
-
135
- return prompt_tokens
136
-
137
-
138
- def decode_vq_tokens(
139
- *,
140
- decoder_model,
141
- codes,
142
- ):
143
- feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
144
- logger.info(f"VQ features: {codes.shape}")
145
-
146
- if isinstance(decoder_model, FireflyArchitecture):
147
- # VQGAN Inference
148
- return decoder_model.decode(
149
- indices=codes[None],
150
- feature_lengths=feature_lengths,
151
- )[0].squeeze()
152
-
153
- raise ValueError(f"Unknown model type: {type(decoder_model)}")
154
-
155
-
156
- routes = MultimethodRoutes(base_class=HttpView)
157
-
158
-
159
- def get_content_type(audio_format):
160
- if audio_format == "wav":
161
- return "audio/wav"
162
- elif audio_format == "flac":
163
- return "audio/flac"
164
- elif audio_format == "mp3":
165
- return "audio/mpeg"
166
- else:
167
- return "application/octet-stream"
168
-
169
-
170
- @torch.inference_mode()
171
- def inference(req: ServeTTSRequest):
172
-
173
- idstr: str | None = req.reference_id
174
- if idstr is not None:
175
- ref_folder = Path("references") / idstr
176
- ref_folder.mkdir(parents=True, exist_ok=True)
177
- ref_audios = list_files(
178
- ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
179
- )
180
- prompt_tokens = [
181
- encode_reference(
182
- decoder_model=decoder_model,
183
- reference_audio=audio_to_bytes(str(ref_audio)),
184
- enable_reference_audio=True,
185
- )
186
- for ref_audio in ref_audios
187
- ]
188
- prompt_texts = [
189
- read_ref_text(str(ref_audio.with_suffix(".lab")))
190
- for ref_audio in ref_audios
191
- ]
192
-
193
- else:
194
- # Parse reference audio aka prompt
195
- refs = req.references
196
- if refs is None:
197
- refs = []
198
- prompt_tokens = [
199
- encode_reference(
200
- decoder_model=decoder_model,
201
- reference_audio=ref.audio,
202
- enable_reference_audio=True,
203
- )
204
- for ref in refs
205
- ]
206
- prompt_texts = [ref.text for ref in refs]
207
-
208
- # LLAMA Inference
209
- request = dict(
210
- device=decoder_model.device,
211
- max_new_tokens=req.max_new_tokens,
212
- text=(
213
- req.text
214
- if not req.normalize
215
- else ChnNormedText(raw_text=req.text).normalize()
216
- ),
217
- top_p=req.top_p,
218
- repetition_penalty=req.repetition_penalty,
219
- temperature=req.temperature,
220
- compile=args.compile,
221
- iterative_prompt=req.chunk_length > 0,
222
- chunk_length=req.chunk_length,
223
- max_length=2048,
224
- prompt_tokens=prompt_tokens,
225
- prompt_text=prompt_texts,
226
- )
227
-
228
- response_queue = queue.Queue()
229
- llama_queue.put(
230
- GenerateRequest(
231
- request=request,
232
- response_queue=response_queue,
233
- )
234
- )
235
-
236
- if req.streaming:
237
- yield wav_chunk_header()
238
-
239
- segments = []
240
- while True:
241
- result: WrappedGenerateResponse = response_queue.get()
242
- if result.status == "error":
243
- raise result.response
244
- break
245
-
246
- result: GenerateResponse = result.response
247
- if result.action == "next":
248
- break
249
-
250
- with autocast_exclude_mps(
251
- device_type=decoder_model.device.type, dtype=args.precision
252
- ):
253
- fake_audios = decode_vq_tokens(
254
- decoder_model=decoder_model,
255
- codes=result.codes,
256
- )
257
-
258
- fake_audios = fake_audios.float().cpu().numpy()
259
-
260
- if req.streaming:
261
- yield (fake_audios * 32768).astype(np.int16).tobytes()
262
- else:
263
- segments.append(fake_audios)
264
-
265
- if req.streaming:
266
- return
267
-
268
- if len(segments) == 0:
269
- raise HTTPException(
270
- HTTPStatus.INTERNAL_SERVER_ERROR,
271
- content="No audio generated, please check the input text.",
272
- )
273
-
274
- fake_audios = np.concatenate(segments, axis=0)
275
- yield fake_audios
276
-
277
-
278
- async def inference_async(req: ServeTTSRequest):
279
- for chunk in inference(req):
280
- yield chunk
281
-
282
-
283
- async def buffer_to_async_generator(buffer):
284
- yield buffer
285
-
286
-
287
- @routes.http.post("/v1/tts")
288
- async def api_invoke_model(
289
- req: Annotated[ServeTTSRequest, Body(exclusive=True)],
290
- ):
291
- """
292
- Invoke model and generate audio
293
- """
294
-
295
- if args.max_text_length > 0 and len(req.text) > args.max_text_length:
296
- raise HTTPException(
297
- HTTPStatus.BAD_REQUEST,
298
- content=f"Text is too long, max length is {args.max_text_length}",
299
- )
300
-
301
- if req.streaming and req.format != "wav":
302
- raise HTTPException(
303
- HTTPStatus.BAD_REQUEST,
304
- content="Streaming only supports WAV format",
305
- )
306
-
307
- if req.streaming:
308
- return StreamResponse(
309
- iterable=inference_async(req),
310
- headers={
311
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
312
- },
313
- content_type=get_content_type(req.format),
314
- )
315
- else:
316
- fake_audios = next(inference(req))
317
- buffer = io.BytesIO()
318
- sf.write(
319
- buffer,
320
- fake_audios,
321
- decoder_model.spec_transform.sample_rate,
322
- format=req.format,
323
- )
324
-
325
- return StreamResponse(
326
- iterable=buffer_to_async_generator(buffer.getvalue()),
327
- headers={
328
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
329
- },
330
- content_type=get_content_type(req.format),
331
- )
332
-
333
-
334
- @routes.http.post("/v1/health")
335
- async def api_health():
336
- """
337
- Health check
338
- """
339
-
340
- return JSONResponse({"status": "ok"})
341
-
342
-
343
- def parse_args():
344
- parser = ArgumentParser()
345
- parser.add_argument(
346
- "--llama-checkpoint-path",
347
- type=str,
348
- default="checkpoints/fish-speech-1.4",
349
- )
350
- parser.add_argument(
351
- "--decoder-checkpoint-path",
352
- type=str,
353
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
354
- )
355
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
356
- parser.add_argument("--device", type=str, default="cuda")
357
- parser.add_argument("--half", action="store_true")
358
- parser.add_argument("--compile", action="store_true")
359
- parser.add_argument("--max-text-length", type=int, default=0)
360
- parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
361
- parser.add_argument("--workers", type=int, default=1)
362
-
363
- return parser.parse_args()
364
-
365
-
366
- # Define Kui app
367
- openapi = OpenAPI(
368
- {
369
- "title": "Fish Speech API",
370
- },
371
- ).routes
372
-
373
-
374
- class MsgPackRequest(HttpRequest):
375
- async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
376
- if self.content_type == "application/msgpack":
377
- return ormsgpack.unpackb(await self.body)
378
-
379
- raise HTTPException(
380
- HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
381
- headers={"Accept": "application/msgpack"},
382
- )
383
-
384
-
385
- app = Kui(
386
- routes=routes + openapi[1:], # Remove the default route
387
- exception_handlers={
388
- HTTPException: http_execption_handler,
389
- Exception: other_exception_handler,
390
- },
391
- factory_class=FactoryClass(http=MsgPackRequest),
392
- cors_config={},
393
- )
394
-
395
-
396
- if __name__ == "__main__":
397
-
398
- import uvicorn
399
-
400
- args = parse_args()
401
- args.precision = torch.half if args.half else torch.bfloat16
402
-
403
- logger.info("Loading Llama model...")
404
- llama_queue = launch_thread_safe_queue(
405
- checkpoint_path=args.llama_checkpoint_path,
406
- device=args.device,
407
- precision=args.precision,
408
- compile=args.compile,
409
- )
410
- logger.info("Llama model loaded, loading VQ-GAN model...")
411
-
412
- decoder_model = load_decoder_model(
413
- config_name=args.decoder_config_name,
414
- checkpoint_path=args.decoder_checkpoint_path,
415
- device=args.device,
416
- )
417
-
418
- logger.info("VQ-GAN model loaded, warming up...")
419
-
420
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
421
- list(
422
- inference(
423
- ServeTTSRequest(
424
- text="Hello world.",
425
- references=[],
426
- reference_id=None,
427
- max_new_tokens=1024,
428
- chunk_length=200,
429
- top_p=0.7,
430
- repetition_penalty=1.2,
431
- temperature=0.7,
432
- emotion=None,
433
- format="wav",
434
- )
435
- )
436
- )
437
-
438
- logger.info(f"Warming up done, starting server at http://{args.listen}")
439
- host, port = args.listen.split(":")
440
- uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import queue
5
+ import random
6
+ import sys
7
+ import traceback
8
+ import wave
9
+ from argparse import ArgumentParser
10
+ from http import HTTPStatus
11
+ from pathlib import Path
12
+ from typing import Annotated, Any, Literal, Optional
13
+
14
+ import numpy as np
15
+ import ormsgpack
16
+ import pyrootutils
17
+ import soundfile as sf
18
+ import torch
19
+ import torchaudio
20
+ from baize.datastructures import ContentType
21
+ from kui.asgi import (
22
+ Body,
23
+ FactoryClass,
24
+ HTTPException,
25
+ HttpRequest,
26
+ HttpView,
27
+ JSONResponse,
28
+ Kui,
29
+ OpenAPI,
30
+ StreamResponse,
31
+ )
32
+ from kui.asgi.routing import MultimethodRoutes
33
+ from loguru import logger
34
+ from pydantic import BaseModel, Field, conint
35
+
36
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
37
+
38
+ # from fish_speech.models.vqgan.lit_module import VQGAN
39
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
40
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
41
+ from fish_speech.utils import autocast_exclude_mps
42
+ from tools.commons import ServeReferenceAudio, ServeTTSRequest
43
+ from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
44
+ from tools.llama.generate import (
45
+ GenerateRequest,
46
+ GenerateResponse,
47
+ WrappedGenerateResponse,
48
+ launch_thread_safe_queue,
49
+ )
50
+ from tools.vqgan.inference import load_model as load_decoder_model
51
+
52
+
53
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
54
+ buffer = io.BytesIO()
55
+
56
+ with wave.open(buffer, "wb") as wav_file:
57
+ wav_file.setnchannels(channels)
58
+ wav_file.setsampwidth(bit_depth // 8)
59
+ wav_file.setframerate(sample_rate)
60
+
61
+ wav_header_bytes = buffer.getvalue()
62
+ buffer.close()
63
+ return wav_header_bytes
64
+
65
+
66
+ # Define utils for web server
67
+ async def http_execption_handler(exc: HTTPException):
68
+ return JSONResponse(
69
+ dict(
70
+ statusCode=exc.status_code,
71
+ message=exc.content,
72
+ error=HTTPStatus(exc.status_code).phrase,
73
+ ),
74
+ exc.status_code,
75
+ exc.headers,
76
+ )
77
+
78
+
79
+ async def other_exception_handler(exc: "Exception"):
80
+ traceback.print_exc()
81
+
82
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
83
+ return JSONResponse(
84
+ dict(statusCode=status, message=str(exc), error=status.phrase),
85
+ status,
86
+ )
87
+
88
+
89
+ def load_audio(reference_audio, sr):
90
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
91
+ audio_data = reference_audio
92
+ reference_audio = io.BytesIO(audio_data)
93
+
94
+ waveform, original_sr = torchaudio.load(
95
+ reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
96
+ )
97
+
98
+ if waveform.shape[0] > 1:
99
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
100
+
101
+ if original_sr != sr:
102
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
103
+ waveform = resampler(waveform)
104
+
105
+ audio = waveform.squeeze().numpy()
106
+ return audio
107
+
108
+
109
+ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
110
+ if enable_reference_audio and reference_audio is not None:
111
+ # Load audios, and prepare basic info here
112
+ reference_audio_content = load_audio(
113
+ reference_audio, decoder_model.spec_transform.sample_rate
114
+ )
115
+
116
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
117
+ None, None, :
118
+ ]
119
+ audio_lengths = torch.tensor(
120
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
121
+ )
122
+ logger.info(
123
+ f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
124
+ )
125
+
126
+ # VQ Encoder
127
+ if isinstance(decoder_model, FireflyArchitecture):
128
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
129
+
130
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
131
+ else:
132
+ prompt_tokens = None
133
+ logger.info("No reference audio provided")
134
+
135
+ return prompt_tokens
136
+
137
+
138
+ def decode_vq_tokens(
139
+ *,
140
+ decoder_model,
141
+ codes,
142
+ ):
143
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
144
+ logger.info(f"VQ features: {codes.shape}")
145
+
146
+ if isinstance(decoder_model, FireflyArchitecture):
147
+ # VQGAN Inference
148
+ return decoder_model.decode(
149
+ indices=codes[None],
150
+ feature_lengths=feature_lengths,
151
+ )[0].squeeze()
152
+
153
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
154
+
155
+
156
+ routes = MultimethodRoutes(base_class=HttpView)
157
+
158
+
159
+ def get_content_type(audio_format):
160
+ if audio_format == "wav":
161
+ return "audio/wav"
162
+ elif audio_format == "flac":
163
+ return "audio/flac"
164
+ elif audio_format == "mp3":
165
+ return "audio/mpeg"
166
+ else:
167
+ return "application/octet-stream"
168
+
169
+
170
+ @torch.inference_mode()
171
+ def inference(req: ServeTTSRequest):
172
+
173
+ idstr: str | None = req.reference_id
174
+ if idstr is not None:
175
+ ref_folder = Path("references") / idstr
176
+ ref_folder.mkdir(parents=True, exist_ok=True)
177
+ ref_audios = list_files(
178
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
179
+ )
180
+ prompt_tokens = [
181
+ encode_reference(
182
+ decoder_model=decoder_model,
183
+ reference_audio=audio_to_bytes(str(ref_audio)),
184
+ enable_reference_audio=True,
185
+ )
186
+ for ref_audio in ref_audios
187
+ ]
188
+ prompt_texts = [
189
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
190
+ for ref_audio in ref_audios
191
+ ]
192
+
193
+ else:
194
+ # Parse reference audio aka prompt
195
+ refs = req.references
196
+ if refs is None:
197
+ refs = []
198
+ prompt_tokens = [
199
+ encode_reference(
200
+ decoder_model=decoder_model,
201
+ reference_audio=ref.audio,
202
+ enable_reference_audio=True,
203
+ )
204
+ for ref in refs
205
+ ]
206
+ prompt_texts = [ref.text for ref in refs]
207
+
208
+ # LLAMA Inference
209
+ request = dict(
210
+ device=decoder_model.device,
211
+ max_new_tokens=req.max_new_tokens,
212
+ text=(
213
+ req.text
214
+ if not req.normalize
215
+ else ChnNormedText(raw_text=req.text).normalize()
216
+ ),
217
+ top_p=req.top_p,
218
+ repetition_penalty=req.repetition_penalty,
219
+ temperature=req.temperature,
220
+ compile=args.compile,
221
+ iterative_prompt=req.chunk_length > 0,
222
+ chunk_length=req.chunk_length,
223
+ max_length=2048,
224
+ prompt_tokens=prompt_tokens,
225
+ prompt_text=prompt_texts,
226
+ )
227
+
228
+ response_queue = queue.Queue()
229
+ llama_queue.put(
230
+ GenerateRequest(
231
+ request=request,
232
+ response_queue=response_queue,
233
+ )
234
+ )
235
+
236
+ if req.streaming:
237
+ yield wav_chunk_header()
238
+
239
+ segments = []
240
+ while True:
241
+ result: WrappedGenerateResponse = response_queue.get()
242
+ if result.status == "error":
243
+ raise result.response
244
+ break
245
+
246
+ result: GenerateResponse = result.response
247
+ if result.action == "next":
248
+ break
249
+
250
+ with autocast_exclude_mps(
251
+ device_type=decoder_model.device.type, dtype=args.precision
252
+ ):
253
+ fake_audios = decode_vq_tokens(
254
+ decoder_model=decoder_model,
255
+ codes=result.codes,
256
+ )
257
+
258
+ fake_audios = fake_audios.float().cpu().numpy()
259
+
260
+ if req.streaming:
261
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
262
+ else:
263
+ segments.append(fake_audios)
264
+
265
+ if req.streaming:
266
+ return
267
+
268
+ if len(segments) == 0:
269
+ raise HTTPException(
270
+ HTTPStatus.INTERNAL_SERVER_ERROR,
271
+ content="No audio generated, please check the input text.",
272
+ )
273
+
274
+ fake_audios = np.concatenate(segments, axis=0)
275
+ yield fake_audios
276
+
277
+
278
+ async def inference_async(req: ServeTTSRequest):
279
+ for chunk in inference(req):
280
+ yield chunk
281
+
282
+
283
+ async def buffer_to_async_generator(buffer):
284
+ yield buffer
285
+
286
+
287
+ @routes.http.post("/v1/tts")
288
+ async def api_invoke_model(
289
+ req: Annotated[ServeTTSRequest, Body(exclusive=True)],
290
+ ):
291
+ """
292
+ Invoke model and generate audio
293
+ """
294
+
295
+ if args.max_text_length > 0 and len(req.text) > args.max_text_length:
296
+ raise HTTPException(
297
+ HTTPStatus.BAD_REQUEST,
298
+ content=f"Text is too long, max length is {args.max_text_length}",
299
+ )
300
+
301
+ if req.streaming and req.format != "wav":
302
+ raise HTTPException(
303
+ HTTPStatus.BAD_REQUEST,
304
+ content="Streaming only supports WAV format",
305
+ )
306
+
307
+ if req.streaming:
308
+ return StreamResponse(
309
+ iterable=inference_async(req),
310
+ headers={
311
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
312
+ },
313
+ content_type=get_content_type(req.format),
314
+ )
315
+ else:
316
+ fake_audios = next(inference(req))
317
+ buffer = io.BytesIO()
318
+ sf.write(
319
+ buffer,
320
+ fake_audios,
321
+ decoder_model.spec_transform.sample_rate,
322
+ format=req.format,
323
+ )
324
+
325
+ return StreamResponse(
326
+ iterable=buffer_to_async_generator(buffer.getvalue()),
327
+ headers={
328
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
329
+ },
330
+ content_type=get_content_type(req.format),
331
+ )
332
+
333
+
334
+ @routes.http.post("/v1/health")
335
+ async def api_health():
336
+ """
337
+ Health check
338
+ """
339
+
340
+ return JSONResponse({"status": "ok"})
341
+
342
+
343
+ def parse_args():
344
+ parser = ArgumentParser()
345
+ parser.add_argument(
346
+ "--llama-checkpoint-path",
347
+ type=str,
348
+ default="checkpoints/fish-speech-1.4-sft-yth-lora",
349
+ )
350
+ parser.add_argument(
351
+ "--decoder-checkpoint-path",
352
+ type=str,
353
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
354
+ )
355
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
356
+ parser.add_argument("--device", type=str, default="cuda")
357
+ parser.add_argument("--half", action="store_true")
358
+ parser.add_argument("--compile", action="store_true")
359
+ parser.add_argument("--max-text-length", type=int, default=0)
360
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
361
+ parser.add_argument("--workers", type=int, default=1)
362
+
363
+ return parser.parse_args()
364
+
365
+
366
+ # Define Kui app
367
+ openapi = OpenAPI(
368
+ {
369
+ "title": "Fish Speech API",
370
+ },
371
+ ).routes
372
+
373
+
374
+ class MsgPackRequest(HttpRequest):
375
+ async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
376
+ if self.content_type == "application/msgpack":
377
+ return ormsgpack.unpackb(await self.body)
378
+
379
+ raise HTTPException(
380
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
381
+ headers={"Accept": "application/msgpack"},
382
+ )
383
+
384
+
385
+ app = Kui(
386
+ routes=routes + openapi[1:], # Remove the default route
387
+ exception_handlers={
388
+ HTTPException: http_execption_handler,
389
+ Exception: other_exception_handler,
390
+ },
391
+ factory_class=FactoryClass(http=MsgPackRequest),
392
+ cors_config={},
393
+ )
394
+
395
+
396
+ if __name__ == "__main__":
397
+
398
+ import uvicorn
399
+
400
+ args = parse_args()
401
+ args.precision = torch.half if args.half else torch.bfloat16
402
+
403
+ logger.info("Loading Llama model...")
404
+ llama_queue = launch_thread_safe_queue(
405
+ checkpoint_path=args.llama_checkpoint_path,
406
+ device=args.device,
407
+ precision=args.precision,
408
+ compile=args.compile,
409
+ )
410
+ logger.info("Llama model loaded, loading VQ-GAN model...")
411
+
412
+ decoder_model = load_decoder_model(
413
+ config_name=args.decoder_config_name,
414
+ checkpoint_path=args.decoder_checkpoint_path,
415
+ device=args.device,
416
+ )
417
+
418
+ logger.info("VQ-GAN model loaded, warming up...")
419
+
420
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
421
+ list(
422
+ inference(
423
+ ServeTTSRequest(
424
+ text="Hello world.",
425
+ references=[],
426
+ reference_id=None,
427
+ max_new_tokens=1024,
428
+ chunk_length=200,
429
+ top_p=0.7,
430
+ repetition_penalty=1.2,
431
+ temperature=0.7,
432
+ emotion=None,
433
+ format="wav",
434
+ )
435
+ )
436
+ )
437
+
438
+ logger.info(f"Warming up done, starting server at http://{args.listen}")
439
+ host, port = args.listen.split(":")
440
+ uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")