|
import os |
|
|
|
import torch |
|
from trainer import Trainer, TrainerArgs |
|
|
|
from TTS.config import load_config |
|
from TTS.config.shared_configs import BaseDatasetConfig |
|
from TTS.tts.configs.vits_config import VitsConfig |
|
from TTS.tts.datasets import load_tts_samples |
|
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig |
|
from TTS.tts.utils.speakers import SpeakerManager |
|
from TTS.tts.utils.managers import save_file |
|
from tqdm import tqdm |
|
import json |
|
|
|
import tarfile |
|
|
|
torch.set_num_threads(24) |
|
|
|
|
|
def nemo(root_path, meta_file, **kwargs): |
|
""" |
|
Normalizes NeMo-style json manifest files to TTS format |
|
""" |
|
meta_path = os.path.join(root_path, meta_file) |
|
items = [] |
|
with open(meta_path, "r", encoding="utf-8") as ttf: |
|
for line in ttf: |
|
cols = json.loads(line) |
|
wav_file = cols["audio_filepath"] |
|
text = cols["text"] |
|
speaker_name = cols["speaker_name"] if "speaker_name" in cols else "one" |
|
language = cols["language"] if "language" in cols else "" |
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "language": language, "root_path": root_path}) |
|
return items |
|
|
|
|
|
def compute_embeddings( |
|
model_path, |
|
config_path, |
|
output_path, |
|
old_speakers_file=None, |
|
old_append=False, |
|
config_dataset_path=None, |
|
formatter=None, |
|
dataset_name=None, |
|
dataset_path=None, |
|
meta_file_train=None, |
|
meta_file_val=None, |
|
disable_cuda=False, |
|
no_eval=False, |
|
): |
|
use_cuda = torch.cuda.is_available() and not disable_cuda |
|
|
|
if config_dataset_path is not None: |
|
c_dataset = load_config(config_dataset_path) |
|
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval) |
|
else: |
|
c_dataset = BaseDatasetConfig() |
|
c_dataset.dataset_name = dataset_name |
|
c_dataset.path = dataset_path |
|
if meta_file_train is not None: |
|
c_dataset.meta_file_train = meta_file_train |
|
if meta_file_val is not None: |
|
c_dataset.meta_file_val = meta_file_val |
|
meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval, formatter=formatter) |
|
|
|
if meta_data_eval is None: |
|
samples = meta_data_train |
|
else: |
|
samples = meta_data_train + meta_data_eval |
|
|
|
encoder_manager = SpeakerManager( |
|
encoder_model_path=model_path, |
|
encoder_config_path=config_path, |
|
d_vectors_file_path=old_speakers_file, |
|
use_cuda=use_cuda, |
|
) |
|
|
|
class_name_key = encoder_manager.encoder_config.class_name_key |
|
|
|
|
|
if old_speakers_file is not None and old_append: |
|
speaker_mapping = encoder_manager.embeddings |
|
else: |
|
speaker_mapping = {} |
|
|
|
for fields in tqdm(samples): |
|
class_name = fields[class_name_key] |
|
audio_file = fields["audio_file"] |
|
embedding_key = fields["audio_unique_name"] |
|
|
|
|
|
if embedding_key in speaker_mapping: |
|
speaker_mapping[embedding_key]["name"] = class_name |
|
continue |
|
|
|
if old_speakers_file is not None and embedding_key in encoder_manager.clip_ids: |
|
|
|
embedd = encoder_manager.get_embedding_by_clip(embedding_key) |
|
else: |
|
|
|
embedd = encoder_manager.compute_embedding_from_clip(audio_file) |
|
|
|
|
|
speaker_mapping[embedding_key] = {} |
|
speaker_mapping[embedding_key]["name"] = class_name |
|
speaker_mapping[embedding_key]["embedding"] = embedd |
|
|
|
if speaker_mapping: |
|
|
|
if os.path.isdir(output_path): |
|
mapping_file_path = os.path.join(output_path, "speakers.pth") |
|
else: |
|
mapping_file_path = output_path |
|
|
|
if os.path.dirname(mapping_file_path) != "": |
|
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) |
|
|
|
save_file(speaker_mapping, mapping_file_path) |
|
print("Speaker embeddings saved at:", mapping_file_path) |
|
|
|
OUT_PATH = "yourtts_hausa" |
|
LANG_NAME = "hausa" |
|
ISO = "ha" |
|
|
|
|
|
RUN_NAME = f"YourTTS-{LANG_NAME.capitalize()}" |
|
|
|
|
|
RESTORE_PATH = os.path.join(OUT_PATH, "checkpoints_yourtts_cml_tts_dataset/best_model.pth") |
|
|
|
URL = "https://drive.google.com/u/2/uc?id=1yDCSJ1pFZQTHhL09GMbOrdjcPULApa0p" |
|
OUTPUT_CHECKPOINTS_FILEPATH = os.path.join(OUT_PATH, "checkpoints_yourtts_cml_tts_dataset.tar.bz") |
|
|
|
|
|
if not os.path.exists(RESTORE_PATH): |
|
print(f"Downloading the CML-TTS checkpoint from {URL}") |
|
gdown.download(url=URL, output=OUTPUT_CHECKPOINTS_FILEPATH, quiet=False, fuzzy=True) |
|
with tarfile.open(OUTPUT_CHECKPOINTS_FILEPATH, "r:bz2") as tar: |
|
tar.extractall(OUT_PATH) |
|
else: |
|
print(f"Checkpoint already exists at {RESTORE_PATH}") |
|
|
|
|
|
SKIP_TRAIN_EPOCH = False |
|
|
|
|
|
BATCH_SIZE = 4 |
|
|
|
|
|
SAMPLE_RATE = 24000 |
|
|
|
|
|
MAX_AUDIO_LEN_IN_SECONDS = 11 |
|
|
|
MIN_AUDIO_LEN_IN_SECONDS = 0.8 |
|
|
|
dataset_conf = BaseDatasetConfig( |
|
dataset_name=f"{ISO}_openbible", |
|
meta_file_train="manifest_train.jsonl", |
|
meta_file_val="manifest_dev.jsonl", |
|
language=ISO, |
|
path="data/hausa/tts_data" |
|
) |
|
|
|
|
|
SPEAKER_ENCODER_CHECKPOINT_PATH = ( |
|
"https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar" |
|
) |
|
SPEAKER_ENCODER_CONFIG_PATH = "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json" |
|
|
|
D_VECTOR_FILES = [] |
|
|
|
|
|
embeddings_file = os.path.join(dataset_conf.path, "speakers.pth") |
|
if not os.path.isfile(embeddings_file): |
|
print(f">>> Computing the speaker embeddings for the {dataset_conf.dataset_name} dataset") |
|
compute_embeddings( |
|
SPEAKER_ENCODER_CHECKPOINT_PATH, |
|
SPEAKER_ENCODER_CONFIG_PATH, |
|
embeddings_file, |
|
formatter=nemo, |
|
dataset_name=dataset_conf.dataset_name, |
|
dataset_path=dataset_conf.path, |
|
meta_file_train=dataset_conf.meta_file_train, |
|
meta_file_val=dataset_conf.meta_file_val, |
|
) |
|
D_VECTOR_FILES.append(embeddings_file) |
|
|
|
|
|
audio_config = VitsAudioConfig( |
|
sample_rate=SAMPLE_RATE, |
|
hop_length=256, |
|
win_length=1024, |
|
fft_size=1024, |
|
mel_fmin=0.0, |
|
mel_fmax=None, |
|
num_mels=80, |
|
) |
|
|
|
|
|
model_args = VitsArgs( |
|
spec_segment_size=62, |
|
hidden_channels=192, |
|
hidden_channels_ffn_text_encoder=768, |
|
num_heads_text_encoder=2, |
|
num_layers_text_encoder=10, |
|
kernel_size_text_encoder=3, |
|
dropout_p_text_encoder=0.1, |
|
d_vector_file=D_VECTOR_FILES, |
|
use_d_vector_file=True, |
|
d_vector_dim=512, |
|
speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH, |
|
speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH, |
|
resblock_type_decoder="2", |
|
|
|
use_speaker_encoder_as_loss=False, |
|
|
|
use_language_embedding=True, |
|
embedded_language_dim=4, |
|
) |
|
|
|
CHARS = ["'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'w', 'y', 'z', 'ā', 'ă', 'ū', 'ƙ', 'ɓ', 'ɗ', '’'] |
|
PUNCT = [' ', '!', ',', '.', ':', ';', '?'] |
|
|
|
TEST_SENTENCES = [ |
|
["umarnai don zaman tsarki.", "two", None, "ha"], |
|
["wanda kuma ya faɗa mana ƙaunar da kuke yi cikin ruhu.", "one", None, "ha"], |
|
["gama mun ji labarin bangaskiyarku a cikin yesu kiristi da kuma ƙaunar da kuke yi saboda dukan tsarkaka.", "two", None, "ha"], |
|
] |
|
|
|
|
|
config = VitsConfig( |
|
output_path=OUT_PATH, |
|
model_args=model_args, |
|
run_name=RUN_NAME, |
|
project_name="YourTTS", |
|
run_description=f""" |
|
- YourTTS trained using the {LANG_NAME.capitalize()} OpenBible dataset. |
|
""", |
|
dashboard_logger="tensorboard", |
|
logger_uri=None, |
|
audio=audio_config, |
|
batch_size=BATCH_SIZE, |
|
batch_group_size=4, |
|
eval_batch_size=BATCH_SIZE, |
|
num_loader_workers=8, |
|
|
|
print_step=50, |
|
plot_step=100, |
|
|
|
save_step=1000, |
|
save_n_checkpoints=2, |
|
save_checkpoints=True, |
|
target_loss="loss_1", |
|
print_eval=True, |
|
compute_input_seq_cache=True, |
|
add_blank=True, |
|
text_cleaner="no_cleaners", |
|
characters=CharactersConfig( |
|
characters_class="TTS.tts.models.vits.VitsCharacters", |
|
pad="_", |
|
eos="&", |
|
bos="*", |
|
blank=None, |
|
characters="".join(CHARS), |
|
punctuations="".join(PUNCT), |
|
), |
|
phoneme_cache_path=None, |
|
precompute_num_workers=12, |
|
start_by_longest=True, |
|
datasets=[dataset_conf], |
|
cudnn_benchmark=False, |
|
min_audio_len=int(SAMPLE_RATE * MIN_AUDIO_LEN_IN_SECONDS), |
|
max_audio_len=SAMPLE_RATE * MAX_AUDIO_LEN_IN_SECONDS, |
|
mixed_precision=True, |
|
test_sentences=TEST_SENTENCES, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
speaker_encoder_loss_alpha=9.0, |
|
) |
|
|
|
|
|
train_samples, eval_samples = load_tts_samples( |
|
config.datasets, |
|
eval_split=True, |
|
formatter=nemo, |
|
eval_split_max_size=config.eval_split_max_size, |
|
eval_split_size=config.eval_split_size, |
|
) |
|
print(f"Loaded {len(train_samples)} train samples") |
|
print(f"Loaded {len(eval_samples)} eval samples") |
|
|
|
|
|
model = Vits.init_from_config(config) |
|
|
|
|
|
trainer = Trainer( |
|
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH), |
|
config, |
|
output_path=OUT_PATH, |
|
model=model, |
|
train_samples=train_samples, |
|
eval_samples=eval_samples, |
|
) |
|
trainer.fit() |
|
|