Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/api.py +11 -8
- src/f5_tts/infer/utils_infer.py +6 -6
src/f5_tts/api.py
CHANGED
|
@@ -32,6 +32,7 @@ class F5TTS:
|
|
| 32 |
vocoder_name="vocos",
|
| 33 |
local_path=None,
|
| 34 |
device=None,
|
|
|
|
| 35 |
):
|
| 36 |
# Initialize parameters
|
| 37 |
self.final_wave = None
|
|
@@ -46,29 +47,31 @@ class F5TTS:
|
|
| 46 |
)
|
| 47 |
|
| 48 |
# Load models
|
| 49 |
-
self.load_vocoder_model(vocoder_name, local_path=local_path)
|
| 50 |
-
self.load_ema_model(
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
def load_vocoder_model(self, vocoder_name, local_path=None):
|
| 53 |
-
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
| 54 |
|
| 55 |
-
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema,
|
| 56 |
if model_type == "F5-TTS":
|
| 57 |
if not ckpt_file:
|
| 58 |
if mel_spec_type == "vocos":
|
| 59 |
ckpt_file = str(
|
| 60 |
-
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=
|
| 61 |
)
|
| 62 |
elif mel_spec_type == "bigvgan":
|
| 63 |
ckpt_file = str(
|
| 64 |
-
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=
|
| 65 |
)
|
| 66 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
| 67 |
model_cls = DiT
|
| 68 |
elif model_type == "E2-TTS":
|
| 69 |
if not ckpt_file:
|
| 70 |
ckpt_file = str(
|
| 71 |
-
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=
|
| 72 |
)
|
| 73 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 74 |
model_cls = UNetT
|
|
|
|
| 32 |
vocoder_name="vocos",
|
| 33 |
local_path=None,
|
| 34 |
device=None,
|
| 35 |
+
hf_cache_dir=None,
|
| 36 |
):
|
| 37 |
# Initialize parameters
|
| 38 |
self.final_wave = None
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
# Load models
|
| 50 |
+
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
| 51 |
+
self.load_ema_model(
|
| 52 |
+
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
|
| 53 |
+
)
|
| 54 |
|
| 55 |
+
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
|
| 56 |
+
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
|
| 57 |
|
| 58 |
+
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
|
| 59 |
if model_type == "F5-TTS":
|
| 60 |
if not ckpt_file:
|
| 61 |
if mel_spec_type == "vocos":
|
| 62 |
ckpt_file = str(
|
| 63 |
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
| 64 |
)
|
| 65 |
elif mel_spec_type == "bigvgan":
|
| 66 |
ckpt_file = str(
|
| 67 |
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
|
| 68 |
)
|
| 69 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
| 70 |
model_cls = DiT
|
| 71 |
elif model_type == "E2-TTS":
|
| 72 |
if not ckpt_file:
|
| 73 |
ckpt_file = str(
|
| 74 |
+
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
| 75 |
)
|
| 76 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 77 |
model_cls = UNetT
|
src/f5_tts/infer/utils_infer.py
CHANGED
|
@@ -90,18 +90,18 @@ def chunk_text(text, max_chars=135):
|
|
| 90 |
|
| 91 |
|
| 92 |
# load vocoder
|
| 93 |
-
def load_vocoder(vocoder_name="vocos", is_local=False, local_path=
|
| 94 |
if vocoder_name == "vocos":
|
| 95 |
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
| 96 |
-
if is_local
|
| 97 |
print(f"Load vocos from local path {local_path}")
|
| 98 |
config_path = f"{local_path}/config.yaml"
|
| 99 |
model_path = f"{local_path}/pytorch_model.bin"
|
| 100 |
else:
|
| 101 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
| 102 |
repo_id = "charactr/vocos-mel-24khz"
|
| 103 |
-
config_path = hf_hub_download(repo_id=repo_id, cache_dir=
|
| 104 |
-
model_path = hf_hub_download(repo_id=repo_id, cache_dir=
|
| 105 |
vocoder = Vocos.from_hparams(config_path)
|
| 106 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 107 |
from vocos.feature_extractors import EncodecFeatures
|
|
@@ -119,11 +119,11 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=d
|
|
| 119 |
from third_party.BigVGAN import bigvgan
|
| 120 |
except ImportError:
|
| 121 |
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
| 122 |
-
if is_local
|
| 123 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
| 124 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
| 125 |
else:
|
| 126 |
-
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=
|
| 127 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
| 128 |
|
| 129 |
vocoder.remove_weight_norm()
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
# load vocoder
|
| 93 |
+
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
|
| 94 |
if vocoder_name == "vocos":
|
| 95 |
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
| 96 |
+
if is_local:
|
| 97 |
print(f"Load vocos from local path {local_path}")
|
| 98 |
config_path = f"{local_path}/config.yaml"
|
| 99 |
model_path = f"{local_path}/pytorch_model.bin"
|
| 100 |
else:
|
| 101 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
| 102 |
repo_id = "charactr/vocos-mel-24khz"
|
| 103 |
+
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
| 104 |
+
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
| 105 |
vocoder = Vocos.from_hparams(config_path)
|
| 106 |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 107 |
from vocos.feature_extractors import EncodecFeatures
|
|
|
|
| 119 |
from third_party.BigVGAN import bigvgan
|
| 120 |
except ImportError:
|
| 121 |
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
| 122 |
+
if is_local:
|
| 123 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
| 124 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
| 125 |
else:
|
| 126 |
+
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
|
| 127 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
| 128 |
|
| 129 |
vocoder.remove_weight_norm()
|