cozytales-backend / server.py
SebastianSchramm's picture
fix background tasks
bfd2ec1 unverified
raw
history blame
2.45 kB
import logging
import random
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.responses import FileResponse
from fastapi import BackgroundTasks
from starlette.requests import Request
from kokoro import KPipeline
import soundfile as sf
import tempfile
import numpy as np
import os
random.seed(42)
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
language_map = {"en": "a", "es": "e"}
speaker_map = {"en": "af_heart", "es": "em_santa"}
def cleanup_temp_file(file_path: str):
"""Clean up temporary file after response is sent"""
try:
os.unlink(file_path)
except OSError:
pass
def text_to_audio_chunks(text, voice="af_heart", language="a"):
pipeline = KPipeline(lang_code=language)
generator = pipeline(text, voice=voice)
audios = [audio for (gs, ps, audio) in generator]
return audios
def concat_chunks(audios, samplerate=24000, silence_dur=0.3):
# Convert PyTorch tensors to NumPy arrays
audio_arrays = [audio.numpy() if hasattr(audio, 'numpy') else audio for audio in audios]
if not audio_arrays:
return np.array([]) # Return empty array if no audio chunks
silence = np.zeros(int(samplerate * silence_dur), dtype=audio_arrays[0].dtype)
# Insert silence between all but last
chunks = sum([[chunk, silence] for chunk in audio_arrays[:-1]], []) + [audio_arrays[-1]]
return np.concatenate(chunks)
def get_audio(text: str, language: str):
voice = speaker_map.get(language, "af_heart")
language = language_map.get(language, "a")
audios = text_to_audio_chunks(text, voice=voice, language=language)
final_audio = concat_chunks(audios)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(tmp.name, final_audio, 24000)
tmp.close()
return tmp.name
class InputLoad(BaseModel):
text: str
language: str
app = FastAPI()
@app.get("/health")
def health_check():
return {"server": "running"}
@app.post("/answer/")
async def receive(input_load: InputLoad, request: Request) -> FileResponse:
audio_path = get_audio(input_load.text, input_load.language)
background_tasks = BackgroundTasks()
background_tasks.add_task(cleanup_temp_file, audio_path)
return FileResponse(
path=audio_path,
media_type="audio/wav",
filename="generated_audio.wav",
background=background_tasks
)