Somalitts commited on
Commit
c82ae02
·
verified ·
1 Parent(s): 3af67f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -14,10 +14,9 @@ app = FastAPI()
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  CACHE_DIR = "/tmp/hf-cache"
16
 
17
- # Load models
18
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
19
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
20
- model_male = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/5aad", cache_dir=CACHE_DIR).to(device)
21
  model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
22
 
23
  # Speaker encoder
@@ -27,7 +26,7 @@ speaker_model = EncoderClassifier.from_hparams(
27
  savedir="/tmp/spk_model"
28
  )
29
 
30
- # Load speaker embeddings
31
  def get_embedding(wav_path, pt_path):
32
  if os.path.exists(pt_path):
33
  return torch.load(pt_path).to(device)
@@ -39,7 +38,6 @@ def get_embedding(wav_path, pt_path):
39
  torch.save(emb.cpu(), pt_path)
40
  return emb
41
 
42
- embedding_male = get_embedding("Hussein.wav", "/tmp/male_embedding.pt")
43
  embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
44
 
45
  # Text normalization
@@ -75,20 +73,17 @@ def normalize_text(text):
75
  text = re.sub(r'[^\w\s]', '', text)
76
  return text
77
 
78
- # API request schema
79
  class TTSRequest(BaseModel):
80
  text: str
81
- voice: str # "Male" or "Female"
82
 
83
  @app.post("/speak")
84
  def speak(payload: TTSRequest):
85
  clean_text = normalize_text(payload.text)
86
  inputs = processor(text=clean_text, return_tensors="pt").to(device)
87
- model = model_male if payload.voice.lower() == "male" else model_female
88
- embedding = embedding_male if payload.voice.lower() == "male" else embedding_female
89
 
90
  with torch.no_grad():
91
- waveform = model.generate_speech(inputs["input_ids"], embedding.unsqueeze(0), vocoder=vocoder)
92
 
93
  out_path = f"/tmp/{uuid.uuid4().hex}.wav"
94
  sf.write(out_path, waveform.cpu().numpy(), 16000)
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  CACHE_DIR = "/tmp/hf-cache"
16
 
17
+ # Load models (female only)
18
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
19
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
 
20
  model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
21
 
22
  # Speaker encoder
 
26
  savedir="/tmp/spk_model"
27
  )
28
 
29
+ # Load female embedding only
30
  def get_embedding(wav_path, pt_path):
31
  if os.path.exists(pt_path):
32
  return torch.load(pt_path).to(device)
 
38
  torch.save(emb.cpu(), pt_path)
39
  return emb
40
 
 
41
  embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
42
 
43
  # Text normalization
 
73
  text = re.sub(r'[^\w\s]', '', text)
74
  return text
75
 
76
+ # Request schema without voice choice
77
  class TTSRequest(BaseModel):
78
  text: str
 
79
 
80
  @app.post("/speak")
81
  def speak(payload: TTSRequest):
82
  clean_text = normalize_text(payload.text)
83
  inputs = processor(text=clean_text, return_tensors="pt").to(device)
 
 
84
 
85
  with torch.no_grad():
86
+ waveform = model_female.generate_speech(inputs["input_ids"], embedding_female.unsqueeze(0), vocoder=vocoder)
87
 
88
  out_path = f"/tmp/{uuid.uuid4().hex}.wav"
89
  sf.write(out_path, waveform.cpu().numpy(), 16000)