Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Query, HTTPException, BackgroundTasks | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Optional, Tuple, Generator | |
import torch | |
import os | |
import io | |
import numpy as np | |
from kokoro import KModel, KPipeline | |
import spaces | |
import time | |
app = FastAPI(title="Kokoro TTS API", description="API for Kokoro text-to-speech conversion") | |
# Constants | |
IS_DUPLICATE = not os.getenv('SPACE_ID', '').startswith('hexgrad/') | |
CHAR_LIMIT = None if IS_DUPLICATE else 5000 | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
# Initialize models | |
models = {gpu: KModel().to('cuda' if gpu else 'cpu').eval() for gpu in [False] + ([True] if CUDA_AVAILABLE else [])} | |
pipelines = {lang_code: KPipeline(lang_code=lang_code, model=False) for lang_code in 'ab'} | |
pipelines['a'].g2p.lexicon.golds['kokoro'] = 'kหOkษษนO' | |
pipelines['b'].g2p.lexicon.golds['kokoro'] = 'kหQkษษนQ' | |
# Voice choices | |
CHOICES = { | |
'๐บ๐ธ ๐บ Heart โค๏ธ': 'af_heart', | |
'๐บ๐ธ ๐บ Bella ๐ฅ': 'af_bella', | |
'๐บ๐ธ ๐บ Nicole ๐ง': 'af_nicole', | |
'๐บ๐ธ ๐บ Aoede': 'af_aoede', | |
'๐บ๐ธ ๐บ Kore': 'af_kore', | |
'๐บ๐ธ ๐บ Sarah': 'af_sarah', | |
'๐บ๐ธ ๐บ Nova': 'af_nova', | |
'๐บ๐ธ ๐บ Sky': 'af_sky', | |
'๐บ๐ธ ๐บ Alloy': 'af_alloy', | |
'๐บ๐ธ ๐บ Jessica': 'af_jessica', | |
'๐บ๐ธ ๐บ River': 'af_river', | |
'๐บ๐ธ ๐น Michael': 'am_michael', | |
'๐บ๐ธ ๐น Fenrir': 'am_fenrir', | |
'๐บ๐ธ ๐น Puck': 'am_puck', | |
'๐บ๐ธ ๐น Echo': 'am_echo', | |
'๐บ๐ธ ๐น Eric': 'am_eric', | |
'๐บ๐ธ ๐น Liam': 'am_liam', | |
'๐บ๐ธ ๐น Onyx': 'am_onyx', | |
'๐บ๐ธ ๐น Santa': 'am_santa', | |
'๐บ๐ธ ๐น Adam': 'am_adam', | |
'๐ฌ๐ง ๐บ Emma': 'bf_emma', | |
'๐ฌ๐ง ๐บ Isabella': 'bf_isabella', | |
'๐ฌ๐ง ๐บ Alice': 'bf_alice', | |
'๐ฌ๐ง ๐บ Lily': 'bf_lily', | |
'๐ฌ๐ง ๐น George': 'bm_george', | |
'๐ฌ๐ง ๐น Fable': 'bm_fable', | |
'๐ฌ๐ง ๐น Lewis': 'bm_lewis', | |
'๐ฌ๐ง ๐น Daniel': 'bm_daniel', | |
} | |
# Load voices | |
for v in CHOICES.values(): | |
pipelines[v[0]].load_voice(v) | |
# Sample text files | |
with open('en.txt', 'r') as r: | |
RANDOM_QUOTES = [line.strip() for line in r] | |
def get_gatsby(): | |
with open('gatsby5k.md', 'r') as r: | |
return r.read().strip() | |
def get_frankenstein(): | |
with open('frankenstein5k.md', 'r') as r: | |
return r.read().strip() | |
# Pydantic models | |
class TTSRequest(BaseModel): | |
text: str = Field(..., description="Text to convert to speech") | |
voice: str = Field("af_heart", description="Voice ID to use for TTS") | |
speed: float = Field(1.0, description="Speech speed factor (0.5 to 2.0)", ge=0.5, le=2.0) | |
use_gpu: bool = Field(CUDA_AVAILABLE, description="Whether to use GPU for inference") | |
class TextRequest(BaseModel): | |
text: str = Field(..., description="Text to tokenize") | |
voice: str = Field("af_heart", description="Voice ID to use for tokenization") | |
class Voice(BaseModel): | |
display_name: str | |
id: str | |
language: str | |
gender: str | |
class VoiceList(BaseModel): | |
voices: List[Voice] | |
# GPU wrapper function | |
def forward_gpu(ps, ref_s, speed): | |
return models[True](ps, ref_s, speed) | |
# Helper functions | |
def generate_first(text: str, voice: str = 'af_heart', speed: float = 1.0, use_gpu: bool = CUDA_AVAILABLE): | |
"""Generate audio for the first sentence/segment of text""" | |
text = text if CHAR_LIMIT is None else text.strip()[:CHAR_LIMIT] | |
pipeline = pipelines[voice[0]] | |
pack = pipeline.load_voice(voice) | |
use_gpu = use_gpu and CUDA_AVAILABLE | |
for _, ps, _ in pipeline(text, voice, speed): | |
ref_s = pack[len(ps)-1] | |
try: | |
if use_gpu: | |
audio = forward_gpu(ps, ref_s, speed) | |
else: | |
audio = models[False](ps, ref_s, speed) | |
except Exception as e: | |
if use_gpu: | |
# Fallback to CPU | |
audio = models[False](ps, ref_s, speed) | |
else: | |
raise HTTPException(status_code=500, detail=str(e)) | |
return (24000, audio.numpy()), ps | |
return None, '' | |
def tokenize_first(text: str, voice: str = 'af_heart'): | |
"""Tokenize the first sentence/segment of text""" | |
pipeline = pipelines[voice[0]] | |
for _, ps, _ in pipeline(text, voice): | |
return ps | |
return '' | |
def generate_all(text: str, voice: str = 'af_heart', speed: float = 1.0, use_gpu: bool = CUDA_AVAILABLE) -> Generator: | |
"""Generate audio for all segments of text""" | |
text = text if CHAR_LIMIT is None else text.strip()[:CHAR_LIMIT] | |
pipeline = pipelines[voice[0]] | |
pack = pipeline.load_voice(voice) | |
use_gpu = use_gpu and CUDA_AVAILABLE | |
for _, ps, _ in pipeline(text, voice, speed): | |
ref_s = pack[len(ps)-1] | |
try: | |
if use_gpu: | |
audio = forward_gpu(ps, ref_s, speed) | |
else: | |
audio = models[False](ps, ref_s, speed) | |
except Exception as e: | |
if use_gpu: | |
# Fallback to CPU | |
audio = models[False](ps, ref_s, speed) | |
else: | |
raise HTTPException(status_code=500, detail=str(e)) | |
yield audio.numpy() | |
def create_wav(audio_data, sample_rate=24000): | |
"""Convert numpy array to WAV bytes""" | |
import wave | |
import struct | |
wav_io = io.BytesIO() | |
with wave.open(wav_io, 'wb') as wav_file: | |
wav_file.setnchannels(1) # Mono | |
wav_file.setsampwidth(2) # 16-bit | |
wav_file.setframerate(sample_rate) | |
# Convert float32 to int16 | |
audio_data = (audio_data * 32767).astype(np.int16) | |
wav_file.writeframes(audio_data.tobytes()) | |
wav_io.seek(0) | |
return wav_io.read() | |
def stream_wav_chunks(audio_chunks, sample_rate=24000): | |
"""Stream WAV chunks as they're generated""" | |
# Write WAV header first | |
header_io = io.BytesIO() | |
with wave.open(header_io, 'wb') as wav_file: | |
wav_file.setnchannels(1) # Mono | |
wav_file.setsampwidth(2) # 16-bit | |
wav_file.setframerate(sample_rate) | |
# We don't know the total frames yet | |
wav_file.writeframes(b'') | |
# Get header bytes | |
header_io.seek(0) | |
header_bytes = header_io.read(44) # WAV header is 44 bytes | |
yield header_bytes | |
# Stream audio chunks | |
for chunk in audio_chunks: | |
# Convert float32 to int16 | |
audio_data = (chunk * 32767).astype(np.int16) | |
yield audio_data.tobytes() | |
time.sleep(0.1) # Small delay to avoid overwhelming the client | |
# API Routes | |
async def root(): | |
"""API root with basic information""" | |
return { | |
"message": "Kokoro TTS API", | |
"description": "Convert text to speech using Kokoro TTS model", | |
"endpoints": { | |
"GET /voices": "List available voices", | |
"POST /tts": "Convert text to speech", | |
"POST /tokenize": "Tokenize text", | |
"GET /stream": "Stream audio from text", | |
"GET /samples": "Get sample texts" | |
} | |
} | |
async def list_voices(): | |
"""List all available voices""" | |
voice_list = [] | |
for display_name, voice_id in CHOICES.items(): | |
# Parse display name format: "๐บ๐ธ ๐บ Heart โค๏ธ" | |
parts = display_name.split() | |
language = "US English" if "๐บ๐ธ" in display_name else "UK English" | |
gender = "Female" if "๐บ" in display_name else "Male" | |
voice_list.append(Voice( | |
display_name=display_name, | |
id=voice_id, | |
language=language, | |
gender=gender | |
)) | |
return VoiceList(voices=voice_list) | |
async def text_to_speech(request: TTSRequest): | |
"""Convert text to speech""" | |
if request.voice not in CHOICES.values(): | |
raise HTTPException(status_code=400, detail=f"Voice '{request.voice}' not found. Use /voices to see available options.") | |
result, _ = generate_first(request.text, request.voice, request.speed, request.use_gpu) | |
if result is None: | |
raise HTTPException(status_code=500, detail="Failed to generate audio") | |
sample_rate, audio_data = result | |
wav_bytes = create_wav(audio_data, sample_rate) | |
return StreamingResponse( | |
io.BytesIO(wav_bytes), | |
media_type="audio/wav", | |
headers={"Content-Disposition": f"attachment; filename=tts_{request.voice}.wav"} | |
) | |
async def tokenize_text(request: TextRequest): | |
"""Tokenize input text""" | |
if request.voice not in CHOICES.values(): | |
raise HTTPException(status_code=400, detail=f"Voice '{request.voice}' not found. Use /voices to see available options.") | |
tokens = tokenize_first(request.text, request.voice) | |
return {"text": request.text, "tokens": tokens} | |
async def stream_tts( | |
text: str = Query(..., description="Text to convert to speech"), | |
voice: str = Query("af_heart", description="Voice ID"), | |
speed: float = Query(1.0, description="Speech speed", ge=0.5, le=2.0), | |
use_gpu: bool = Query(CUDA_AVAILABLE, description="Use GPU for inference") | |
): | |
"""Stream audio from text as it's generated""" | |
if voice not in CHOICES.values(): | |
raise HTTPException(status_code=400, detail=f"Voice '{voice}' not found. Use /voices to see available options.") | |
# Limit text if needed | |
if CHAR_LIMIT is not None: | |
text = text.strip()[:CHAR_LIMIT] | |
# Create generator for audio chunks | |
audio_chunks = generate_all(text, voice, speed, use_gpu) | |
# Stream as WAV | |
return StreamingResponse( | |
stream_wav_chunks(audio_chunks), | |
media_type="audio/wav", | |
headers={"Content-Disposition": f"attachment; filename=stream_{voice}.wav"} | |
) | |
async def get_samples(): | |
"""Get sample texts""" | |
import random | |
return { | |
"random_quote": random.choice(RANDOM_QUOTES), | |
"gatsby_excerpt": get_gatsby()[:200] + "...", # First 200 chars | |
"frankenstein_excerpt": get_frankenstein()[:200] + "..." # First 200 chars | |
} | |
async def get_sample(sample_type: str): | |
"""Get a specific sample text""" | |
import random | |
if sample_type == "random": | |
return {"text": random.choice(RANDOM_QUOTES)} | |
elif sample_type == "gatsby": | |
return {"text": get_gatsby()} | |
elif sample_type == "frankenstein": | |
return {"text": get_frankenstein()} | |
else: | |
raise HTTPException(status_code=404, detail=f"Sample type '{sample_type}' not found") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |