jupyterlab-test2 / train.py
nateraw's picture
Update train.py
e120a31
raw
history blame
5.67 kB
from __future__ import annotations
import os
import re
import warnings
from logging import getLogger
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any
import lightning.pytorch as pl
import torch
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.tuner import Tuner
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
import so_vits_svc_fork.f0
import so_vits_svc_fork.modules.commons as commons
import so_vits_svc_fork.utils
from so_vits_svc_fork import utils
from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset
from so_vits_svc_fork.logger import is_notebook
from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator
from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
from so_vits_svc_fork.train import VitsLightning, VCDataModule
LOG = getLogger(__name__)
torch.set_float32_matmul_precision("high")
from pathlib import Path
from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file
if os.environ.get("HF_TOKEN"):
login(os.environ.get("HF_TOKEN"))
class HuggingFacePushCallback(pl.Callback):
def __init__(self, repo_id, private=False, every=100):
self.repo_id = repo_id
self.private = private
self.every = every
def on_validation_epoch_end(self, trainer, pl_module):
self.repo_url = create_repo(
repo_id=self.repo_id,
exist_ok=True,
private=self.private
)
self.repo_id = self.repo_url.repo_id
if pl_module.global_step == 0:
return
print(f"\nπŸ€— Pushing to Hugging Face Hub: {self.repo_url}...")
model_dir = pl_module.hparams.model_dir
upload_folder(
repo_id=self.repo_id,
folder_path=model_dir,
path_in_repo=".",
commit_message="🍻 cheers",
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
)
ckpt_pattern = r'^(D_|G_)\d+\.pth$'
todelete = []
repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]]
local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)]
to_delete = set(repo_ckpts) - set(local_ckpts)
for fname in to_delete:
print(f"πŸ—‘ Deleting {fname} from repo")
delete_file(fname, self.repo_id)
def train(
config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
):
config_path = Path(config_path)
model_path = Path(model_path)
hparams = utils.get_backup_hparams(config_path, model_path)
utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan"))
datamodule = VCDataModule(hparams)
strategy = (
(
"ddp_find_unused_parameters_true"
if os.name != "nt"
else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
)
if torch.cuda.device_count() > 1
else "auto"
)
LOG.info(f"Using strategy: {strategy}")
callbacks = []
if hparams.train.push_to_hub:
callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))
if not is_notebook():
callbacks.append(pl.callbacks.RichProgressBar())
if callbacks == []:
callbacks = None
trainer = pl.Trainer(
logger=TensorBoardLogger(
model_path, "lightning_logs", hparams.train.get("log_version", 0)
),
# profiler="simple",
val_check_interval=hparams.train.eval_interval,
max_epochs=hparams.train.epochs,
check_val_every_n_epoch=None,
precision="16-mixed"
if hparams.train.fp16_run
else "bf16-mixed"
if hparams.train.get("bf16_run", False)
else 32,
strategy=strategy,
callbacks=callbacks,
benchmark=True,
enable_checkpointing=False,
)
tuner = Tuner(trainer)
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
# automatic batch size scaling
batch_size = hparams.train.batch_size
batch_split = str(batch_size).split("-")
batch_size = batch_split[0]
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
if batch_size == "auto":
batch_size = "binsearch"
if batch_size in ["power", "binsearch"]:
model.tuning = True
tuner.scale_batch_size(
model,
mode=batch_size,
datamodule=datamodule,
steps_per_trial=1,
init_val=init_val,
max_trials=max_trials,
)
model.tuning = False
else:
batch_size = int(batch_size)
# automatic learning rate scaling is not supported for multiple optimizers
"""if hparams.train.learning_rate == "auto":
lr_finder = tuner.lr_find(model)
LOG.info(lr_finder.results)
fig = lr_finder.plot(suggest=True)
fig.savefig(model_path / "lr_finder.png")"""
trainer.fit(model, datamodule=datamodule)
if __name__ == '__main__':
train('configs/44k/config.json', 'logs/44k')