Spaces:
Paused
Paused
from __future__ import annotations | |
import os | |
import warnings | |
from logging import getLogger | |
from multiprocessing import cpu_count | |
from pathlib import Path | |
from typing import Any | |
import lightning.pytorch as pl | |
import torch | |
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator | |
from lightning.pytorch.loggers import TensorBoardLogger | |
from lightning.pytorch.strategies.ddp import DDPStrategy | |
from lightning.pytorch.tuner import Tuner | |
from torch.cuda.amp import autocast | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard.writer import SummaryWriter | |
import so_vits_svc_fork.f0 | |
import so_vits_svc_fork.modules.commons as commons | |
import so_vits_svc_fork.utils | |
from so_vits_svc_fork import utils | |
from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset | |
from so_vits_svc_fork.logger import is_notebook | |
from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator | |
from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss | |
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch | |
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn | |
from so_vits_svc_fork.train import VitsLightning | |
LOG = getLogger(__name__) | |
torch.set_float32_matmul_precision("high") | |
from pathlib import Path | |
from huggingface_hub import create_repo, upload_folder, login | |
if os.environ.get("HF_TOKEN"): | |
login(os.environ.get("HF_TOKEN")) | |
class HuggingFacePushCallback(pl.Callback): | |
def __init__(self, repo_id, private=False, every=100): | |
self.repo_id = repo_id | |
self.private = private | |
self.every = every | |
def on_validation_epoch_end(self, trainer, pl_module): | |
self.repo_url = create_repo( | |
repo_id=self.repo_id, | |
exist_ok=True, | |
private=self.private | |
) | |
self.repo_id = self.repo_url.repo_id | |
if pl_module.global_step == 0: | |
return | |
print(f"\nπ€ Pushing to Hugging Face Hub: {self.repo_url}...") | |
model_dir = pl_module.hparams.model_dir | |
upload_folder( | |
repo_id=self.repo_id, | |
folder_path=model_dir, | |
path_in_repo=".", | |
commit_message="π» cheers", | |
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"], | |
) | |
class VCDataModule(pl.LightningDataModule): | |
batch_size: int | |
def __init__(self, hparams: Any): | |
super().__init__() | |
self.__hparams = hparams | |
self.batch_size = hparams.train.batch_size | |
if not isinstance(self.batch_size, int): | |
self.batch_size = 1 | |
self.collate_fn = TextAudioCollate() | |
# these should be called in setup(), but we need to calculate check_val_every_n_epoch | |
self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False) | |
self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True) | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)), | |
batch_size=self.batch_size, | |
collate_fn=self.collate_fn, | |
persistent_workers=self.__hparams.train.get("persistent_workers", True), | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_dataset, | |
batch_size=1, | |
collate_fn=self.collate_fn, | |
) | |
def train( | |
config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False | |
): | |
config_path = Path(config_path) | |
model_path = Path(model_path) | |
hparams = utils.get_backup_hparams(config_path, model_path) | |
utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan")) | |
datamodule = VCDataModule(hparams) | |
strategy = ( | |
( | |
"ddp_find_unused_parameters_true" | |
if os.name != "nt" | |
else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") | |
) | |
if torch.cuda.device_count() > 1 | |
else "auto" | |
) | |
LOG.info(f"Using strategy: {strategy}") | |
callbacks = [] | |
if hparams.train.push_to_hub: | |
callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private)) | |
if not is_notebook(): | |
callbacks.append(pl.callbacks.RichProgressBar()) | |
if callbacks == []: | |
callbacks = None | |
trainer = pl.Trainer( | |
logger=TensorBoardLogger( | |
model_path, "lightning_logs", hparams.train.get("log_version", 0) | |
), | |
# profiler="simple", | |
val_check_interval=hparams.train.eval_interval, | |
max_epochs=hparams.train.epochs, | |
check_val_every_n_epoch=None, | |
precision="16-mixed" | |
if hparams.train.fp16_run | |
else "bf16-mixed" | |
if hparams.train.get("bf16_run", False) | |
else 32, | |
strategy=strategy, | |
callbacks=callbacks, | |
benchmark=True, | |
enable_checkpointing=False, | |
) | |
tuner = Tuner(trainer) | |
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) | |
# automatic batch size scaling | |
batch_size = hparams.train.batch_size | |
batch_split = str(batch_size).split("-") | |
batch_size = batch_split[0] | |
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1]) | |
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2]) | |
if batch_size == "auto": | |
batch_size = "binsearch" | |
if batch_size in ["power", "binsearch"]: | |
model.tuning = True | |
tuner.scale_batch_size( | |
model, | |
mode=batch_size, | |
datamodule=datamodule, | |
steps_per_trial=1, | |
init_val=init_val, | |
max_trials=max_trials, | |
) | |
model.tuning = False | |
else: | |
batch_size = int(batch_size) | |
# automatic learning rate scaling is not supported for multiple optimizers | |
"""if hparams.train.learning_rate == "auto": | |
lr_finder = tuner.lr_find(model) | |
LOG.info(lr_finder.results) | |
fig = lr_finder.plot(suggest=True) | |
fig.savefig(model_path / "lr_finder.png")""" | |
trainer.fit(model, datamodule=datamodule) | |
if __name__ == '__main__': | |
train('configs/44k/config.json', 'logs/44k') | |