Spaces:
Build error
Build error
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import os | |
import uuid | |
import torch | |
import torchaudio | |
import base64 | |
from transformers import AutoModelForCausalLM | |
from yarngpt.audiotokenizer import AudioTokenizerV2 | |
import uvicorn | |
from datetime import datetime, timedelta | |
import asyncio | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Nigerian TTS API", version="1.0.0") | |
# Add CORS middleware to allow requests from anywhere | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variables for model components | |
audio_tokenizer = None | |
model = None | |
model_loaded = False | |
loading_error = None | |
# Model configuration - Updated paths for Hugging Face Spaces | |
tokenizer_path = "saheedniyi/YarnGPT2" | |
# These files should be downloaded to /tmp during startup | |
wav_tokenizer_config_path = "/tmp/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
wav_tokenizer_model_path = "/tmp/wavtokenizer_large_speech_320_24k.ckpt" | |
# Available voices and languages | |
AVAILABLE_VOICES = { | |
"female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"], | |
"male": ["jude", "tayo", "umar", "osagie", "onye", "emma"] | |
} | |
AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"] | |
# Input validation model | |
class TTSRequest(BaseModel): | |
text: str | |
language: str = "english" | |
voice: str = "idera" | |
# Output model with base64-encoded audio | |
class TTSResponse(BaseModel): | |
audio_base64: str | |
audio_url: str | |
text: str | |
voice: str | |
language: str | |
async def download_model_files(): | |
"""Download required model files""" | |
global loading_error | |
try: | |
import requests | |
from pathlib import Path | |
logger.info("Starting model file downloads...") | |
# URLs for the model files | |
config_url = "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
model_url = "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/wavtokenizer_large_speech_320_24k.ckpt" | |
# Create tmp directory if it doesn't exist | |
Path("/tmp").mkdir(exist_ok=True) | |
# Download config file | |
if not os.path.exists(wav_tokenizer_config_path): | |
logger.info("Downloading tokenizer config...") | |
response = requests.get(config_url, stream=True) | |
response.raise_for_status() | |
with open(wav_tokenizer_config_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
logger.info("Config file downloaded successfully") | |
# Download model file | |
if not os.path.exists(wav_tokenizer_model_path): | |
logger.info("Downloading tokenizer model (this may take a while)...") | |
response = requests.get(model_url, stream=True) | |
response.raise_for_status() | |
total_size = int(response.headers.get('content-length', 0)) | |
downloaded = 0 | |
with open(wav_tokenizer_model_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
downloaded += len(chunk) | |
if total_size > 0: | |
progress = (downloaded / total_size) * 100 | |
if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB | |
logger.info(f"Download progress: {progress:.1f}%") | |
logger.info("Model file downloaded successfully") | |
logger.info("All model files are ready") | |
except Exception as e: | |
error_msg = f"Error downloading model files: {str(e)}" | |
logger.error(error_msg) | |
loading_error = error_msg | |
raise e | |
async def load_models(): | |
"""Load the YarnGPT model and tokenizer""" | |
global audio_tokenizer, model, model_loaded, loading_error | |
try: | |
logger.info("Loading YarnGPT model and tokenizer...") | |
# First download the required files | |
await download_model_files() | |
# Initialize audio tokenizer | |
logger.info("Initializing audio tokenizer...") | |
audio_tokenizer = AudioTokenizerV2( | |
tokenizer_path, | |
wav_tokenizer_model_path, | |
wav_tokenizer_config_path | |
) | |
# Load the main model | |
logger.info("Loading main model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
tokenizer_path, | |
torch_dtype="auto" | |
).to(audio_tokenizer.device) | |
model_loaded = True | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
error_msg = f"Error loading models: {str(e)}" | |
logger.error(error_msg) | |
loading_error = error_msg | |
model_loaded = False | |
async def startup_event(): | |
"""Load models when the API starts""" | |
asyncio.create_task(load_models()) | |
async def root(): | |
"""API health check and info""" | |
return { | |
"status": "ok" if model_loaded else "loading", | |
"message": "Nigerian TTS API is running" if model_loaded else "Models are loading...", | |
"model_loaded": model_loaded, | |
"loading_error": loading_error, | |
"available_languages": AVAILABLE_LANGUAGES, | |
"available_voices": AVAILABLE_VOICES | |
} | |
async def health_check(): | |
"""Detailed health check""" | |
return { | |
"status": "healthy" if model_loaded else "loading", | |
"model_loaded": model_loaded, | |
"loading_error": loading_error, | |
"timestamp": datetime.now().isoformat() | |
} | |
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks): | |
"""Convert text to Nigerian-accented speech""" | |
# Check if models are loaded | |
if not model_loaded: | |
if loading_error: | |
raise HTTPException(status_code=503, detail=f"Model loading failed: {loading_error}") | |
else: | |
raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.") | |
# Validate inputs | |
if request.language not in AVAILABLE_LANGUAGES: | |
raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}") | |
all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"] | |
if request.voice not in all_voices: | |
raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}") | |
# Generate unique filename | |
audio_id = str(uuid.uuid4()) | |
output_path = f"audio_files/{audio_id}.wav" | |
os.makedirs("audio_files", exist_ok=True) | |
try: | |
logger.info(f"Generating TTS for text: '{request.text[:50]}...' with voice: {request.voice}") | |
# Create prompt and generate audio | |
prompt = audio_tokenizer.create_prompt( | |
request.text, | |
lang=request.language, | |
speaker_name=request.voice | |
) | |
input_ids = audio_tokenizer.tokenize_prompt(prompt) | |
output = model.generate( | |
input_ids=input_ids, | |
temperature=0.1, | |
repetition_penalty=1.1, | |
max_length=4000, | |
) | |
codes = audio_tokenizer.get_codes(output) | |
audio = audio_tokenizer.get_audio(codes) | |
# Save audio file | |
torchaudio.save(output_path, audio, sample_rate=24000) | |
logger.info(f"Audio saved to {output_path}") | |
# Read the file and encode as base64 | |
with open(output_path, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
# Clean up old files after a while | |
background_tasks.add_task(cleanup_old_files) | |
return TTSResponse( | |
audio_base64=audio_base64, | |
audio_url=f"/audio/{audio_id}.wav", | |
text=request.text, | |
voice=request.voice, | |
language=request.language | |
) | |
except Exception as e: | |
logger.error(f"Error generating audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}") | |
async def get_audio(filename: str): | |
"""Serve audio files""" | |
file_path = f"audio_files/{filename}" | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="Audio file not found") | |
return FileResponse(file_path, media_type="audio/wav") | |
def cleanup_old_files(): | |
"""Delete audio files older than 6 hours to manage disk space""" | |
try: | |
now = datetime.now() | |
audio_dir = "audio_files" | |
if not os.path.exists(audio_dir): | |
return | |
for filename in os.listdir(audio_dir): | |
if not filename.endswith(".wav"): | |
continue | |
file_path = os.path.join(audio_dir, filename) | |
file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
# Delete files older than 6 hours | |
if now - file_mod_time > timedelta(hours=6): | |
os.remove(file_path) | |
logger.info(f"Deleted old audio file: {filename}") | |
except Exception as e: | |
logger.error(f"Error cleaning up old files: {e}") | |
if __name__ == "__main__": | |
logger.info("Starting Nigerian TTS API server...") | |
uvicorn.run(app, host="0.0.0.0", port=7860) |