from __future__ import annotations import os import warnings from logging import getLogger from multiprocessing import cpu_count from pathlib import Path from typing import Any import lightning.pytorch as pl import torch from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.tuner import Tuner from torch.cuda.amp import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard.writer import SummaryWriter import so_vits_svc_fork.f0 import so_vits_svc_fork.modules.commons as commons import so_vits_svc_fork.utils from so_vits_svc_fork import utils from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset from so_vits_svc_fork.logger import is_notebook from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn from so_vits_svc_fork.train import VitsLightning LOG = getLogger(__name__) torch.set_float32_matmul_precision("high") from pathlib import Path from huggingface_hub import create_repo, upload_folder, login if os.environ.get("HF_TOKEN"): login(os.environ.get("HF_TOKEN")) class HuggingFacePushCallback(pl.Callback): def __init__(self, repo_id, private=False, every=100): self.repo_id = repo_id self.private = private self.every = every def on_validation_epoch_end(self, trainer, pl_module): self.repo_url = create_repo( repo_id=self.repo_id, exist_ok=True, private=self.private ) self.repo_id = self.repo_url.repo_id if pl_module.global_step == 0: return print(f"\nšŸ¤— Pushing to Hugging Face Hub: {self.repo_url}...") model_dir = pl_module.hparams.model_dir upload_folder( repo_id=self.repo_id, folder_path=model_dir, path_in_repo=".", commit_message="šŸ» cheers", ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"], ) class VCDataModule(pl.LightningDataModule): batch_size: int def __init__(self, hparams: Any): super().__init__() self.__hparams = hparams self.batch_size = hparams.train.batch_size if not isinstance(self.batch_size, int): self.batch_size = 1 self.collate_fn = TextAudioCollate() # these should be called in setup(), but we need to calculate check_val_every_n_epoch self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False) self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True) def train_dataloader(self): return DataLoader( self.train_dataset, num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)), batch_size=self.batch_size, collate_fn=self.collate_fn, persistent_workers=self.__hparams.train.get("persistent_workers", True), ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=1, collate_fn=self.collate_fn, ) def train( config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False ): config_path = Path(config_path) model_path = Path(model_path) hparams = utils.get_backup_hparams(config_path, model_path) utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan")) datamodule = VCDataModule(hparams) strategy = ( ( "ddp_find_unused_parameters_true" if os.name != "nt" else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") ) if torch.cuda.device_count() > 1 else "auto" ) LOG.info(f"Using strategy: {strategy}") callbacks = [] if hparams.train.push_to_hub: callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private)) if not is_notebook(): callbacks.append(pl.callbacks.RichProgressBar()) if callbacks == []: callbacks = None trainer = pl.Trainer( logger=TensorBoardLogger( model_path, "lightning_logs", hparams.train.get("log_version", 0) ), # profiler="simple", val_check_interval=hparams.train.eval_interval, max_epochs=hparams.train.epochs, check_val_every_n_epoch=None, precision="16-mixed" if hparams.train.fp16_run else "bf16-mixed" if hparams.train.get("bf16_run", False) else 32, strategy=strategy, callbacks=callbacks, benchmark=True, enable_checkpointing=False, ) tuner = Tuner(trainer) model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) # automatic batch size scaling batch_size = hparams.train.batch_size batch_split = str(batch_size).split("-") batch_size = batch_split[0] init_val = 2 if len(batch_split) <= 1 else int(batch_split[1]) max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2]) if batch_size == "auto": batch_size = "binsearch" if batch_size in ["power", "binsearch"]: model.tuning = True tuner.scale_batch_size( model, mode=batch_size, datamodule=datamodule, steps_per_trial=1, init_val=init_val, max_trials=max_trials, ) model.tuning = False else: batch_size = int(batch_size) # automatic learning rate scaling is not supported for multiple optimizers """if hparams.train.learning_rate == "auto": lr_finder = tuner.lr_find(model) LOG.info(lr_finder.results) fig = lr_finder.plot(suggest=True) fig.savefig(model_path / "lr_finder.png")""" trainer.fit(model, datamodule=datamodule) if __name__ == '__main__': train('configs/44k/config.json', 'logs/44k')