File size: 5,671 Bytes
b19723a
 
 
be3ec09
b19723a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e242c
b19723a
 
 
 
 
 
 
be3ec09
b19723a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e242c
 
 
 
 
 
 
be3ec09
e120a31
b19723a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66f3454
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations

import os
import re
import warnings
from logging import getLogger
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.tuner import Tuner
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

import so_vits_svc_fork.f0
import so_vits_svc_fork.modules.commons as commons
import so_vits_svc_fork.utils

from so_vits_svc_fork import utils
from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset
from so_vits_svc_fork.logger import is_notebook
from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator
from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn

from so_vits_svc_fork.train import VitsLightning, VCDataModule

LOG = getLogger(__name__)
torch.set_float32_matmul_precision("high")


from pathlib import Path

from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file

if os.environ.get("HF_TOKEN"):
    login(os.environ.get("HF_TOKEN"))


class HuggingFacePushCallback(pl.Callback):
    def __init__(self, repo_id, private=False, every=100):
        self.repo_id = repo_id
        self.private = private
        self.every = every

    def on_validation_epoch_end(self, trainer, pl_module):
        self.repo_url = create_repo(
            repo_id=self.repo_id,
            exist_ok=True,
            private=self.private
        )
        self.repo_id = self.repo_url.repo_id
        if pl_module.global_step == 0:
            return
        print(f"\nπŸ€— Pushing to Hugging Face Hub: {self.repo_url}...")
        model_dir = pl_module.hparams.model_dir
        upload_folder(
            repo_id=self.repo_id,
            folder_path=model_dir,
            path_in_repo=".",
            commit_message="🍻 cheers",
            ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
        )
        ckpt_pattern = r'^(D_|G_)\d+\.pth$'
        todelete = []
        repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]]
        local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)]
        to_delete = set(repo_ckpts) - set(local_ckpts)

        for fname in to_delete:
            print(f"πŸ—‘ Deleting {fname} from repo")
            delete_file(fname, self.repo_id)


def train(
    config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
):
    config_path = Path(config_path)
    model_path = Path(model_path)

    hparams = utils.get_backup_hparams(config_path, model_path)
    utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan"))

    datamodule = VCDataModule(hparams)
    strategy = (
        (
            "ddp_find_unused_parameters_true"
            if os.name != "nt"
            else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
        )
        if torch.cuda.device_count() > 1
        else "auto"
    )
    LOG.info(f"Using strategy: {strategy}")
    
    callbacks = []
    if hparams.train.push_to_hub:
        callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))
    if not is_notebook():
        callbacks.append(pl.callbacks.RichProgressBar())
    if callbacks == []:
        callbacks = None

    trainer = pl.Trainer(
        logger=TensorBoardLogger(
            model_path, "lightning_logs", hparams.train.get("log_version", 0)
        ),
        # profiler="simple",
        val_check_interval=hparams.train.eval_interval,
        max_epochs=hparams.train.epochs,
        check_val_every_n_epoch=None,
        precision="16-mixed"
        if hparams.train.fp16_run
        else "bf16-mixed"
        if hparams.train.get("bf16_run", False)
        else 32,
        strategy=strategy,
        callbacks=callbacks,
        benchmark=True,
        enable_checkpointing=False,
    )
    tuner = Tuner(trainer)
    model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)

    # automatic batch size scaling
    batch_size = hparams.train.batch_size
    batch_split = str(batch_size).split("-")
    batch_size = batch_split[0]
    init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
    max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
    if batch_size == "auto":
        batch_size = "binsearch"
    if batch_size in ["power", "binsearch"]:
        model.tuning = True
        tuner.scale_batch_size(
            model,
            mode=batch_size,
            datamodule=datamodule,
            steps_per_trial=1,
            init_val=init_val,
            max_trials=max_trials,
        )
        model.tuning = False
    else:
        batch_size = int(batch_size)
    # automatic learning rate scaling is not supported for multiple optimizers
    """if hparams.train.learning_rate  == "auto":
    lr_finder = tuner.lr_find(model)
    LOG.info(lr_finder.results)
    fig = lr_finder.plot(suggest=True)
    fig.savefig(model_path / "lr_finder.png")"""

    trainer.fit(model, datamodule=datamodule)

if __name__ == '__main__':
    train('configs/44k/config.json', 'logs/44k')