nateraw commited on
Commit
64c7e3d
·
1 Parent(s): 911e8ea

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +185 -0
train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+ from logging import getLogger
6
+ from multiprocessing import cpu_count
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import lightning.pytorch as pl
11
+ import torch
12
+ from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
13
+ from lightning.pytorch.loggers import TensorBoardLogger
14
+ from lightning.pytorch.strategies.ddp import DDPStrategy
15
+ from lightning.pytorch.tuner import Tuner
16
+ from torch.cuda.amp import autocast
17
+ from torch.nn import functional as F
18
+ from torch.utils.data import DataLoader
19
+ from torch.utils.tensorboard.writer import SummaryWriter
20
+
21
+ import so_vits_svc_fork.f0
22
+ import so_vits_svc_fork.modules.commons as commons
23
+ import so_vits_svc_fork.utils
24
+
25
+ from so_vits_svc_fork import utils
26
+ from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset
27
+ from so_vits_svc_fork.logger import is_notebook
28
+ from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator
29
+ from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
30
+ from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
31
+ from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
32
+
33
+ from so_vits_svc_fork.train import VitsLightning
34
+
35
+ LOG = getLogger(__name__)
36
+ torch.set_float32_matmul_precision("high")
37
+
38
+
39
+ from pathlib import Path
40
+
41
+ from huggingface_hub import create_repo, upload_folder, login
42
+
43
+ if os.environ.get("HF_TOKEN"):
44
+ login(os.environ.get("HF_TOKEN"))
45
+
46
+
47
+ class HuggingFacePushCallback(pl.Callback):
48
+ def __init__(self, repo_id, private=False, every=100):
49
+ self.repo_id = repo_id
50
+ self.private = private
51
+ self.every = every
52
+
53
+ def on_validation_epoch_end(self, trainer, pl_module):
54
+ self.repo_url = create_repo(
55
+ repo_id=self.repo_id,
56
+ exist_ok=True,
57
+ private=self.private
58
+ )
59
+ self.repo_id = self.repo_url.repo_id
60
+ if pl_module.global_step == 0:
61
+ return
62
+ print(f"\n🤗 Pushing to Hugging Face Hub: {self.repo_url}...")
63
+ model_dir = pl_module.hparams.model_dir
64
+ upload_folder(
65
+ repo_id=self.repo_id,
66
+ folder_path=model_dir,
67
+ path_in_repo=".",
68
+ commit_message="🍻 cheers",
69
+ ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
70
+ )
71
+
72
+ class VCDataModule(pl.LightningDataModule):
73
+ batch_size: int
74
+
75
+ def __init__(self, hparams: Any):
76
+ super().__init__()
77
+ self.__hparams = hparams
78
+ self.batch_size = hparams.train.batch_size
79
+ if not isinstance(self.batch_size, int):
80
+ self.batch_size = 1
81
+ self.collate_fn = TextAudioCollate()
82
+
83
+ # these should be called in setup(), but we need to calculate check_val_every_n_epoch
84
+ self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
85
+ self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
86
+
87
+ def train_dataloader(self):
88
+ return DataLoader(
89
+ self.train_dataset,
90
+ num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
91
+ batch_size=self.batch_size,
92
+ collate_fn=self.collate_fn,
93
+ persistent_workers=self.__hparams.train.get("persistent_workers", True),
94
+ )
95
+
96
+ def val_dataloader(self):
97
+ return DataLoader(
98
+ self.val_dataset,
99
+ batch_size=1,
100
+ collate_fn=self.collate_fn,
101
+ )
102
+
103
+
104
+ def train(
105
+ config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
106
+ ):
107
+ config_path = Path(config_path)
108
+ model_path = Path(model_path)
109
+
110
+ hparams = utils.get_backup_hparams(config_path, model_path)
111
+ utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan"))
112
+
113
+ datamodule = VCDataModule(hparams)
114
+ strategy = (
115
+ (
116
+ "ddp_find_unused_parameters_true"
117
+ if os.name != "nt"
118
+ else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
119
+ )
120
+ if torch.cuda.device_count() > 1
121
+ else "auto"
122
+ )
123
+ LOG.info(f"Using strategy: {strategy}")
124
+
125
+ callbacks = []
126
+ if hparams.train.push_to_hub:
127
+ callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))
128
+ if not is_notebook():
129
+ callbacks.append(pl.callbacks.RichProgressBar())
130
+ if callbacks == []:
131
+ callbacks = None
132
+
133
+ trainer = pl.Trainer(
134
+ logger=TensorBoardLogger(
135
+ model_path, "lightning_logs", hparams.train.get("log_version", 0)
136
+ ),
137
+ # profiler="simple",
138
+ val_check_interval=hparams.train.eval_interval,
139
+ max_epochs=hparams.train.epochs,
140
+ check_val_every_n_epoch=None,
141
+ precision="16-mixed"
142
+ if hparams.train.fp16_run
143
+ else "bf16-mixed"
144
+ if hparams.train.get("bf16_run", False)
145
+ else 32,
146
+ strategy=strategy,
147
+ callbacks=callbacks,
148
+ benchmark=True,
149
+ enable_checkpointing=False,
150
+ )
151
+ tuner = Tuner(trainer)
152
+ model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
153
+
154
+ # automatic batch size scaling
155
+ batch_size = hparams.train.batch_size
156
+ batch_split = str(batch_size).split("-")
157
+ batch_size = batch_split[0]
158
+ init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
159
+ max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
160
+ if batch_size == "auto":
161
+ batch_size = "binsearch"
162
+ if batch_size in ["power", "binsearch"]:
163
+ model.tuning = True
164
+ tuner.scale_batch_size(
165
+ model,
166
+ mode=batch_size,
167
+ datamodule=datamodule,
168
+ steps_per_trial=1,
169
+ init_val=init_val,
170
+ max_trials=max_trials,
171
+ )
172
+ model.tuning = False
173
+ else:
174
+ batch_size = int(batch_size)
175
+ # automatic learning rate scaling is not supported for multiple optimizers
176
+ """if hparams.train.learning_rate == "auto":
177
+ lr_finder = tuner.lr_find(model)
178
+ LOG.info(lr_finder.results)
179
+ fig = lr_finder.plot(suggest=True)
180
+ fig.savefig(model_path / "lr_finder.png")"""
181
+
182
+ trainer.fit(model, datamodule=datamodule)
183
+
184
+ if __name__ == '__main__':
185
+ train('configs/44k/config.json', 'logs/44k_new')