Spaces:
Running
Running
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- api.py +2 -2
- inference-cli.py +2 -2
- model/utils_infer.py +22 -10
api.py
CHANGED
@@ -33,10 +33,10 @@ class F5TTS:
|
|
33 |
)
|
34 |
|
35 |
# Load models
|
36 |
-
self.
|
37 |
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
38 |
|
39 |
-
def
|
40 |
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
41 |
|
42 |
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
|
|
33 |
)
|
34 |
|
35 |
# Load models
|
36 |
+
self.load_vocoder_model(local_path)
|
37 |
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
38 |
|
39 |
+
def load_vocoder_model(self, local_path):
|
40 |
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
41 |
|
42 |
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
inference-cli.py
CHANGED
@@ -104,7 +104,7 @@ if model == "F5-TTS":
|
|
104 |
exp_name = "F5TTS_Base"
|
105 |
ckpt_step = 1200000
|
106 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
107 |
-
#
|
108 |
|
109 |
elif model == "E2-TTS":
|
110 |
model_cls = UNetT
|
@@ -114,7 +114,7 @@ elif model == "E2-TTS":
|
|
114 |
exp_name = "E2TTS_Base"
|
115 |
ckpt_step = 1200000
|
116 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
117 |
-
#
|
118 |
|
119 |
print(f"Using {model}...")
|
120 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
|
|
104 |
exp_name = "F5TTS_Base"
|
105 |
ckpt_step = 1200000
|
106 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
107 |
+
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
108 |
|
109 |
elif model == "E2-TTS":
|
110 |
model_cls = UNetT
|
|
|
114 |
exp_name = "E2TTS_Base"
|
115 |
ckpt_step = 1200000
|
116 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
117 |
+
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
118 |
|
119 |
print(f"Using {model}...")
|
120 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
model/utils_infer.py
CHANGED
@@ -22,13 +22,6 @@ from model.utils import (
|
|
22 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
23 |
print(f"Using {device} device")
|
24 |
|
25 |
-
asr_pipe = pipeline(
|
26 |
-
"automatic-speech-recognition",
|
27 |
-
model="openai/whisper-large-v3-turbo",
|
28 |
-
torch_dtype=torch.float16,
|
29 |
-
device=device,
|
30 |
-
)
|
31 |
-
|
32 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
33 |
|
34 |
|
@@ -82,8 +75,6 @@ def chunk_text(text, max_chars=135):
|
|
82 |
|
83 |
|
84 |
# load vocoder
|
85 |
-
|
86 |
-
|
87 |
def load_vocoder(is_local=False, local_path="", device=device):
|
88 |
if is_local:
|
89 |
print(f"Load vocos from local path {local_path}")
|
@@ -97,6 +88,22 @@ def load_vocoder(is_local=False, local_path="", device=device):
|
|
97 |
return vocos
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
# load model for inference
|
101 |
|
102 |
|
@@ -133,7 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler
|
|
133 |
# preprocess reference audio and text
|
134 |
|
135 |
|
136 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
137 |
show_info("Converting audio...")
|
138 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
139 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
@@ -152,6 +159,9 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
|
152 |
ref_audio = f.name
|
153 |
|
154 |
if not ref_text.strip():
|
|
|
|
|
|
|
155 |
show_info("No reference text provided, transcribing reference audio...")
|
156 |
ref_text = asr_pipe(
|
157 |
ref_audio,
|
@@ -329,6 +339,8 @@ def infer_batch_process(
|
|
329 |
|
330 |
|
331 |
# remove silence from generated wav
|
|
|
|
|
332 |
def remove_silence_for_generated_wav(filename):
|
333 |
aseg = AudioSegment.from_file(filename)
|
334 |
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
|
|
22 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
23 |
print(f"Using {device} device")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
26 |
|
27 |
|
|
|
75 |
|
76 |
|
77 |
# load vocoder
|
|
|
|
|
78 |
def load_vocoder(is_local=False, local_path="", device=device):
|
79 |
if is_local:
|
80 |
print(f"Load vocos from local path {local_path}")
|
|
|
88 |
return vocos
|
89 |
|
90 |
|
91 |
+
# load asr pipeline
|
92 |
+
|
93 |
+
asr_pipe = None
|
94 |
+
|
95 |
+
|
96 |
+
def initialize_asr_pipeline(device=device):
|
97 |
+
global asr_pipe
|
98 |
+
|
99 |
+
asr_pipe = pipeline(
|
100 |
+
"automatic-speech-recognition",
|
101 |
+
model="openai/whisper-large",
|
102 |
+
torch_dtype=torch.float16,
|
103 |
+
device=device,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
# load model for inference
|
108 |
|
109 |
|
|
|
140 |
# preprocess reference audio and text
|
141 |
|
142 |
|
143 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
|
144 |
show_info("Converting audio...")
|
145 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
146 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
159 |
ref_audio = f.name
|
160 |
|
161 |
if not ref_text.strip():
|
162 |
+
global asr_pipe
|
163 |
+
if asr_pipe is None:
|
164 |
+
initialize_asr_pipeline(device=device)
|
165 |
show_info("No reference text provided, transcribing reference audio...")
|
166 |
ref_text = asr_pipe(
|
167 |
ref_audio,
|
|
|
339 |
|
340 |
|
341 |
# remove silence from generated wav
|
342 |
+
|
343 |
+
|
344 |
def remove_silence_for_generated_wav(filename):
|
345 |
aseg = AudioSegment.from_file(filename)
|
346 |
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|