import os os.environ["HF_HOME"] = "/tmp" os.environ["TRANSFORMERS_CACHE"] = "/tmp" os.environ["TORCH_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" import io import re import math import numpy as np import scipy.io.wavfile import torch from fastapi import FastAPI, Query from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import VitsModel, AutoTokenizer app = FastAPI() model = VitsModel.from_pretrained("Somali-tts/somali_tts_model") tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() number_words = { 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan", 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban", 11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex", 14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix", 17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal", 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton", 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan", 100: "boqol", 1000: "kun" } def number_to_words(number: int) -> str: if number < 20: return number_words[number] elif number < 100: tens, unit = divmod(number, 10) return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "") elif number < 1000: hundreds, remainder = divmod(number, 100) part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol" if remainder: part += " iyo " + number_to_words(remainder) return part elif number < 1000000: thousands, remainder = divmod(number, 1000) words = [] if thousands == 1: words.append("kun") else: words.append(number_to_words(thousands) + " kun") if remainder: words.append("iyo " + number_to_words(remainder)) return " ".join(words) elif number < 1000000000: millions, remainder = divmod(number, 1000000) words = [] if millions == 1: words.append("milyan") else: words.append(number_to_words(millions) + " milyan") if remainder: words.append(number_to_words(remainder)) return " ".join(words) else: return str(number) def normalize_text(text: str) -> str: numbers = re.findall(r'\d+', text) for num in numbers: text = text.replace(num, number_to_words(int(num))) text = text.replace("KH", "qa").replace("Z", "S") text = text.replace("SH", "SHa'a").replace("DH", "Dha'a") text = text.replace("ZamZam", "SamSam") return text def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes: np_waveform = waveform.cpu().numpy() if np_waveform.ndim == 3: np_waveform = np_waveform[0] if np_waveform.ndim == 2: np_waveform = np_waveform.mean(axis=0) np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32) pcm_waveform = (np_waveform * 32767).astype(np.int16) buf = io.BytesIO() scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) buf.seek(0) return buf.read() class TextIn(BaseModel): inputs: str @app.post("/synthesize") async def synthesize_post(data: TextIn): text = normalize_text(data.inputs) inputs = tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): output = model(**inputs) if hasattr(output, "waveform"): waveform = output.waveform elif isinstance(output, dict) and "waveform" in output: waveform = output["waveform"] elif isinstance(output, (tuple, list)): waveform = output[0] else: return {"error": "Waveform not found in model output"} sample_rate = getattr(model.config, "sampling_rate", 22050) wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate) return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav") @app.get("/synthesize") async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)): if test: duration_s = 2.0 sample_rate = 22050 t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False) freq = 440 waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32) pcm_waveform = (waveform * 32767).astype(np.int16) buf = io.BytesIO() scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") normalized = normalize_text(text) inputs = tokenizer(normalized, return_tensors="pt").to(device) with torch.no_grad(): output = model(**inputs) if hasattr(output, "waveform"): waveform = output.waveform elif isinstance(output, dict) and "waveform" in output: waveform = output["waveform"] elif isinstance(output, (tuple, list)): waveform = output[0] else: return {"error": "Waveform not found in model output"} sample_rate = getattr(model.config, "sampling_rate", 22050) wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate) return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")