|
from __future__ import annotations |
|
|
|
import os |
|
import warnings |
|
from logging import getLogger |
|
from multiprocessing import cpu_count |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import lightning.pytorch as pl |
|
import torch |
|
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
from lightning.pytorch.strategies.ddp import DDPStrategy |
|
from lightning.pytorch.tuner import Tuner |
|
from torch.cuda.amp import autocast |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
|
|
import so_vits_svc_fork.f0 |
|
import so_vits_svc_fork.modules.commons as commons |
|
import so_vits_svc_fork.utils |
|
|
|
from so_vits_svc_fork import utils |
|
from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset |
|
from so_vits_svc_fork.logger import is_notebook |
|
from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator |
|
from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss |
|
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch |
|
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn |
|
|
|
from so_vits_svc_fork.train import VitsLightning |
|
|
|
LOG = getLogger(__name__) |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
from pathlib import Path |
|
|
|
from huggingface_hub import create_repo, upload_folder, login |
|
|
|
if os.environ.get("HF_TOKEN"): |
|
login(os.environ.get("HF_TOKEN")) |
|
|
|
|
|
class HuggingFacePushCallback(pl.Callback): |
|
def __init__(self, repo_id, private=False, every=100): |
|
self.repo_id = repo_id |
|
self.private = private |
|
self.every = every |
|
|
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
self.repo_url = create_repo( |
|
repo_id=self.repo_id, |
|
exist_ok=True, |
|
private=self.private |
|
) |
|
self.repo_id = self.repo_url.repo_id |
|
if pl_module.global_step == 0: |
|
return |
|
print(f"\nπ€ Pushing to Hugging Face Hub: {self.repo_url}...") |
|
model_dir = pl_module.hparams.model_dir |
|
upload_folder( |
|
repo_id=self.repo_id, |
|
folder_path=model_dir, |
|
path_in_repo=".", |
|
commit_message="π» cheers", |
|
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"], |
|
) |
|
|
|
class VCDataModule(pl.LightningDataModule): |
|
batch_size: int |
|
|
|
def __init__(self, hparams: Any): |
|
super().__init__() |
|
self.__hparams = hparams |
|
self.batch_size = hparams.train.batch_size |
|
if not isinstance(self.batch_size, int): |
|
self.batch_size = 1 |
|
self.collate_fn = TextAudioCollate() |
|
|
|
|
|
self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False) |
|
self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)), |
|
batch_size=self.batch_size, |
|
collate_fn=self.collate_fn, |
|
persistent_workers=self.__hparams.train.get("persistent_workers", True), |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=1, |
|
collate_fn=self.collate_fn, |
|
) |
|
|
|
|
|
def train( |
|
config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False |
|
): |
|
config_path = Path(config_path) |
|
model_path = Path(model_path) |
|
|
|
hparams = utils.get_backup_hparams(config_path, model_path) |
|
utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan")) |
|
|
|
datamodule = VCDataModule(hparams) |
|
strategy = ( |
|
( |
|
"ddp_find_unused_parameters_true" |
|
if os.name != "nt" |
|
else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") |
|
) |
|
if torch.cuda.device_count() > 1 |
|
else "auto" |
|
) |
|
LOG.info(f"Using strategy: {strategy}") |
|
|
|
callbacks = [] |
|
if hparams.train.push_to_hub: |
|
callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private)) |
|
if not is_notebook(): |
|
callbacks.append(pl.callbacks.RichProgressBar()) |
|
if callbacks == []: |
|
callbacks = None |
|
|
|
trainer = pl.Trainer( |
|
logger=TensorBoardLogger( |
|
model_path, "lightning_logs", hparams.train.get("log_version", 0) |
|
), |
|
|
|
val_check_interval=hparams.train.eval_interval, |
|
max_epochs=hparams.train.epochs, |
|
check_val_every_n_epoch=None, |
|
precision="16-mixed" |
|
if hparams.train.fp16_run |
|
else "bf16-mixed" |
|
if hparams.train.get("bf16_run", False) |
|
else 32, |
|
strategy=strategy, |
|
callbacks=callbacks, |
|
benchmark=True, |
|
enable_checkpointing=False, |
|
) |
|
tuner = Tuner(trainer) |
|
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) |
|
|
|
|
|
batch_size = hparams.train.batch_size |
|
batch_split = str(batch_size).split("-") |
|
batch_size = batch_split[0] |
|
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1]) |
|
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2]) |
|
if batch_size == "auto": |
|
batch_size = "binsearch" |
|
if batch_size in ["power", "binsearch"]: |
|
model.tuning = True |
|
tuner.scale_batch_size( |
|
model, |
|
mode=batch_size, |
|
datamodule=datamodule, |
|
steps_per_trial=1, |
|
init_val=init_val, |
|
max_trials=max_trials, |
|
) |
|
model.tuning = False |
|
else: |
|
batch_size = int(batch_size) |
|
|
|
"""if hparams.train.learning_rate == "auto": |
|
lr_finder = tuner.lr_find(model) |
|
LOG.info(lr_finder.results) |
|
fig = lr_finder.plot(suggest=True) |
|
fig.savefig(model_path / "lr_finder.png")""" |
|
|
|
trainer.fit(model, datamodule=datamodule) |
|
|
|
if __name__ == '__main__': |
|
train('configs/44k/config.json', 'logs/44k') |
|
|