nateraw commited on
Commit
b19723a
·
1 Parent(s): 911e8ea
Files changed (4) hide show
  1. main.sh +1 -1
  2. start_server.sh +12 -12
  3. train.py +185 -0
  4. update_config.py +5 -1
main.sh CHANGED
@@ -3,4 +3,4 @@ svc pre-resample
3
  svc pre-config
4
  svc pre-hubert -fm crepe
5
  python update_config.py
6
- svc train
 
3
  svc pre-config
4
  svc pre-hubert -fm crepe
5
  python update_config.py
6
+ python train.py
start_server.sh CHANGED
@@ -6,18 +6,18 @@ echo "Starting Jupyter Lab with token $JUPYTER_TOKEN"
6
  nohup bash ./main.sh &
7
 
8
  export TF_CPP_MIN_LOG_LEVEL="2"
9
- # tensorboard --logdir=logs/44k --host 0.0.0.0 --port 7860
10
 
11
- jupyter-lab \
12
- --ip 0.0.0.0 \
13
- --port 7860 \
14
- --no-browser \
15
- --allow-root \
16
- --ServerApp.token="$JUPYTER_TOKEN" \
17
- --ServerApp.tornado_settings="{'headers': {'Content-Security-Policy': 'frame-ancestors *'}}" \
18
- --ServerApp.cookie_options="{'SameSite': 'None', 'Secure': True}" \
19
- --ServerApp.disable_check_xsrf=True \
20
- --LabApp.news_url=None \
21
- --LabApp.check_for_updates_class="jupyterlab.NeverCheckForUpdate"
22
 
23
  # python app.py
 
6
  nohup bash ./main.sh &
7
 
8
  export TF_CPP_MIN_LOG_LEVEL="2"
9
+ tensorboard --logdir=logs/44k --host 0.0.0.0 --port 7860
10
 
11
+ # jupyter-lab \
12
+ # --ip 0.0.0.0 \
13
+ # --port 7860 \
14
+ # --no-browser \
15
+ # --allow-root \
16
+ # --ServerApp.token="$JUPYTER_TOKEN" \
17
+ # --ServerApp.tornado_settings="{'headers': {'Content-Security-Policy': 'frame-ancestors *'}}" \
18
+ # --ServerApp.cookie_options="{'SameSite': 'None', 'Secure': True}" \
19
+ # --ServerApp.disable_check_xsrf=True \
20
+ # --LabApp.news_url=None \
21
+ # --LabApp.check_for_updates_class="jupyterlab.NeverCheckForUpdate"
22
 
23
  # python app.py
train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+ from logging import getLogger
6
+ from multiprocessing import cpu_count
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import lightning.pytorch as pl
11
+ import torch
12
+ from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
13
+ from lightning.pytorch.loggers import TensorBoardLogger
14
+ from lightning.pytorch.strategies.ddp import DDPStrategy
15
+ from lightning.pytorch.tuner import Tuner
16
+ from torch.cuda.amp import autocast
17
+ from torch.nn import functional as F
18
+ from torch.utils.data import DataLoader
19
+ from torch.utils.tensorboard.writer import SummaryWriter
20
+
21
+ import so_vits_svc_fork.f0
22
+ import so_vits_svc_fork.modules.commons as commons
23
+ import so_vits_svc_fork.utils
24
+
25
+ from so_vits_svc_fork import utils
26
+ from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset
27
+ from so_vits_svc_fork.logger import is_notebook
28
+ from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator
29
+ from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
30
+ from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
31
+ from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
32
+
33
+ from so_vits_svc_fork.train import VitsLightning
34
+
35
+ LOG = getLogger(__name__)
36
+ torch.set_float32_matmul_precision("high")
37
+
38
+
39
+ from pathlib import Path
40
+
41
+ from huggingface_hub import create_repo, upload_folder, login
42
+
43
+ if os.environ.get("HF_TOKEN"):
44
+ login(os.environ.get("HF_TOKEN"))
45
+
46
+
47
+ class HuggingFacePushCallback(pl.Callback):
48
+ def __init__(self, repo_id, private=False, every=100):
49
+ self.repo_id = repo_id
50
+ self.private = private
51
+ self.every = every
52
+
53
+ def on_validation_epoch_end(self, trainer, pl_module):
54
+ self.repo_url = create_repo(
55
+ repo_id=self.repo_id,
56
+ exist_ok=True,
57
+ private=self.private
58
+ )
59
+ self.repo_id = self.repo_url.repo_id
60
+ if pl_module.global_step == 0:
61
+ return
62
+ print(f"\n🤗 Pushing to Hugging Face Hub: {self.repo_url}...")
63
+ model_dir = pl_module.hparams.model_dir
64
+ upload_folder(
65
+ repo_id=self.repo_id,
66
+ folder_path=model_dir,
67
+ path_in_repo=".",
68
+ commit_message="🍻 cheers",
69
+ ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
70
+ )
71
+
72
+ class VCDataModule(pl.LightningDataModule):
73
+ batch_size: int
74
+
75
+ def __init__(self, hparams: Any):
76
+ super().__init__()
77
+ self.__hparams = hparams
78
+ self.batch_size = hparams.train.batch_size
79
+ if not isinstance(self.batch_size, int):
80
+ self.batch_size = 1
81
+ self.collate_fn = TextAudioCollate()
82
+
83
+ # these should be called in setup(), but we need to calculate check_val_every_n_epoch
84
+ self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
85
+ self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
86
+
87
+ def train_dataloader(self):
88
+ return DataLoader(
89
+ self.train_dataset,
90
+ num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
91
+ batch_size=self.batch_size,
92
+ collate_fn=self.collate_fn,
93
+ persistent_workers=self.__hparams.train.get("persistent_workers", True),
94
+ )
95
+
96
+ def val_dataloader(self):
97
+ return DataLoader(
98
+ self.val_dataset,
99
+ batch_size=1,
100
+ collate_fn=self.collate_fn,
101
+ )
102
+
103
+
104
+ def train(
105
+ config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
106
+ ):
107
+ config_path = Path(config_path)
108
+ model_path = Path(model_path)
109
+
110
+ hparams = utils.get_backup_hparams(config_path, model_path)
111
+ utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan"))
112
+
113
+ datamodule = VCDataModule(hparams)
114
+ strategy = (
115
+ (
116
+ "ddp_find_unused_parameters_true"
117
+ if os.name != "nt"
118
+ else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
119
+ )
120
+ if torch.cuda.device_count() > 1
121
+ else "auto"
122
+ )
123
+ LOG.info(f"Using strategy: {strategy}")
124
+
125
+ callbacks = []
126
+ if hparams.train.push_to_hub:
127
+ callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))
128
+ if not is_notebook():
129
+ callbacks.append(pl.callbacks.RichProgressBar())
130
+ if callbacks == []:
131
+ callbacks = None
132
+
133
+ trainer = pl.Trainer(
134
+ logger=TensorBoardLogger(
135
+ model_path, "lightning_logs", hparams.train.get("log_version", 0)
136
+ ),
137
+ # profiler="simple",
138
+ val_check_interval=hparams.train.eval_interval,
139
+ max_epochs=hparams.train.epochs,
140
+ check_val_every_n_epoch=None,
141
+ precision="16-mixed"
142
+ if hparams.train.fp16_run
143
+ else "bf16-mixed"
144
+ if hparams.train.get("bf16_run", False)
145
+ else 32,
146
+ strategy=strategy,
147
+ callbacks=callbacks,
148
+ benchmark=True,
149
+ enable_checkpointing=False,
150
+ )
151
+ tuner = Tuner(trainer)
152
+ model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
153
+
154
+ # automatic batch size scaling
155
+ batch_size = hparams.train.batch_size
156
+ batch_split = str(batch_size).split("-")
157
+ batch_size = batch_split[0]
158
+ init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
159
+ max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
160
+ if batch_size == "auto":
161
+ batch_size = "binsearch"
162
+ if batch_size in ["power", "binsearch"]:
163
+ model.tuning = True
164
+ tuner.scale_batch_size(
165
+ model,
166
+ mode=batch_size,
167
+ datamodule=datamodule,
168
+ steps_per_trial=1,
169
+ init_val=init_val,
170
+ max_trials=max_trials,
171
+ )
172
+ model.tuning = False
173
+ else:
174
+ batch_size = int(batch_size)
175
+ # automatic learning rate scaling is not supported for multiple optimizers
176
+ """if hparams.train.learning_rate == "auto":
177
+ lr_finder = tuner.lr_find(model)
178
+ LOG.info(lr_finder.results)
179
+ fig = lr_finder.plot(suggest=True)
180
+ fig.savefig(model_path / "lr_finder.png")"""
181
+
182
+ trainer.fit(model, datamodule=datamodule)
183
+
184
+ if __name__ == '__main__':
185
+ train('configs/44k/config.json', 'logs/44k_new')
update_config.py CHANGED
@@ -10,10 +10,14 @@ from pathlib import Path
10
  def main(config_file="configs/44k/config.json"):
11
  config_path = Path(config_file)
12
  data = json.loads(config_path.read_text())
 
 
13
  data['train']['num_workers'] = 0
14
  data['train']['persistent_workers'] = False
 
 
 
15
  config_path.write_text(json.dumps(data, indent=2, sort_keys=False))
16
 
17
-
18
  if __name__ == "__main__":
19
  main()
 
10
  def main(config_file="configs/44k/config.json"):
11
  config_path = Path(config_file)
12
  data = json.loads(config_path.read_text())
13
+ data['train']['batch_size'] = 16
14
+ data['train']['eval_interval'] = 800
15
  data['train']['num_workers'] = 0
16
  data['train']['persistent_workers'] = False
17
+ data['train']['push_to_hub'] = True
18
+ data['train']['repo_id'] = tuple(data['spk'])[0]
19
+ data['train']['private'] = True
20
  config_path.write_text(json.dumps(data, indent=2, sort_keys=False))
21
 
 
22
  if __name__ == "__main__":
23
  main()