Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,10 +14,9 @@ app = FastAPI()
|
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
CACHE_DIR = "/tmp/hf-cache"
|
16 |
|
17 |
-
# Load models
|
18 |
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
|
19 |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
|
20 |
-
model_male = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/5aad", cache_dir=CACHE_DIR).to(device)
|
21 |
model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
|
22 |
|
23 |
# Speaker encoder
|
@@ -27,7 +26,7 @@ speaker_model = EncoderClassifier.from_hparams(
|
|
27 |
savedir="/tmp/spk_model"
|
28 |
)
|
29 |
|
30 |
-
# Load
|
31 |
def get_embedding(wav_path, pt_path):
|
32 |
if os.path.exists(pt_path):
|
33 |
return torch.load(pt_path).to(device)
|
@@ -39,7 +38,6 @@ def get_embedding(wav_path, pt_path):
|
|
39 |
torch.save(emb.cpu(), pt_path)
|
40 |
return emb
|
41 |
|
42 |
-
embedding_male = get_embedding("Hussein.wav", "/tmp/male_embedding.pt")
|
43 |
embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
|
44 |
|
45 |
# Text normalization
|
@@ -75,20 +73,17 @@ def normalize_text(text):
|
|
75 |
text = re.sub(r'[^\w\s]', '', text)
|
76 |
return text
|
77 |
|
78 |
-
#
|
79 |
class TTSRequest(BaseModel):
|
80 |
text: str
|
81 |
-
voice: str # "Male" or "Female"
|
82 |
|
83 |
@app.post("/speak")
|
84 |
def speak(payload: TTSRequest):
|
85 |
clean_text = normalize_text(payload.text)
|
86 |
inputs = processor(text=clean_text, return_tensors="pt").to(device)
|
87 |
-
model = model_male if payload.voice.lower() == "male" else model_female
|
88 |
-
embedding = embedding_male if payload.voice.lower() == "male" else embedding_female
|
89 |
|
90 |
with torch.no_grad():
|
91 |
-
waveform =
|
92 |
|
93 |
out_path = f"/tmp/{uuid.uuid4().hex}.wav"
|
94 |
sf.write(out_path, waveform.cpu().numpy(), 16000)
|
|
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
CACHE_DIR = "/tmp/hf-cache"
|
16 |
|
17 |
+
# Load models (female only)
|
18 |
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
|
19 |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
|
|
|
20 |
model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
|
21 |
|
22 |
# Speaker encoder
|
|
|
26 |
savedir="/tmp/spk_model"
|
27 |
)
|
28 |
|
29 |
+
# Load female embedding only
|
30 |
def get_embedding(wav_path, pt_path):
|
31 |
if os.path.exists(pt_path):
|
32 |
return torch.load(pt_path).to(device)
|
|
|
38 |
torch.save(emb.cpu(), pt_path)
|
39 |
return emb
|
40 |
|
|
|
41 |
embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
|
42 |
|
43 |
# Text normalization
|
|
|
73 |
text = re.sub(r'[^\w\s]', '', text)
|
74 |
return text
|
75 |
|
76 |
+
# Request schema without voice choice
|
77 |
class TTSRequest(BaseModel):
|
78 |
text: str
|
|
|
79 |
|
80 |
@app.post("/speak")
|
81 |
def speak(payload: TTSRequest):
|
82 |
clean_text = normalize_text(payload.text)
|
83 |
inputs = processor(text=clean_text, return_tensors="pt").to(device)
|
|
|
|
|
84 |
|
85 |
with torch.no_grad():
|
86 |
+
waveform = model_female.generate_speech(inputs["input_ids"], embedding_female.unsqueeze(0), vocoder=vocoder)
|
87 |
|
88 |
out_path = f"/tmp/{uuid.uuid4().hex}.wav"
|
89 |
sf.write(out_path, waveform.cpu().numpy(), 16000)
|