from __future__ import annotations import os import re import warnings from logging import getLogger from multiprocessing import cpu_count from pathlib import Path from typing import Any import lightning.pytorch as pl import torch from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.tuner import Tuner from torch.cuda.amp import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard.writer import SummaryWriter import so_vits_svc_fork.f0 import so_vits_svc_fork.modules.commons as commons import so_vits_svc_fork.utils from so_vits_svc_fork import utils from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset from so_vits_svc_fork.logger import is_notebook from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn from so_vits_svc_fork.train import VitsLightning, VCDataModule LOG = getLogger(__name__) torch.set_float32_matmul_precision("high") from pathlib import Path from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file if os.environ.get("HF_TOKEN"): login(os.environ.get("HF_TOKEN")) class HuggingFacePushCallback(pl.Callback): def __init__(self, repo_id, private=False, every=100): self.repo_id = repo_id self.private = private self.every = every def on_validation_epoch_end(self, trainer, pl_module): self.repo_url = create_repo( repo_id=self.repo_id, exist_ok=True, private=self.private ) self.repo_id = self.repo_url.repo_id if pl_module.global_step == 0: return print(f"\nšŸ¤— Pushing to Hugging Face Hub: {self.repo_url}...") model_dir = pl_module.hparams.model_dir upload_folder( repo_id=self.repo_id, folder_path=model_dir, path_in_repo=".", commit_message="šŸ» cheers", ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"], ) ckpt_pattern = r'^(D_|G_)\d+\.pth$' todelete = [] repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]] local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)] to_delete = set(repo_ckpts) - set(local_ckpts) for fname in to_delete: print(f"šŸ—‘ Deleting {fname} from repo") delete_file(fname, self.repo_id) def train( config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False ): config_path = Path(config_path) model_path = Path(model_path) hparams = utils.get_backup_hparams(config_path, model_path) utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan")) datamodule = VCDataModule(hparams) strategy = ( ( "ddp_find_unused_parameters_true" if os.name != "nt" else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") ) if torch.cuda.device_count() > 1 else "auto" ) LOG.info(f"Using strategy: {strategy}") callbacks = [] if hparams.train.push_to_hub: callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private)) if not is_notebook(): callbacks.append(pl.callbacks.RichProgressBar()) if callbacks == []: callbacks = None trainer = pl.Trainer( logger=TensorBoardLogger( model_path, "lightning_logs", hparams.train.get("log_version", 0) ), # profiler="simple", val_check_interval=hparams.train.eval_interval, max_epochs=hparams.train.epochs, check_val_every_n_epoch=None, precision="16-mixed" if hparams.train.fp16_run else "bf16-mixed" if hparams.train.get("bf16_run", False) else 32, strategy=strategy, callbacks=callbacks, benchmark=True, enable_checkpointing=False, ) tuner = Tuner(trainer) model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) # automatic batch size scaling batch_size = hparams.train.batch_size batch_split = str(batch_size).split("-") batch_size = batch_split[0] init_val = 2 if len(batch_split) <= 1 else int(batch_split[1]) max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2]) if batch_size == "auto": batch_size = "binsearch" if batch_size in ["power", "binsearch"]: model.tuning = True tuner.scale_batch_size( model, mode=batch_size, datamodule=datamodule, steps_per_trial=1, init_val=init_val, max_trials=max_trials, ) model.tuning = False else: batch_size = int(batch_size) # automatic learning rate scaling is not supported for multiple optimizers """if hparams.train.learning_rate == "auto": lr_finder = tuner.lr_find(model) LOG.info(lr_finder.results) fig = lr_finder.plot(suggest=True) fig.savefig(model_path / "lr_finder.png")""" trainer.fit(model, datamodule=datamodule) if __name__ == '__main__': train('configs/44k/config.json', 'logs/44k')