Hemant0000 commited on
Commit
cb84bf3
·
verified ·
1 Parent(s): d05fcdb

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +92 -0
train.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CFM, UNetT, DiT, Trainer
2
+ from model.utils import get_tokenizer
3
+ from model.dataset import load_dataset
4
+
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
+ dataset_name = "Emilia_ZH_EN"
15
+
16
+ # -------------------------- Training Settings -------------------------- #
17
+
18
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
+
20
+ learning_rate = 7.5e-5
21
+
22
+ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
+ batch_size_type = "frame" # "frame" or "sample"
24
+ max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
+ grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.0
27
+
28
+ epochs = 11 # use linear decay, thus epochs control the slope
29
+ num_warmup_updates = 20000 # warmup steps
30
+ save_per_updates = 50000 # save checkpoint per steps
31
+ last_per_steps = 5000 # save last checkpoint per steps
32
+
33
+ # model params
34
+ if exp_name == "F5TTS_Base":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
38
+ elif exp_name == "E2TTS_Base":
39
+ wandb_resume_id = None
40
+ model_cls = UNetT
41
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
42
+
43
+
44
+ # ----------------------------------------------------------------------- #
45
+
46
+
47
+ def main():
48
+ if tokenizer == "custom":
49
+ tokenizer_path = tokenizer_path
50
+ else:
51
+ tokenizer_path = dataset_name
52
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
53
+
54
+ mel_spec_kwargs = dict(
55
+ target_sample_rate=target_sample_rate,
56
+ n_mel_channels=n_mel_channels,
57
+ hop_length=hop_length,
58
+ )
59
+
60
+ model = CFM(
61
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
62
+ mel_spec_kwargs=mel_spec_kwargs,
63
+ vocab_char_map=vocab_char_map,
64
+ )
65
+
66
+ trainer = Trainer(
67
+ model,
68
+ epochs,
69
+ learning_rate,
70
+ num_warmup_updates=num_warmup_updates,
71
+ save_per_updates=save_per_updates,
72
+ checkpoint_path=f"ckpts/{exp_name}",
73
+ batch_size=batch_size_per_gpu,
74
+ batch_size_type=batch_size_type,
75
+ max_samples=max_samples,
76
+ grad_accumulation_steps=grad_accumulation_steps,
77
+ max_grad_norm=max_grad_norm,
78
+ wandb_project="CFM-TTS",
79
+ wandb_run_name=exp_name,
80
+ wandb_resume_id=wandb_resume_id,
81
+ last_per_steps=last_per_steps,
82
+ )
83
+
84
+ train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(
86
+ train_dataset,
87
+ resumable_with_seed=666, # seed for shuffling dataset
88
+ )
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()