Spaces:
Paused
Paused
File size: 5,671 Bytes
b19723a be3ec09 b19723a 17e242c b19723a be3ec09 b19723a 17e242c be3ec09 e120a31 b19723a 66f3454 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from __future__ import annotations
import os
import re
import warnings
from logging import getLogger
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any
import lightning.pytorch as pl
import torch
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.tuner import Tuner
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
import so_vits_svc_fork.f0
import so_vits_svc_fork.modules.commons as commons
import so_vits_svc_fork.utils
from so_vits_svc_fork import utils
from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset
from so_vits_svc_fork.logger import is_notebook
from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator
from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
from so_vits_svc_fork.train import VitsLightning, VCDataModule
LOG = getLogger(__name__)
torch.set_float32_matmul_precision("high")
from pathlib import Path
from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file
if os.environ.get("HF_TOKEN"):
login(os.environ.get("HF_TOKEN"))
class HuggingFacePushCallback(pl.Callback):
def __init__(self, repo_id, private=False, every=100):
self.repo_id = repo_id
self.private = private
self.every = every
def on_validation_epoch_end(self, trainer, pl_module):
self.repo_url = create_repo(
repo_id=self.repo_id,
exist_ok=True,
private=self.private
)
self.repo_id = self.repo_url.repo_id
if pl_module.global_step == 0:
return
print(f"\nπ€ Pushing to Hugging Face Hub: {self.repo_url}...")
model_dir = pl_module.hparams.model_dir
upload_folder(
repo_id=self.repo_id,
folder_path=model_dir,
path_in_repo=".",
commit_message="π» cheers",
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
)
ckpt_pattern = r'^(D_|G_)\d+\.pth$'
todelete = []
repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]]
local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)]
to_delete = set(repo_ckpts) - set(local_ckpts)
for fname in to_delete:
print(f"π Deleting {fname} from repo")
delete_file(fname, self.repo_id)
def train(
config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
):
config_path = Path(config_path)
model_path = Path(model_path)
hparams = utils.get_backup_hparams(config_path, model_path)
utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan"))
datamodule = VCDataModule(hparams)
strategy = (
(
"ddp_find_unused_parameters_true"
if os.name != "nt"
else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
)
if torch.cuda.device_count() > 1
else "auto"
)
LOG.info(f"Using strategy: {strategy}")
callbacks = []
if hparams.train.push_to_hub:
callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))
if not is_notebook():
callbacks.append(pl.callbacks.RichProgressBar())
if callbacks == []:
callbacks = None
trainer = pl.Trainer(
logger=TensorBoardLogger(
model_path, "lightning_logs", hparams.train.get("log_version", 0)
),
# profiler="simple",
val_check_interval=hparams.train.eval_interval,
max_epochs=hparams.train.epochs,
check_val_every_n_epoch=None,
precision="16-mixed"
if hparams.train.fp16_run
else "bf16-mixed"
if hparams.train.get("bf16_run", False)
else 32,
strategy=strategy,
callbacks=callbacks,
benchmark=True,
enable_checkpointing=False,
)
tuner = Tuner(trainer)
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
# automatic batch size scaling
batch_size = hparams.train.batch_size
batch_split = str(batch_size).split("-")
batch_size = batch_split[0]
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
if batch_size == "auto":
batch_size = "binsearch"
if batch_size in ["power", "binsearch"]:
model.tuning = True
tuner.scale_batch_size(
model,
mode=batch_size,
datamodule=datamodule,
steps_per_trial=1,
init_val=init_val,
max_trials=max_trials,
)
model.tuning = False
else:
batch_size = int(batch_size)
# automatic learning rate scaling is not supported for multiple optimizers
"""if hparams.train.learning_rate == "auto":
lr_finder = tuner.lr_find(model)
LOG.info(lr_finder.results)
fig = lr_finder.plot(suggest=True)
fig.savefig(model_path / "lr_finder.png")"""
trainer.fit(model, datamodule=datamodule)
if __name__ == '__main__':
train('configs/44k/config.json', 'logs/44k')
|