kemuririn
commited on
Commit
·
8ccaa64
1
Parent(s):
6eee896
debug for hf space
Browse files- indextts/infer.py +4 -0
- webui.py +2 -0
indextts/infer.py
CHANGED
|
@@ -44,6 +44,9 @@ class IndexTTS:
|
|
| 44 |
self.bigvgan = self.bigvgan.to(self.device)
|
| 45 |
self.bigvgan.eval()
|
| 46 |
print(">> bigvgan weights restored from:", self.bigvgan_path)
|
|
|
|
|
|
|
|
|
|
| 47 |
self.normalizer = TextNormalizer()
|
| 48 |
self.normalizer.load()
|
| 49 |
print(">> TextNormalizer loaded")
|
|
@@ -156,4 +159,5 @@ class IndexTTS:
|
|
| 156 |
|
| 157 |
if __name__ == "__main__":
|
| 158 |
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
|
|
|
| 159 |
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav")
|
|
|
|
| 44 |
self.bigvgan = self.bigvgan.to(self.device)
|
| 45 |
self.bigvgan.eval()
|
| 46 |
print(">> bigvgan weights restored from:", self.bigvgan_path)
|
| 47 |
+
self.normalizer = None
|
| 48 |
+
|
| 49 |
+
def load_normalizer(self):
|
| 50 |
self.normalizer = TextNormalizer()
|
| 51 |
self.normalizer.load()
|
| 52 |
print(">> TextNormalizer loaded")
|
|
|
|
| 159 |
|
| 160 |
if __name__ == "__main__":
|
| 161 |
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
| 162 |
+
tts.load_normalizer()
|
| 163 |
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav")
|
webui.py
CHANGED
|
@@ -28,6 +28,7 @@ def infer(voice, text,output_path=None):
|
|
| 28 |
global tts
|
| 29 |
if not tts:
|
| 30 |
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
|
|
|
|
| 31 |
if not output_path:
|
| 32 |
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
| 33 |
tts.infer(voice, text, output_path)
|
|
@@ -76,6 +77,7 @@ def main():
|
|
| 76 |
global tts
|
| 77 |
if not tts:
|
| 78 |
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
|
|
|
|
| 79 |
demo.queue(20)
|
| 80 |
demo.launch(server_name="0.0.0.0")
|
| 81 |
|
|
|
|
| 28 |
global tts
|
| 29 |
if not tts:
|
| 30 |
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
|
| 31 |
+
tts.load_normalizer()
|
| 32 |
if not output_path:
|
| 33 |
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
| 34 |
tts.infer(voice, text, output_path)
|
|
|
|
| 77 |
global tts
|
| 78 |
if not tts:
|
| 79 |
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
|
| 80 |
+
tts.load_normalizer()
|
| 81 |
demo.queue(20)
|
| 82 |
demo.launch(server_name="0.0.0.0")
|
| 83 |
|