Somalitts commited on
Commit
95b6972
·
verified ·
1 Parent(s): 5444f28

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import uuid
4
+ import torch
5
+ import torchaudio
6
+ import soundfile as sf
7
+ from fastapi import FastAPI
8
+ from fastapi.responses import FileResponse
9
+ from pydantic import BaseModel
10
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
11
+ from speechbrain.inference.speaker import EncoderClassifier
12
+
13
+ app = FastAPI()
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ CACHE_DIR = "/tmp/hf-cache"
16
+
17
+ # Load models
18
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
19
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
20
+ model_male = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/5aad", cache_dir=CACHE_DIR).to(device)
21
+ model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
22
+
23
+ # Speaker encoder
24
+ speaker_model = EncoderClassifier.from_hparams(
25
+ source="speechbrain/spkrec-xvect-voxceleb",
26
+ run_opts={"device": device},
27
+ savedir="/tmp/spk_model"
28
+ )
29
+
30
+ # Load speaker embeddings
31
+ def get_embedding(wav_path, pt_path):
32
+ if os.path.exists(pt_path):
33
+ return torch.load(pt_path).to(device)
34
+ audio, sr = torchaudio.load(wav_path)
35
+ audio = torchaudio.functional.resample(audio, sr, 16000).mean(dim=0).unsqueeze(0).to(device)
36
+ with torch.no_grad():
37
+ emb = speaker_model.encode_batch(audio)
38
+ emb = torch.nn.functional.normalize(emb, dim=2).squeeze()
39
+ torch.save(emb.cpu(), pt_path)
40
+ return emb
41
+
42
+ embedding_male = get_embedding("Hussein.wav", "/tmp/male_embedding.pt")
43
+ embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
44
+
45
+ # Text normalization
46
+ number_words = {
47
+ 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
48
+ 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
49
+ 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
50
+ 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
51
+ 100: "boqol", 1000: "kun"
52
+ }
53
+
54
+ def number_to_words(n):
55
+ if n < 20:
56
+ return number_words.get(n, str(n))
57
+ elif n < 100:
58
+ tens, unit = divmod(n, 10)
59
+ return number_words[tens * 10] + (" " + number_words[unit] if unit else "")
60
+ elif n < 1000:
61
+ hundreds, rem = divmod(n, 100)
62
+ return (number_words[hundreds] + " boqol" if hundreds > 1 else "boqol") + (" " + number_to_words(rem) if rem else "")
63
+ elif n < 1_000_000:
64
+ th, rem = divmod(n, 1000)
65
+ return (number_to_words(th) + " kun") + (" " + number_to_words(rem) if rem else "")
66
+ else:
67
+ return str(n)
68
+
69
+ def replace_numbers_with_words(text):
70
+ return re.sub(r'\b\d+\b', lambda m: number_to_words(int(m.group())), text)
71
+
72
+ def normalize_text(text):
73
+ text = text.lower()
74
+ text = replace_numbers_with_words(text)
75
+ text = re.sub(r'[^\w\s]', '', text)
76
+ return text
77
+
78
+ # API request schema
79
+ class TTSRequest(BaseModel):
80
+ text: str
81
+ voice: str # "Male" or "Female"
82
+
83
+ @app.post("/speak")
84
+ def speak(payload: TTSRequest):
85
+ clean_text = normalize_text(payload.text)
86
+ inputs = processor(text=clean_text, return_tensors="pt").to(device)
87
+ model = model_male if payload.voice.lower() == "male" else model_female
88
+ embedding = embedding_male if payload.voice.lower() == "male" else embedding_female
89
+
90
+ with torch.no_grad():
91
+ waveform = model.generate_speech(inputs["input_ids"], embedding.unsqueeze(0), vocoder=vocoder)
92
+
93
+ out_path = f"/tmp/{uuid.uuid4().hex}.wav"
94
+ sf.write(out_path, waveform.cpu().numpy(), 16000)
95
+ return FileResponse(out_path, media_type="audio/wav", filename="voice.wav")