Spaces:
Running
Running
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/dataset.py +18 -2
- model/utils.py +7 -0
- train.py +7 -4
model/dataset.py
CHANGED
|
@@ -184,10 +184,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
| 184 |
|
| 185 |
def load_dataset(
|
| 186 |
dataset_name: str,
|
| 187 |
-
tokenizer: str,
|
| 188 |
dataset_type: str = "CustomDataset",
|
| 189 |
audio_type: str = "raw",
|
| 190 |
mel_spec_kwargs: dict = dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
) -> CustomDataset:
|
| 192 |
|
| 193 |
print("Loading dataset ...")
|
|
@@ -206,7 +211,18 @@ def load_dataset(
|
|
| 206 |
data_dict = json.load(f)
|
| 207 |
durations = data_dict["duration"]
|
| 208 |
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
elif dataset_type == "HFDataset":
|
| 211 |
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
| 212 |
"May also the corresponding script cuz different dataset may have different format.")
|
|
|
|
| 184 |
|
| 185 |
def load_dataset(
|
| 186 |
dataset_name: str,
|
| 187 |
+
tokenizer: str = "pinyon",
|
| 188 |
dataset_type: str = "CustomDataset",
|
| 189 |
audio_type: str = "raw",
|
| 190 |
mel_spec_kwargs: dict = dict()
|
| 191 |
+
) -> CustomDataset | HFDataset:
|
| 192 |
+
'''
|
| 193 |
+
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
| 194 |
+
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
| 195 |
+
'''
|
| 196 |
) -> CustomDataset:
|
| 197 |
|
| 198 |
print("Loading dataset ...")
|
|
|
|
| 211 |
data_dict = json.load(f)
|
| 212 |
durations = data_dict["duration"]
|
| 213 |
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
| 214 |
+
|
| 215 |
+
elif dataset_type == "CustomDatasetPath":
|
| 216 |
+
try:
|
| 217 |
+
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
| 218 |
+
except:
|
| 219 |
+
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
| 220 |
+
|
| 221 |
+
with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
|
| 222 |
+
data_dict = json.load(f)
|
| 223 |
+
durations = data_dict["duration"]
|
| 224 |
+
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
| 225 |
+
|
| 226 |
elif dataset_type == "HFDataset":
|
| 227 |
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
| 228 |
"May also the corresponding script cuz different dataset may have different format.")
|
model/utils.py
CHANGED
|
@@ -129,6 +129,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
| 129 |
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
| 130 |
- "char" for char-wise tokenizer, need .txt vocab_file
|
| 131 |
- "byte" for utf-8 tokenizer
|
|
|
|
| 132 |
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
| 133 |
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
| 134 |
- if use "byte", set to 256 (unicode byte range)
|
|
@@ -144,6 +145,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
| 144 |
elif tokenizer == "byte":
|
| 145 |
vocab_char_map = None
|
| 146 |
vocab_size = 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
return vocab_char_map, vocab_size
|
| 149 |
|
|
|
|
| 129 |
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
| 130 |
- "char" for char-wise tokenizer, need .txt vocab_file
|
| 131 |
- "byte" for utf-8 tokenizer
|
| 132 |
+
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
| 133 |
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
| 134 |
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
| 135 |
- if use "byte", set to 256 (unicode byte range)
|
|
|
|
| 145 |
elif tokenizer == "byte":
|
| 146 |
vocab_char_map = None
|
| 147 |
vocab_size = 256
|
| 148 |
+
elif tokenizer == "custom":
|
| 149 |
+
with open (dataset_name, "r", encoding="utf-8") as f:
|
| 150 |
+
vocab_char_map = {}
|
| 151 |
+
for i, char in enumerate(f):
|
| 152 |
+
vocab_char_map[char[:-1]] = i
|
| 153 |
+
vocab_size = len(vocab_char_map)
|
| 154 |
|
| 155 |
return vocab_char_map, vocab_size
|
| 156 |
|
train.py
CHANGED
|
@@ -9,10 +9,10 @@ target_sample_rate = 24000
|
|
| 9 |
n_mel_channels = 100
|
| 10 |
hop_length = 256
|
| 11 |
|
| 12 |
-
tokenizer = "pinyin"
|
|
|
|
| 13 |
dataset_name = "Emilia_ZH_EN"
|
| 14 |
|
| 15 |
-
|
| 16 |
# -------------------------- Training Settings -------------------------- #
|
| 17 |
|
| 18 |
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
|
@@ -44,8 +44,11 @@ elif exp_name == "E2TTS_Base":
|
|
| 44 |
# ----------------------------------------------------------------------- #
|
| 45 |
|
| 46 |
def main():
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
mel_spec_kwargs = dict(
|
| 51 |
target_sample_rate = target_sample_rate,
|
|
|
|
| 9 |
n_mel_channels = 100
|
| 10 |
hop_length = 256
|
| 11 |
|
| 12 |
+
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
| 13 |
+
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
| 14 |
dataset_name = "Emilia_ZH_EN"
|
| 15 |
|
|
|
|
| 16 |
# -------------------------- Training Settings -------------------------- #
|
| 17 |
|
| 18 |
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
|
|
|
| 44 |
# ----------------------------------------------------------------------- #
|
| 45 |
|
| 46 |
def main():
|
| 47 |
+
if tokenizer == "custom":
|
| 48 |
+
tokenizer_path = tokenizer_path
|
| 49 |
+
else:
|
| 50 |
+
tokenizer_path = dataset_name
|
| 51 |
+
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
| 52 |
|
| 53 |
mel_spec_kwargs = dict(
|
| 54 |
target_sample_rate = target_sample_rate,
|