Spaces:
Sleeping
Sleeping
Vladimir Alabov
commited on
Commit
•
46b0a70
1
Parent(s):
1ba66c7
Refactor #3
Browse files- so_vits_svc_fork/__init__.py +5 -0
- so_vits_svc_fork/__main__.py +917 -0
- so_vits_svc_fork/cluster/__init__.py +48 -0
- so_vits_svc_fork/cluster/train_cluster.py +141 -0
- so_vits_svc_fork/dataset.py +87 -0
- so_vits_svc_fork/default_gui_presets.json +92 -0
- so_vits_svc_fork/f0.py +239 -0
- so_vits_svc_fork/gui.py +851 -0
- so_vits_svc_fork/hparams.py +38 -0
- so_vits_svc_fork/inference/__init__.py +0 -0
- so_vits_svc_fork/inference/core.py +692 -0
- so_vits_svc_fork/inference/main.py +272 -0
- so_vits_svc_fork/logger.py +46 -0
- so_vits_svc_fork/modules/__init__.py +0 -0
- so_vits_svc_fork/modules/attentions.py +488 -0
- so_vits_svc_fork/modules/commons.py +132 -0
- so_vits_svc_fork/modules/decoders/__init__.py +0 -0
- so_vits_svc_fork/modules/decoders/f0.py +46 -0
- so_vits_svc_fork/modules/decoders/hifigan/__init__.py +3 -0
- so_vits_svc_fork/modules/decoders/hifigan/_models.py +311 -0
- so_vits_svc_fork/modules/decoders/hifigan/_utils.py +15 -0
- so_vits_svc_fork/modules/decoders/mb_istft/__init__.py +15 -0
- so_vits_svc_fork/modules/decoders/mb_istft/_generators.py +376 -0
- so_vits_svc_fork/modules/decoders/mb_istft/_loss.py +11 -0
- so_vits_svc_fork/modules/decoders/mb_istft/_pqmf.py +128 -0
- so_vits_svc_fork/modules/decoders/mb_istft/_stft.py +244 -0
- so_vits_svc_fork/modules/decoders/mb_istft/_stft_loss.py +142 -0
- so_vits_svc_fork/modules/descriminators.py +177 -0
- so_vits_svc_fork/modules/encoders.py +136 -0
- so_vits_svc_fork/modules/flows.py +48 -0
- so_vits_svc_fork/modules/losses.py +58 -0
- so_vits_svc_fork/modules/mel_processing.py +205 -0
- so_vits_svc_fork/modules/modules.py +452 -0
- so_vits_svc_fork/modules/synthesizers.py +233 -0
- so_vits_svc_fork/preprocessing/__init__.py +0 -0
- so_vits_svc_fork/preprocessing/config_templates/quickvc.json +78 -0
- so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json +69 -0
- so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json +71 -0
- so_vits_svc_fork/preprocessing/preprocess_classify.py +95 -0
- so_vits_svc_fork/preprocessing/preprocess_flist_config.py +86 -0
- so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py +157 -0
- so_vits_svc_fork/preprocessing/preprocess_resample.py +144 -0
- so_vits_svc_fork/preprocessing/preprocess_speaker_diarization.py +93 -0
- so_vits_svc_fork/preprocessing/preprocess_split.py +78 -0
- so_vits_svc_fork/preprocessing/preprocess_utils.py +5 -0
- so_vits_svc_fork/py.typed +0 -0
- so_vits_svc_fork/train.py +571 -0
- so_vits_svc_fork/utils.py +478 -0
so_vits_svc_fork/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "4.1.1"
|
2 |
+
|
3 |
+
from .logger import init_logger
|
4 |
+
|
5 |
+
init_logger()
|
so_vits_svc_fork/__main__.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
from logging import getLogger
|
5 |
+
from multiprocessing import freeze_support
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Literal
|
8 |
+
|
9 |
+
import click
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from so_vits_svc_fork import __version__
|
13 |
+
from so_vits_svc_fork.utils import get_optimal_device
|
14 |
+
|
15 |
+
LOG = getLogger(__name__)
|
16 |
+
|
17 |
+
IS_TEST = "test" in Path(__file__).parent.stem
|
18 |
+
if IS_TEST:
|
19 |
+
LOG.debug("Test mode is on.")
|
20 |
+
|
21 |
+
|
22 |
+
class RichHelpFormatter(click.HelpFormatter):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
indent_increment: int = 2,
|
26 |
+
width: int | None = None,
|
27 |
+
max_width: int | None = None,
|
28 |
+
) -> None:
|
29 |
+
width = 100
|
30 |
+
super().__init__(indent_increment, width, max_width)
|
31 |
+
LOG.info(f"Version: {__version__}")
|
32 |
+
|
33 |
+
|
34 |
+
def patch_wrap_text():
|
35 |
+
orig_wrap_text = click.formatting.wrap_text
|
36 |
+
|
37 |
+
def wrap_text(
|
38 |
+
text,
|
39 |
+
width=78,
|
40 |
+
initial_indent="",
|
41 |
+
subsequent_indent="",
|
42 |
+
preserve_paragraphs=False,
|
43 |
+
):
|
44 |
+
return orig_wrap_text(
|
45 |
+
text.replace("\n", "\n\n"),
|
46 |
+
width=width,
|
47 |
+
initial_indent=initial_indent,
|
48 |
+
subsequent_indent=subsequent_indent,
|
49 |
+
preserve_paragraphs=True,
|
50 |
+
).replace("\n\n", "\n")
|
51 |
+
|
52 |
+
click.formatting.wrap_text = wrap_text
|
53 |
+
|
54 |
+
|
55 |
+
patch_wrap_text()
|
56 |
+
|
57 |
+
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"], show_default=True)
|
58 |
+
click.Context.formatter_class = RichHelpFormatter
|
59 |
+
|
60 |
+
|
61 |
+
@click.group(context_settings=CONTEXT_SETTINGS)
|
62 |
+
def cli():
|
63 |
+
"""so-vits-svc allows any folder structure for training data.
|
64 |
+
However, the following folder structure is recommended.\n
|
65 |
+
When training: dataset_raw/{speaker_name}/**/{wav_name}.{any_format}\n
|
66 |
+
When inference: configs/44k/config.json, logs/44k/G_XXXX.pth\n
|
67 |
+
If the folder structure is followed, you DO NOT NEED TO SPECIFY model path, config path, etc.
|
68 |
+
(The latest model will be automatically loaded.)\n
|
69 |
+
To train a model, run pre-resample, pre-config, pre-hubert, train.\n
|
70 |
+
To infer a model, run infer.
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
@cli.command()
|
75 |
+
@click.option(
|
76 |
+
"-c",
|
77 |
+
"--config-path",
|
78 |
+
type=click.Path(exists=True),
|
79 |
+
help="path to config",
|
80 |
+
default=Path("./configs/44k/config.json"),
|
81 |
+
)
|
82 |
+
@click.option(
|
83 |
+
"-m",
|
84 |
+
"--model-path",
|
85 |
+
type=click.Path(),
|
86 |
+
help="path to output dir",
|
87 |
+
default=Path("./logs/44k"),
|
88 |
+
)
|
89 |
+
@click.option(
|
90 |
+
"-t/-nt",
|
91 |
+
"--tensorboard/--no-tensorboard",
|
92 |
+
default=False,
|
93 |
+
type=bool,
|
94 |
+
help="launch tensorboard",
|
95 |
+
)
|
96 |
+
@click.option(
|
97 |
+
"-r",
|
98 |
+
"--reset-optimizer",
|
99 |
+
default=False,
|
100 |
+
type=bool,
|
101 |
+
help="reset optimizer",
|
102 |
+
is_flag=True,
|
103 |
+
)
|
104 |
+
def train(
|
105 |
+
config_path: Path,
|
106 |
+
model_path: Path,
|
107 |
+
tensorboard: bool = False,
|
108 |
+
reset_optimizer: bool = False,
|
109 |
+
):
|
110 |
+
"""Train model
|
111 |
+
If D_0.pth or G_0.pth not found, automatically download from hub."""
|
112 |
+
from .train import train
|
113 |
+
|
114 |
+
config_path = Path(config_path)
|
115 |
+
model_path = Path(model_path)
|
116 |
+
|
117 |
+
if tensorboard:
|
118 |
+
import webbrowser
|
119 |
+
|
120 |
+
from tensorboard import program
|
121 |
+
|
122 |
+
getLogger("tensorboard").setLevel(30)
|
123 |
+
tb = program.TensorBoard()
|
124 |
+
tb.configure(argv=[None, "--logdir", model_path.as_posix()])
|
125 |
+
url = tb.launch()
|
126 |
+
webbrowser.open(url)
|
127 |
+
|
128 |
+
train(
|
129 |
+
config_path=config_path, model_path=model_path, reset_optimizer=reset_optimizer
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
@cli.command()
|
134 |
+
def gui():
|
135 |
+
"""Opens GUI
|
136 |
+
for conversion and realtime inference"""
|
137 |
+
from .gui import main
|
138 |
+
|
139 |
+
main()
|
140 |
+
|
141 |
+
|
142 |
+
@cli.command()
|
143 |
+
@click.argument(
|
144 |
+
"input-path",
|
145 |
+
type=click.Path(exists=True),
|
146 |
+
)
|
147 |
+
@click.option(
|
148 |
+
"-o",
|
149 |
+
"--output-path",
|
150 |
+
type=click.Path(),
|
151 |
+
help="path to output dir",
|
152 |
+
)
|
153 |
+
@click.option("-s", "--speaker", type=str, default=None, help="speaker name")
|
154 |
+
@click.option(
|
155 |
+
"-m",
|
156 |
+
"--model-path",
|
157 |
+
type=click.Path(exists=True),
|
158 |
+
default=Path("./logs/44k/"),
|
159 |
+
help="path to model",
|
160 |
+
)
|
161 |
+
@click.option(
|
162 |
+
"-c",
|
163 |
+
"--config-path",
|
164 |
+
type=click.Path(exists=True),
|
165 |
+
default=Path("./configs/44k/config.json"),
|
166 |
+
help="path to config",
|
167 |
+
)
|
168 |
+
@click.option(
|
169 |
+
"-k",
|
170 |
+
"--cluster-model-path",
|
171 |
+
type=click.Path(exists=True),
|
172 |
+
default=None,
|
173 |
+
help="path to cluster model",
|
174 |
+
)
|
175 |
+
@click.option(
|
176 |
+
"-re",
|
177 |
+
"--recursive",
|
178 |
+
type=bool,
|
179 |
+
default=False,
|
180 |
+
help="Search recursively",
|
181 |
+
is_flag=True,
|
182 |
+
)
|
183 |
+
@click.option("-t", "--transpose", type=int, default=0, help="transpose")
|
184 |
+
@click.option(
|
185 |
+
"-db", "--db-thresh", type=int, default=-20, help="threshold (DB) (RELATIVE)"
|
186 |
+
)
|
187 |
+
@click.option(
|
188 |
+
"-fm",
|
189 |
+
"--f0-method",
|
190 |
+
type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
|
191 |
+
default="dio",
|
192 |
+
help="f0 prediction method",
|
193 |
+
)
|
194 |
+
@click.option(
|
195 |
+
"-a/-na",
|
196 |
+
"--auto-predict-f0/--no-auto-predict-f0",
|
197 |
+
type=bool,
|
198 |
+
default=True,
|
199 |
+
help="auto predict f0",
|
200 |
+
)
|
201 |
+
@click.option(
|
202 |
+
"-r", "--cluster-infer-ratio", type=float, default=0, help="cluster infer ratio"
|
203 |
+
)
|
204 |
+
@click.option("-n", "--noise-scale", type=float, default=0.4, help="noise scale")
|
205 |
+
@click.option("-p", "--pad-seconds", type=float, default=0.5, help="pad seconds")
|
206 |
+
@click.option(
|
207 |
+
"-d",
|
208 |
+
"--device",
|
209 |
+
type=str,
|
210 |
+
default=get_optimal_device(),
|
211 |
+
help="device",
|
212 |
+
)
|
213 |
+
@click.option("-ch", "--chunk-seconds", type=float, default=0.5, help="chunk seconds")
|
214 |
+
@click.option(
|
215 |
+
"-ab/-nab",
|
216 |
+
"--absolute-thresh/--no-absolute-thresh",
|
217 |
+
type=bool,
|
218 |
+
default=False,
|
219 |
+
help="absolute thresh",
|
220 |
+
)
|
221 |
+
@click.option(
|
222 |
+
"-mc",
|
223 |
+
"--max-chunk-seconds",
|
224 |
+
type=float,
|
225 |
+
default=40,
|
226 |
+
help="maximum allowed single chunk length, set lower if you get out of memory (0 to disable)",
|
227 |
+
)
|
228 |
+
def infer(
|
229 |
+
# paths
|
230 |
+
input_path: Path,
|
231 |
+
output_path: Path,
|
232 |
+
model_path: Path,
|
233 |
+
config_path: Path,
|
234 |
+
recursive: bool,
|
235 |
+
# svc config
|
236 |
+
speaker: str,
|
237 |
+
cluster_model_path: Path | None = None,
|
238 |
+
transpose: int = 0,
|
239 |
+
auto_predict_f0: bool = False,
|
240 |
+
cluster_infer_ratio: float = 0,
|
241 |
+
noise_scale: float = 0.4,
|
242 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
|
243 |
+
# slice config
|
244 |
+
db_thresh: int = -40,
|
245 |
+
pad_seconds: float = 0.5,
|
246 |
+
chunk_seconds: float = 0.5,
|
247 |
+
absolute_thresh: bool = False,
|
248 |
+
max_chunk_seconds: float = 40,
|
249 |
+
device: str | torch.device = get_optimal_device(),
|
250 |
+
):
|
251 |
+
"""Inference"""
|
252 |
+
from so_vits_svc_fork.inference.main import infer
|
253 |
+
|
254 |
+
if not auto_predict_f0:
|
255 |
+
LOG.warning(
|
256 |
+
f"auto_predict_f0 = False, transpose = {transpose}. If you want to change the pitch, please set transpose."
|
257 |
+
"Generally transpose = 0 does not work because your voice pitch and target voice pitch are different."
|
258 |
+
)
|
259 |
+
|
260 |
+
input_path = Path(input_path)
|
261 |
+
if output_path is None:
|
262 |
+
output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
|
263 |
+
output_path = Path(output_path)
|
264 |
+
if input_path.is_dir() and not recursive:
|
265 |
+
raise ValueError(
|
266 |
+
"input_path is a directory. Use 0re or --recursive to infer recursively."
|
267 |
+
)
|
268 |
+
model_path = Path(model_path)
|
269 |
+
if model_path.is_dir():
|
270 |
+
model_path = list(
|
271 |
+
sorted(model_path.glob("G_*.pth"), key=lambda x: x.stat().st_mtime)
|
272 |
+
)[-1]
|
273 |
+
LOG.info(f"Since model_path is a directory, use {model_path}")
|
274 |
+
config_path = Path(config_path)
|
275 |
+
if cluster_model_path is not None:
|
276 |
+
cluster_model_path = Path(cluster_model_path)
|
277 |
+
infer(
|
278 |
+
# paths
|
279 |
+
input_path=input_path,
|
280 |
+
output_path=output_path,
|
281 |
+
model_path=model_path,
|
282 |
+
config_path=config_path,
|
283 |
+
recursive=recursive,
|
284 |
+
# svc config
|
285 |
+
speaker=speaker,
|
286 |
+
cluster_model_path=cluster_model_path,
|
287 |
+
transpose=transpose,
|
288 |
+
auto_predict_f0=auto_predict_f0,
|
289 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
290 |
+
noise_scale=noise_scale,
|
291 |
+
f0_method=f0_method,
|
292 |
+
# slice config
|
293 |
+
db_thresh=db_thresh,
|
294 |
+
pad_seconds=pad_seconds,
|
295 |
+
chunk_seconds=chunk_seconds,
|
296 |
+
absolute_thresh=absolute_thresh,
|
297 |
+
max_chunk_seconds=max_chunk_seconds,
|
298 |
+
device=device,
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
@cli.command()
|
303 |
+
@click.option(
|
304 |
+
"-m",
|
305 |
+
"--model-path",
|
306 |
+
type=click.Path(exists=True),
|
307 |
+
default=Path("./logs/44k/"),
|
308 |
+
help="path to model",
|
309 |
+
)
|
310 |
+
@click.option(
|
311 |
+
"-c",
|
312 |
+
"--config-path",
|
313 |
+
type=click.Path(exists=True),
|
314 |
+
default=Path("./configs/44k/config.json"),
|
315 |
+
help="path to config",
|
316 |
+
)
|
317 |
+
@click.option(
|
318 |
+
"-k",
|
319 |
+
"--cluster-model-path",
|
320 |
+
type=click.Path(exists=True),
|
321 |
+
default=None,
|
322 |
+
help="path to cluster model",
|
323 |
+
)
|
324 |
+
@click.option("-t", "--transpose", type=int, default=12, help="transpose")
|
325 |
+
@click.option(
|
326 |
+
"-a/-na",
|
327 |
+
"--auto-predict-f0/--no-auto-predict-f0",
|
328 |
+
type=bool,
|
329 |
+
default=True,
|
330 |
+
help="auto predict f0 (not recommended for realtime since voice pitch will not be stable)",
|
331 |
+
)
|
332 |
+
@click.option(
|
333 |
+
"-r", "--cluster-infer-ratio", type=float, default=0, help="cluster infer ratio"
|
334 |
+
)
|
335 |
+
@click.option("-n", "--noise-scale", type=float, default=0.4, help="noise scale")
|
336 |
+
@click.option(
|
337 |
+
"-db", "--db-thresh", type=int, default=-30, help="threshold (DB) (ABSOLUTE)"
|
338 |
+
)
|
339 |
+
@click.option(
|
340 |
+
"-fm",
|
341 |
+
"--f0-method",
|
342 |
+
type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
|
343 |
+
default="dio",
|
344 |
+
help="f0 prediction method",
|
345 |
+
)
|
346 |
+
@click.option("-p", "--pad-seconds", type=float, default=0.02, help="pad seconds")
|
347 |
+
@click.option("-ch", "--chunk-seconds", type=float, default=0.5, help="chunk seconds")
|
348 |
+
@click.option(
|
349 |
+
"-cr",
|
350 |
+
"--crossfade-seconds",
|
351 |
+
type=float,
|
352 |
+
default=0.01,
|
353 |
+
help="crossfade seconds",
|
354 |
+
)
|
355 |
+
@click.option(
|
356 |
+
"-ab",
|
357 |
+
"--additional-infer-before-seconds",
|
358 |
+
type=float,
|
359 |
+
default=0.2,
|
360 |
+
help="additional infer before seconds",
|
361 |
+
)
|
362 |
+
@click.option(
|
363 |
+
"-aa",
|
364 |
+
"--additional-infer-after-seconds",
|
365 |
+
type=float,
|
366 |
+
default=0.1,
|
367 |
+
help="additional infer after seconds",
|
368 |
+
)
|
369 |
+
@click.option("-b", "--block-seconds", type=float, default=0.5, help="block seconds")
|
370 |
+
@click.option(
|
371 |
+
"-d",
|
372 |
+
"--device",
|
373 |
+
type=str,
|
374 |
+
default=get_optimal_device(),
|
375 |
+
help="device",
|
376 |
+
)
|
377 |
+
@click.option("-s", "--speaker", type=str, default=None, help="speaker name")
|
378 |
+
@click.option("-v", "--version", type=int, default=2, help="version")
|
379 |
+
@click.option("-i", "--input-device", type=int, default=None, help="input device")
|
380 |
+
@click.option("-o", "--output-device", type=int, default=None, help="output device")
|
381 |
+
@click.option(
|
382 |
+
"-po",
|
383 |
+
"--passthrough-original",
|
384 |
+
type=bool,
|
385 |
+
default=False,
|
386 |
+
is_flag=True,
|
387 |
+
help="passthrough original (for latency check)",
|
388 |
+
)
|
389 |
+
def vc(
|
390 |
+
# paths
|
391 |
+
model_path: Path,
|
392 |
+
config_path: Path,
|
393 |
+
# svc config
|
394 |
+
speaker: str,
|
395 |
+
cluster_model_path: Path | None,
|
396 |
+
transpose: int,
|
397 |
+
auto_predict_f0: bool,
|
398 |
+
cluster_infer_ratio: float,
|
399 |
+
noise_scale: float,
|
400 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
|
401 |
+
# slice config
|
402 |
+
db_thresh: int,
|
403 |
+
pad_seconds: float,
|
404 |
+
chunk_seconds: float,
|
405 |
+
# realtime config
|
406 |
+
crossfade_seconds: float,
|
407 |
+
additional_infer_before_seconds: float,
|
408 |
+
additional_infer_after_seconds: float,
|
409 |
+
block_seconds: float,
|
410 |
+
version: int,
|
411 |
+
input_device: int | str | None,
|
412 |
+
output_device: int | str | None,
|
413 |
+
device: torch.device,
|
414 |
+
passthrough_original: bool = False,
|
415 |
+
) -> None:
|
416 |
+
"""Realtime inference from microphone"""
|
417 |
+
from so_vits_svc_fork.inference.main import realtime
|
418 |
+
|
419 |
+
if auto_predict_f0:
|
420 |
+
LOG.warning(
|
421 |
+
"auto_predict_f0 = True in realtime inference will cause unstable voice pitch, use with caution"
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
LOG.warning(
|
425 |
+
f"auto_predict_f0 = False, transpose = {transpose}. If you want to change the pitch, please change the transpose value."
|
426 |
+
"Generally transpose = 0 does not work because your voice pitch and target voice pitch are different."
|
427 |
+
)
|
428 |
+
model_path = Path(model_path)
|
429 |
+
config_path = Path(config_path)
|
430 |
+
if cluster_model_path is not None:
|
431 |
+
cluster_model_path = Path(cluster_model_path)
|
432 |
+
if model_path.is_dir():
|
433 |
+
model_path = list(
|
434 |
+
sorted(model_path.glob("G_*.pth"), key=lambda x: x.stat().st_mtime)
|
435 |
+
)[-1]
|
436 |
+
LOG.info(f"Since model_path is a directory, use {model_path}")
|
437 |
+
|
438 |
+
realtime(
|
439 |
+
# paths
|
440 |
+
model_path=model_path,
|
441 |
+
config_path=config_path,
|
442 |
+
# svc config
|
443 |
+
speaker=speaker,
|
444 |
+
cluster_model_path=cluster_model_path,
|
445 |
+
transpose=transpose,
|
446 |
+
auto_predict_f0=auto_predict_f0,
|
447 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
448 |
+
noise_scale=noise_scale,
|
449 |
+
f0_method=f0_method,
|
450 |
+
# slice config
|
451 |
+
db_thresh=db_thresh,
|
452 |
+
pad_seconds=pad_seconds,
|
453 |
+
chunk_seconds=chunk_seconds,
|
454 |
+
# realtime config
|
455 |
+
crossfade_seconds=crossfade_seconds,
|
456 |
+
additional_infer_before_seconds=additional_infer_before_seconds,
|
457 |
+
additional_infer_after_seconds=additional_infer_after_seconds,
|
458 |
+
block_seconds=block_seconds,
|
459 |
+
version=version,
|
460 |
+
input_device=input_device,
|
461 |
+
output_device=output_device,
|
462 |
+
device=device,
|
463 |
+
passthrough_original=passthrough_original,
|
464 |
+
)
|
465 |
+
|
466 |
+
|
467 |
+
@cli.command()
|
468 |
+
@click.option(
|
469 |
+
"-i",
|
470 |
+
"--input-dir",
|
471 |
+
type=click.Path(exists=True),
|
472 |
+
default=Path("./dataset_raw"),
|
473 |
+
help="path to source dir",
|
474 |
+
)
|
475 |
+
@click.option(
|
476 |
+
"-o",
|
477 |
+
"--output-dir",
|
478 |
+
type=click.Path(),
|
479 |
+
default=Path("./dataset/44k"),
|
480 |
+
help="path to output dir",
|
481 |
+
)
|
482 |
+
@click.option("-s", "--sampling-rate", type=int, default=44100, help="sampling rate")
|
483 |
+
@click.option(
|
484 |
+
"-n",
|
485 |
+
"--n-jobs",
|
486 |
+
type=int,
|
487 |
+
default=-1,
|
488 |
+
help="number of jobs (optimal value may depend on your RAM capacity and audio duration per file)",
|
489 |
+
)
|
490 |
+
@click.option("-d", "--top-db", type=float, default=30, help="top db")
|
491 |
+
@click.option("-f", "--frame-seconds", type=float, default=1, help="frame seconds")
|
492 |
+
@click.option(
|
493 |
+
"-ho", "-hop", "--hop-seconds", type=float, default=0.3, help="hop seconds"
|
494 |
+
)
|
495 |
+
def pre_resample(
|
496 |
+
input_dir: Path,
|
497 |
+
output_dir: Path,
|
498 |
+
sampling_rate: int,
|
499 |
+
n_jobs: int,
|
500 |
+
top_db: int,
|
501 |
+
frame_seconds: float,
|
502 |
+
hop_seconds: float,
|
503 |
+
) -> None:
|
504 |
+
"""Preprocessing part 1: resample"""
|
505 |
+
from so_vits_svc_fork.preprocessing.preprocess_resample import preprocess_resample
|
506 |
+
|
507 |
+
input_dir = Path(input_dir)
|
508 |
+
output_dir = Path(output_dir)
|
509 |
+
preprocess_resample(
|
510 |
+
input_dir=input_dir,
|
511 |
+
output_dir=output_dir,
|
512 |
+
sampling_rate=sampling_rate,
|
513 |
+
n_jobs=n_jobs,
|
514 |
+
top_db=top_db,
|
515 |
+
frame_seconds=frame_seconds,
|
516 |
+
hop_seconds=hop_seconds,
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
from so_vits_svc_fork.preprocessing.preprocess_flist_config import CONFIG_TEMPLATE_DIR
|
521 |
+
|
522 |
+
|
523 |
+
@cli.command()
|
524 |
+
@click.option(
|
525 |
+
"-i",
|
526 |
+
"--input-dir",
|
527 |
+
type=click.Path(exists=True),
|
528 |
+
default=Path("./dataset/44k"),
|
529 |
+
help="path to source dir",
|
530 |
+
)
|
531 |
+
@click.option(
|
532 |
+
"-f",
|
533 |
+
"--filelist-path",
|
534 |
+
type=click.Path(),
|
535 |
+
default=Path("./filelists/44k"),
|
536 |
+
help="path to filelist dir",
|
537 |
+
)
|
538 |
+
@click.option(
|
539 |
+
"-c",
|
540 |
+
"--config-path",
|
541 |
+
type=click.Path(),
|
542 |
+
default=Path("./configs/44k/config.json"),
|
543 |
+
help="path to config",
|
544 |
+
)
|
545 |
+
@click.option(
|
546 |
+
"-t",
|
547 |
+
"--config-type",
|
548 |
+
type=click.Choice([x.stem for x in CONFIG_TEMPLATE_DIR.rglob("*.json")]),
|
549 |
+
default="so-vits-svc-4.0v1",
|
550 |
+
help="config type",
|
551 |
+
)
|
552 |
+
def pre_config(
|
553 |
+
input_dir: Path,
|
554 |
+
filelist_path: Path,
|
555 |
+
config_path: Path,
|
556 |
+
config_type: str,
|
557 |
+
):
|
558 |
+
"""Preprocessing part 2: config"""
|
559 |
+
from so_vits_svc_fork.preprocessing.preprocess_flist_config import preprocess_config
|
560 |
+
|
561 |
+
input_dir = Path(input_dir)
|
562 |
+
filelist_path = Path(filelist_path)
|
563 |
+
config_path = Path(config_path)
|
564 |
+
preprocess_config(
|
565 |
+
input_dir=input_dir,
|
566 |
+
train_list_path=filelist_path / "train.txt",
|
567 |
+
val_list_path=filelist_path / "val.txt",
|
568 |
+
test_list_path=filelist_path / "test.txt",
|
569 |
+
config_path=config_path,
|
570 |
+
config_name=config_type,
|
571 |
+
)
|
572 |
+
|
573 |
+
|
574 |
+
@cli.command()
|
575 |
+
@click.option(
|
576 |
+
"-i",
|
577 |
+
"--input-dir",
|
578 |
+
type=click.Path(exists=True),
|
579 |
+
default=Path("./dataset/44k"),
|
580 |
+
help="path to source dir",
|
581 |
+
)
|
582 |
+
@click.option(
|
583 |
+
"-c",
|
584 |
+
"--config-path",
|
585 |
+
type=click.Path(exists=True),
|
586 |
+
help="path to config",
|
587 |
+
default=Path("./configs/44k/config.json"),
|
588 |
+
)
|
589 |
+
@click.option(
|
590 |
+
"-n",
|
591 |
+
"--n-jobs",
|
592 |
+
type=int,
|
593 |
+
default=None,
|
594 |
+
help="number of jobs (optimal value may depend on your VRAM capacity and audio duration per file)",
|
595 |
+
)
|
596 |
+
@click.option(
|
597 |
+
"-f/-nf",
|
598 |
+
"--force-rebuild/--no-force-rebuild",
|
599 |
+
type=bool,
|
600 |
+
default=True,
|
601 |
+
help="force rebuild existing preprocessed files",
|
602 |
+
)
|
603 |
+
@click.option(
|
604 |
+
"-fm",
|
605 |
+
"--f0-method",
|
606 |
+
type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
|
607 |
+
default="dio",
|
608 |
+
)
|
609 |
+
def pre_hubert(
|
610 |
+
input_dir: Path,
|
611 |
+
config_path: Path,
|
612 |
+
n_jobs: bool,
|
613 |
+
force_rebuild: bool,
|
614 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
|
615 |
+
) -> None:
|
616 |
+
"""Preprocessing part 3: hubert
|
617 |
+
If the HuBERT model is not found, it will be downloaded automatically."""
|
618 |
+
from so_vits_svc_fork.preprocessing.preprocess_hubert_f0 import preprocess_hubert_f0
|
619 |
+
|
620 |
+
input_dir = Path(input_dir)
|
621 |
+
config_path = Path(config_path)
|
622 |
+
preprocess_hubert_f0(
|
623 |
+
input_dir=input_dir,
|
624 |
+
config_path=config_path,
|
625 |
+
n_jobs=n_jobs,
|
626 |
+
force_rebuild=force_rebuild,
|
627 |
+
f0_method=f0_method,
|
628 |
+
)
|
629 |
+
|
630 |
+
|
631 |
+
@cli.command()
|
632 |
+
@click.option(
|
633 |
+
"-i",
|
634 |
+
"--input-dir",
|
635 |
+
type=click.Path(exists=True),
|
636 |
+
default=Path("./dataset_raw_raw/"),
|
637 |
+
help="path to source dir",
|
638 |
+
)
|
639 |
+
@click.option(
|
640 |
+
"-o",
|
641 |
+
"--output-dir",
|
642 |
+
type=click.Path(),
|
643 |
+
default=Path("./dataset_raw/"),
|
644 |
+
help="path to output dir",
|
645 |
+
)
|
646 |
+
@click.option(
|
647 |
+
"-n",
|
648 |
+
"--n-jobs",
|
649 |
+
type=int,
|
650 |
+
default=-1,
|
651 |
+
help="number of jobs (optimal value may depend on your VRAM capacity and audio duration per file)",
|
652 |
+
)
|
653 |
+
@click.option("-min", "--min-speakers", type=int, default=2, help="min speakers")
|
654 |
+
@click.option("-max", "--max-speakers", type=int, default=2, help="max speakers")
|
655 |
+
@click.option(
|
656 |
+
"-t", "--huggingface-token", type=str, default=None, help="huggingface token"
|
657 |
+
)
|
658 |
+
@click.option("-s", "--sr", type=int, default=44100, help="sampling rate")
|
659 |
+
def pre_sd(
|
660 |
+
input_dir: Path | str,
|
661 |
+
output_dir: Path | str,
|
662 |
+
min_speakers: int,
|
663 |
+
max_speakers: int,
|
664 |
+
huggingface_token: str | None,
|
665 |
+
n_jobs: int,
|
666 |
+
sr: int,
|
667 |
+
):
|
668 |
+
"""Speech diarization using pyannote.audio"""
|
669 |
+
if huggingface_token is None:
|
670 |
+
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN", None)
|
671 |
+
if huggingface_token is None:
|
672 |
+
huggingface_token = click.prompt(
|
673 |
+
"Please enter your HuggingFace token", hide_input=True
|
674 |
+
)
|
675 |
+
if os.environ.get("HUGGINGFACE_TOKEN", None) is None:
|
676 |
+
LOG.info("You can also set the HUGGINGFACE_TOKEN environment variable.")
|
677 |
+
assert huggingface_token is not None
|
678 |
+
huggingface_token = huggingface_token.rstrip(" \n\r\t\0")
|
679 |
+
if len(huggingface_token) <= 1:
|
680 |
+
raise ValueError("HuggingFace token is empty: " + huggingface_token)
|
681 |
+
|
682 |
+
if max_speakers == 1:
|
683 |
+
LOG.warning("Consider using pre-split if max_speakers == 1")
|
684 |
+
from so_vits_svc_fork.preprocessing.preprocess_speaker_diarization import (
|
685 |
+
preprocess_speaker_diarization,
|
686 |
+
)
|
687 |
+
|
688 |
+
preprocess_speaker_diarization(
|
689 |
+
input_dir=input_dir,
|
690 |
+
output_dir=output_dir,
|
691 |
+
min_speakers=min_speakers,
|
692 |
+
max_speakers=max_speakers,
|
693 |
+
huggingface_token=huggingface_token,
|
694 |
+
n_jobs=n_jobs,
|
695 |
+
sr=sr,
|
696 |
+
)
|
697 |
+
|
698 |
+
|
699 |
+
@cli.command()
|
700 |
+
@click.option(
|
701 |
+
"-i",
|
702 |
+
"--input-dir",
|
703 |
+
type=click.Path(exists=True),
|
704 |
+
default=Path("./dataset_raw_raw/"),
|
705 |
+
help="path to source dir",
|
706 |
+
)
|
707 |
+
@click.option(
|
708 |
+
"-o",
|
709 |
+
"--output-dir",
|
710 |
+
type=click.Path(),
|
711 |
+
default=Path("./dataset_raw/"),
|
712 |
+
help="path to output dir",
|
713 |
+
)
|
714 |
+
@click.option(
|
715 |
+
"-n",
|
716 |
+
"--n-jobs",
|
717 |
+
type=int,
|
718 |
+
default=-1,
|
719 |
+
help="number of jobs (optimal value may depend on your RAM capacity and audio duration per file)",
|
720 |
+
)
|
721 |
+
@click.option(
|
722 |
+
"-l",
|
723 |
+
"--max-length",
|
724 |
+
type=float,
|
725 |
+
default=10,
|
726 |
+
help="max length of each split in seconds",
|
727 |
+
)
|
728 |
+
@click.option("-d", "--top-db", type=float, default=30, help="top db")
|
729 |
+
@click.option("-f", "--frame-seconds", type=float, default=1, help="frame seconds")
|
730 |
+
@click.option(
|
731 |
+
"-ho", "-hop", "--hop-seconds", type=float, default=0.3, help="hop seconds"
|
732 |
+
)
|
733 |
+
@click.option("-s", "--sr", type=int, default=44100, help="sample rate")
|
734 |
+
def pre_split(
|
735 |
+
input_dir: Path | str,
|
736 |
+
output_dir: Path | str,
|
737 |
+
max_length: float,
|
738 |
+
top_db: int,
|
739 |
+
frame_seconds: float,
|
740 |
+
hop_seconds: float,
|
741 |
+
n_jobs: int,
|
742 |
+
sr: int,
|
743 |
+
):
|
744 |
+
"""Split audio files into multiple files"""
|
745 |
+
from so_vits_svc_fork.preprocessing.preprocess_split import preprocess_split
|
746 |
+
|
747 |
+
preprocess_split(
|
748 |
+
input_dir=input_dir,
|
749 |
+
output_dir=output_dir,
|
750 |
+
max_length=max_length,
|
751 |
+
top_db=top_db,
|
752 |
+
frame_seconds=frame_seconds,
|
753 |
+
hop_seconds=hop_seconds,
|
754 |
+
n_jobs=n_jobs,
|
755 |
+
sr=sr,
|
756 |
+
)
|
757 |
+
|
758 |
+
|
759 |
+
@cli.command()
|
760 |
+
@click.option(
|
761 |
+
"-i",
|
762 |
+
"--input-dir",
|
763 |
+
type=click.Path(exists=True),
|
764 |
+
required=True,
|
765 |
+
help="path to source dir",
|
766 |
+
)
|
767 |
+
@click.option(
|
768 |
+
"-o",
|
769 |
+
"--output-dir",
|
770 |
+
type=click.Path(),
|
771 |
+
default=None,
|
772 |
+
help="path to output dir",
|
773 |
+
)
|
774 |
+
@click.option(
|
775 |
+
"-c/-nc",
|
776 |
+
"--create-new/--no-create-new",
|
777 |
+
type=bool,
|
778 |
+
default=True,
|
779 |
+
help="create a new folder for the speaker if not exist",
|
780 |
+
)
|
781 |
+
def pre_classify(
|
782 |
+
input_dir: Path | str,
|
783 |
+
output_dir: Path | str | None,
|
784 |
+
create_new: bool,
|
785 |
+
) -> None:
|
786 |
+
"""Classify multiple audio files into multiple files"""
|
787 |
+
from so_vits_svc_fork.preprocessing.preprocess_classify import preprocess_classify
|
788 |
+
|
789 |
+
if output_dir is None:
|
790 |
+
output_dir = input_dir
|
791 |
+
preprocess_classify(
|
792 |
+
input_dir=input_dir,
|
793 |
+
output_dir=output_dir,
|
794 |
+
create_new=create_new,
|
795 |
+
)
|
796 |
+
|
797 |
+
|
798 |
+
@cli.command
|
799 |
+
def clean():
|
800 |
+
"""Clean up files, only useful if you are using the default file structure"""
|
801 |
+
import shutil
|
802 |
+
|
803 |
+
folders = ["dataset", "filelists", "logs"]
|
804 |
+
# if pyip.inputYesNo(f"Are you sure you want to delete files in {folders}?") == "yes":
|
805 |
+
if input("Are you sure you want to delete files in {folders}?") in ["yes", "y"]:
|
806 |
+
for folder in folders:
|
807 |
+
if Path(folder).exists():
|
808 |
+
shutil.rmtree(folder)
|
809 |
+
LOG.info("Cleaned up files")
|
810 |
+
else:
|
811 |
+
LOG.info("Aborted")
|
812 |
+
|
813 |
+
|
814 |
+
@cli.command
|
815 |
+
@click.option(
|
816 |
+
"-i",
|
817 |
+
"--input-path",
|
818 |
+
type=click.Path(exists=True),
|
819 |
+
help="model path",
|
820 |
+
default=Path("./logs/44k/"),
|
821 |
+
)
|
822 |
+
@click.option(
|
823 |
+
"-o",
|
824 |
+
"--output-path",
|
825 |
+
type=click.Path(),
|
826 |
+
help="onnx model path to save",
|
827 |
+
default=None,
|
828 |
+
)
|
829 |
+
@click.option(
|
830 |
+
"-c",
|
831 |
+
"--config-path",
|
832 |
+
type=click.Path(),
|
833 |
+
help="config path",
|
834 |
+
default=Path("./configs/44k/config.json"),
|
835 |
+
)
|
836 |
+
@click.option(
|
837 |
+
"-d",
|
838 |
+
"--device",
|
839 |
+
type=str,
|
840 |
+
default="cpu",
|
841 |
+
help="device to use",
|
842 |
+
)
|
843 |
+
def onnx(
|
844 |
+
input_path: Path, output_path: Path, config_path: Path, device: torch.device | str
|
845 |
+
) -> None:
|
846 |
+
"""Export model to onnx (currently not working)"""
|
847 |
+
raise NotImplementedError("ONNX export is not yet supported")
|
848 |
+
input_path = Path(input_path)
|
849 |
+
if input_path.is_dir():
|
850 |
+
input_path = list(input_path.glob("*.pth"))[0]
|
851 |
+
if output_path is None:
|
852 |
+
output_path = input_path.with_suffix(".onnx")
|
853 |
+
output_path = Path(output_path)
|
854 |
+
if output_path.is_dir():
|
855 |
+
output_path = output_path / (input_path.stem + ".onnx")
|
856 |
+
config_path = Path(config_path)
|
857 |
+
device_ = torch.device(device)
|
858 |
+
from so_vits_svc_fork.modules.onnx._export import onnx_export
|
859 |
+
|
860 |
+
onnx_export(
|
861 |
+
input_path=input_path,
|
862 |
+
output_path=output_path,
|
863 |
+
config_path=config_path,
|
864 |
+
device=device_,
|
865 |
+
)
|
866 |
+
|
867 |
+
|
868 |
+
@cli.command
|
869 |
+
@click.option(
|
870 |
+
"-i",
|
871 |
+
"--input-dir",
|
872 |
+
type=click.Path(exists=True),
|
873 |
+
help="dataset directory",
|
874 |
+
default=Path("./dataset/44k"),
|
875 |
+
)
|
876 |
+
@click.option(
|
877 |
+
"-o",
|
878 |
+
"--output-path",
|
879 |
+
type=click.Path(),
|
880 |
+
help="model path to save",
|
881 |
+
default=Path("./logs/44k/kmeans.pt"),
|
882 |
+
)
|
883 |
+
@click.option("-n", "--n-clusters", type=int, help="number of clusters", default=2000)
|
884 |
+
@click.option(
|
885 |
+
"-m/-nm", "--minibatch/--no-minibatch", default=True, help="use minibatch k-means"
|
886 |
+
)
|
887 |
+
@click.option(
|
888 |
+
"-b", "--batch-size", type=int, default=4096, help="batch size for minibatch kmeans"
|
889 |
+
)
|
890 |
+
@click.option(
|
891 |
+
"-p/-np", "--partial-fit", default=False, help="use partial fit (only use with -m)"
|
892 |
+
)
|
893 |
+
def train_cluster(
|
894 |
+
input_dir: Path,
|
895 |
+
output_path: Path,
|
896 |
+
n_clusters: int,
|
897 |
+
minibatch: bool,
|
898 |
+
batch_size: int,
|
899 |
+
partial_fit: bool,
|
900 |
+
) -> None:
|
901 |
+
"""Train k-means clustering"""
|
902 |
+
from .cluster.train_cluster import main
|
903 |
+
|
904 |
+
main(
|
905 |
+
input_dir=input_dir,
|
906 |
+
output_path=output_path,
|
907 |
+
n_clusters=n_clusters,
|
908 |
+
verbose=True,
|
909 |
+
use_minibatch=minibatch,
|
910 |
+
batch_size=batch_size,
|
911 |
+
partial_fit=partial_fit,
|
912 |
+
)
|
913 |
+
|
914 |
+
|
915 |
+
if __name__ == "__main__":
|
916 |
+
freeze_support()
|
917 |
+
cli()
|
so_vits_svc_fork/cluster/__init__.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from sklearn.cluster import KMeans
|
8 |
+
|
9 |
+
|
10 |
+
def get_cluster_model(ckpt_path: Path | str):
|
11 |
+
with Path(ckpt_path).open("rb") as f:
|
12 |
+
checkpoint = torch.load(
|
13 |
+
f, map_location="cpu"
|
14 |
+
) # Danger of arbitrary code execution
|
15 |
+
kmeans_dict = {}
|
16 |
+
for spk, ckpt in checkpoint.items():
|
17 |
+
km = KMeans(ckpt["n_features_in_"])
|
18 |
+
km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
|
19 |
+
km.__dict__["_n_threads"] = ckpt["_n_threads"]
|
20 |
+
km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
|
21 |
+
kmeans_dict[spk] = km
|
22 |
+
return kmeans_dict
|
23 |
+
|
24 |
+
|
25 |
+
def check_speaker(model: Any, speaker: Any):
|
26 |
+
if speaker not in model:
|
27 |
+
raise ValueError(f"Speaker {speaker} not in {list(model.keys())}")
|
28 |
+
|
29 |
+
|
30 |
+
def get_cluster_result(model: Any, x: Any, speaker: Any):
|
31 |
+
"""
|
32 |
+
x: np.array [t, 256]
|
33 |
+
return cluster class result
|
34 |
+
"""
|
35 |
+
check_speaker(model, speaker)
|
36 |
+
return model[speaker].predict(x)
|
37 |
+
|
38 |
+
|
39 |
+
def get_cluster_center_result(model: Any, x: Any, speaker: Any):
|
40 |
+
"""x: np.array [t, 256]"""
|
41 |
+
check_speaker(model, speaker)
|
42 |
+
predict = model[speaker].predict(x)
|
43 |
+
return model[speaker].cluster_centers_[predict]
|
44 |
+
|
45 |
+
|
46 |
+
def get_center(model: Any, x: Any, speaker: Any):
|
47 |
+
check_speaker(model, speaker)
|
48 |
+
return model[speaker].cluster_centers_[x]
|
so_vits_svc_fork/cluster/train_cluster.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
from logging import getLogger
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from cm_time import timer
|
11 |
+
from joblib import Parallel, delayed
|
12 |
+
from sklearn.cluster import KMeans, MiniBatchKMeans
|
13 |
+
from tqdm_joblib import tqdm_joblib
|
14 |
+
|
15 |
+
LOG = getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def train_cluster(
|
19 |
+
input_dir: Path | str,
|
20 |
+
n_clusters: int,
|
21 |
+
use_minibatch: bool = True,
|
22 |
+
batch_size: int = 4096,
|
23 |
+
partial_fit: bool = False,
|
24 |
+
verbose: bool = False,
|
25 |
+
) -> dict:
|
26 |
+
input_dir = Path(input_dir)
|
27 |
+
if not partial_fit:
|
28 |
+
LOG.info(f"Loading features from {input_dir}")
|
29 |
+
features = []
|
30 |
+
for path in input_dir.rglob("*.data.pt"):
|
31 |
+
with path.open("rb") as f:
|
32 |
+
features.append(
|
33 |
+
torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
|
34 |
+
)
|
35 |
+
if not features:
|
36 |
+
raise ValueError(f"No features found in {input_dir}")
|
37 |
+
features = np.concatenate(features, axis=0).astype(np.float32)
|
38 |
+
if features.shape[0] < n_clusters:
|
39 |
+
raise ValueError(
|
40 |
+
"Too few HuBERT features to cluster. Consider using a smaller number of clusters."
|
41 |
+
)
|
42 |
+
LOG.info(
|
43 |
+
f"shape: {features.shape}, size: {features.nbytes/1024**2:.2f} MB, dtype: {features.dtype}"
|
44 |
+
)
|
45 |
+
with timer() as t:
|
46 |
+
if use_minibatch:
|
47 |
+
kmeans = MiniBatchKMeans(
|
48 |
+
n_clusters=n_clusters,
|
49 |
+
verbose=verbose,
|
50 |
+
batch_size=batch_size,
|
51 |
+
max_iter=80,
|
52 |
+
n_init="auto",
|
53 |
+
).fit(features)
|
54 |
+
else:
|
55 |
+
kmeans = KMeans(
|
56 |
+
n_clusters=n_clusters, verbose=verbose, n_init="auto"
|
57 |
+
).fit(features)
|
58 |
+
LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
|
59 |
+
|
60 |
+
x = {
|
61 |
+
"n_features_in_": kmeans.n_features_in_,
|
62 |
+
"_n_threads": kmeans._n_threads,
|
63 |
+
"cluster_centers_": kmeans.cluster_centers_,
|
64 |
+
}
|
65 |
+
return x
|
66 |
+
else:
|
67 |
+
# minibatch partial fit
|
68 |
+
paths = list(input_dir.rglob("*.data.pt"))
|
69 |
+
if len(paths) == 0:
|
70 |
+
raise ValueError(f"No features found in {input_dir}")
|
71 |
+
LOG.info(f"Found {len(paths)} features in {input_dir}")
|
72 |
+
n_batches = math.ceil(len(paths) / batch_size)
|
73 |
+
LOG.info(f"Splitting into {n_batches} batches")
|
74 |
+
with timer() as t:
|
75 |
+
kmeans = MiniBatchKMeans(
|
76 |
+
n_clusters=n_clusters,
|
77 |
+
verbose=verbose,
|
78 |
+
batch_size=batch_size,
|
79 |
+
max_iter=80,
|
80 |
+
n_init="auto",
|
81 |
+
)
|
82 |
+
for i in range(0, len(paths), batch_size):
|
83 |
+
LOG.info(
|
84 |
+
f"Processing batch {i//batch_size+1}/{n_batches} for speaker {input_dir.stem}"
|
85 |
+
)
|
86 |
+
features = []
|
87 |
+
for path in paths[i : i + batch_size]:
|
88 |
+
with path.open("rb") as f:
|
89 |
+
features.append(
|
90 |
+
torch.load(f, weights_only=True)["content"]
|
91 |
+
.squeeze(0)
|
92 |
+
.numpy()
|
93 |
+
.T
|
94 |
+
)
|
95 |
+
features = np.concatenate(features, axis=0).astype(np.float32)
|
96 |
+
kmeans.partial_fit(features)
|
97 |
+
LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
|
98 |
+
|
99 |
+
x = {
|
100 |
+
"n_features_in_": kmeans.n_features_in_,
|
101 |
+
"_n_threads": kmeans._n_threads,
|
102 |
+
"cluster_centers_": kmeans.cluster_centers_,
|
103 |
+
}
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
def main(
|
108 |
+
input_dir: Path | str,
|
109 |
+
output_path: Path | str,
|
110 |
+
n_clusters: int = 10000,
|
111 |
+
use_minibatch: bool = True,
|
112 |
+
batch_size: int = 4096,
|
113 |
+
partial_fit: bool = False,
|
114 |
+
verbose: bool = False,
|
115 |
+
) -> None:
|
116 |
+
input_dir = Path(input_dir)
|
117 |
+
output_path = Path(output_path)
|
118 |
+
|
119 |
+
if not (use_minibatch or not partial_fit):
|
120 |
+
raise ValueError("partial_fit requires use_minibatch")
|
121 |
+
|
122 |
+
def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
|
123 |
+
return input_path.stem, train_cluster(input_path, **kwargs)
|
124 |
+
|
125 |
+
with tqdm_joblib(desc="Training clusters", total=len(list(input_dir.iterdir()))):
|
126 |
+
parallel_result = Parallel(n_jobs=-1)(
|
127 |
+
delayed(train_cluster_)(
|
128 |
+
speaker_name,
|
129 |
+
n_clusters=n_clusters,
|
130 |
+
use_minibatch=use_minibatch,
|
131 |
+
batch_size=batch_size,
|
132 |
+
partial_fit=partial_fit,
|
133 |
+
verbose=verbose,
|
134 |
+
)
|
135 |
+
for speaker_name in input_dir.iterdir()
|
136 |
+
)
|
137 |
+
assert parallel_result is not None
|
138 |
+
checkpoint = dict(parallel_result)
|
139 |
+
output_path.parent.mkdir(exist_ok=True, parents=True)
|
140 |
+
with output_path.open("wb") as f:
|
141 |
+
torch.save(checkpoint, f)
|
so_vits_svc_fork/dataset.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from random import Random
|
5 |
+
from typing import Sequence
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
|
12 |
+
from .hparams import HParams
|
13 |
+
|
14 |
+
|
15 |
+
class TextAudioDataset(Dataset):
|
16 |
+
def __init__(self, hps: HParams, is_validation: bool = False):
|
17 |
+
self.datapaths = [
|
18 |
+
Path(x).parent / (Path(x).name + ".data.pt")
|
19 |
+
for x in Path(
|
20 |
+
hps.data.validation_files if is_validation else hps.data.training_files
|
21 |
+
)
|
22 |
+
.read_text("utf-8")
|
23 |
+
.splitlines()
|
24 |
+
]
|
25 |
+
self.hps = hps
|
26 |
+
self.random = Random(hps.train.seed)
|
27 |
+
self.random.shuffle(self.datapaths)
|
28 |
+
self.max_spec_len = 800
|
29 |
+
|
30 |
+
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
31 |
+
with Path(self.datapaths[index]).open("rb") as f:
|
32 |
+
data = torch.load(f, weights_only=True, map_location="cpu")
|
33 |
+
|
34 |
+
# cut long data randomly
|
35 |
+
spec_len = data["mel_spec"].shape[1]
|
36 |
+
hop_len = self.hps.data.hop_length
|
37 |
+
if spec_len > self.max_spec_len:
|
38 |
+
start = self.random.randint(0, spec_len - self.max_spec_len)
|
39 |
+
end = start + self.max_spec_len - 10
|
40 |
+
for key in data.keys():
|
41 |
+
if key == "audio":
|
42 |
+
data[key] = data[key][:, start * hop_len : end * hop_len]
|
43 |
+
elif key == "spk":
|
44 |
+
continue
|
45 |
+
else:
|
46 |
+
data[key] = data[key][..., start:end]
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
return data
|
49 |
+
|
50 |
+
def __len__(self) -> int:
|
51 |
+
return len(self.datapaths)
|
52 |
+
|
53 |
+
|
54 |
+
def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor:
|
55 |
+
max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array]))
|
56 |
+
max_x = array[max_idx]
|
57 |
+
x_padded = [
|
58 |
+
F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0)
|
59 |
+
for x_ in array
|
60 |
+
]
|
61 |
+
return torch.stack(x_padded)
|
62 |
+
|
63 |
+
|
64 |
+
class TextAudioCollate(nn.Module):
|
65 |
+
def forward(
|
66 |
+
self, batch: Sequence[dict[str, torch.Tensor]]
|
67 |
+
) -> tuple[torch.Tensor, ...]:
|
68 |
+
batch = [b for b in batch if b is not None]
|
69 |
+
batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True))
|
70 |
+
lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long()
|
71 |
+
results = {}
|
72 |
+
for key in batch[0].keys():
|
73 |
+
if key not in ["spk"]:
|
74 |
+
results[key] = _pad_stack([b[key] for b in batch]).cpu()
|
75 |
+
else:
|
76 |
+
results[key] = torch.tensor([[b[key]] for b in batch]).cpu()
|
77 |
+
|
78 |
+
return (
|
79 |
+
results["content"],
|
80 |
+
results["f0"],
|
81 |
+
results["spec"],
|
82 |
+
results["mel_spec"],
|
83 |
+
results["audio"],
|
84 |
+
results["spk"],
|
85 |
+
lengths,
|
86 |
+
results["uv"],
|
87 |
+
)
|
so_vits_svc_fork/default_gui_presets.json
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Default VC (GPU, GTX 1060)": {
|
3 |
+
"silence_threshold": -35.0,
|
4 |
+
"transpose": 12.0,
|
5 |
+
"auto_predict_f0": false,
|
6 |
+
"f0_method": "dio",
|
7 |
+
"cluster_infer_ratio": 0.0,
|
8 |
+
"noise_scale": 0.4,
|
9 |
+
"pad_seconds": 0.1,
|
10 |
+
"chunk_seconds": 0.5,
|
11 |
+
"absolute_thresh": true,
|
12 |
+
"max_chunk_seconds": 40,
|
13 |
+
"crossfade_seconds": 0.05,
|
14 |
+
"block_seconds": 0.35,
|
15 |
+
"additional_infer_before_seconds": 0.15,
|
16 |
+
"additional_infer_after_seconds": 0.1,
|
17 |
+
"realtime_algorithm": "1 (Divide constantly)",
|
18 |
+
"passthrough_original": false,
|
19 |
+
"use_gpu": true
|
20 |
+
},
|
21 |
+
"Default VC (CPU)": {
|
22 |
+
"silence_threshold": -35.0,
|
23 |
+
"transpose": 12.0,
|
24 |
+
"auto_predict_f0": false,
|
25 |
+
"f0_method": "dio",
|
26 |
+
"cluster_infer_ratio": 0.0,
|
27 |
+
"noise_scale": 0.4,
|
28 |
+
"pad_seconds": 0.1,
|
29 |
+
"chunk_seconds": 0.5,
|
30 |
+
"absolute_thresh": true,
|
31 |
+
"max_chunk_seconds": 40,
|
32 |
+
"crossfade_seconds": 0.05,
|
33 |
+
"block_seconds": 1.5,
|
34 |
+
"additional_infer_before_seconds": 0.01,
|
35 |
+
"additional_infer_after_seconds": 0.01,
|
36 |
+
"realtime_algorithm": "1 (Divide constantly)",
|
37 |
+
"passthrough_original": false,
|
38 |
+
"use_gpu": false
|
39 |
+
},
|
40 |
+
"Default VC (Mobile CPU)": {
|
41 |
+
"silence_threshold": -35.0,
|
42 |
+
"transpose": 12.0,
|
43 |
+
"auto_predict_f0": false,
|
44 |
+
"f0_method": "dio",
|
45 |
+
"cluster_infer_ratio": 0.0,
|
46 |
+
"noise_scale": 0.4,
|
47 |
+
"pad_seconds": 0.1,
|
48 |
+
"chunk_seconds": 0.5,
|
49 |
+
"absolute_thresh": true,
|
50 |
+
"max_chunk_seconds": 40,
|
51 |
+
"crossfade_seconds": 0.05,
|
52 |
+
"block_seconds": 2.5,
|
53 |
+
"additional_infer_before_seconds": 0.01,
|
54 |
+
"additional_infer_after_seconds": 0.01,
|
55 |
+
"realtime_algorithm": "1 (Divide constantly)",
|
56 |
+
"passthrough_original": false,
|
57 |
+
"use_gpu": false
|
58 |
+
},
|
59 |
+
"Default VC (Crooning)": {
|
60 |
+
"silence_threshold": -35.0,
|
61 |
+
"transpose": 12.0,
|
62 |
+
"auto_predict_f0": false,
|
63 |
+
"f0_method": "dio",
|
64 |
+
"cluster_infer_ratio": 0.0,
|
65 |
+
"noise_scale": 0.4,
|
66 |
+
"pad_seconds": 0.1,
|
67 |
+
"chunk_seconds": 0.5,
|
68 |
+
"absolute_thresh": true,
|
69 |
+
"max_chunk_seconds": 40,
|
70 |
+
"crossfade_seconds": 0.04,
|
71 |
+
"block_seconds": 0.15,
|
72 |
+
"additional_infer_before_seconds": 0.05,
|
73 |
+
"additional_infer_after_seconds": 0.05,
|
74 |
+
"realtime_algorithm": "1 (Divide constantly)",
|
75 |
+
"passthrough_original": false,
|
76 |
+
"use_gpu": true
|
77 |
+
},
|
78 |
+
"Default File": {
|
79 |
+
"silence_threshold": -35.0,
|
80 |
+
"transpose": 0.0,
|
81 |
+
"auto_predict_f0": true,
|
82 |
+
"f0_method": "crepe",
|
83 |
+
"cluster_infer_ratio": 0.0,
|
84 |
+
"noise_scale": 0.4,
|
85 |
+
"pad_seconds": 0.1,
|
86 |
+
"chunk_seconds": 0.5,
|
87 |
+
"absolute_thresh": true,
|
88 |
+
"max_chunk_seconds": 40,
|
89 |
+
"auto_play": true,
|
90 |
+
"passthrough_original": false
|
91 |
+
}
|
92 |
+
}
|
so_vits_svc_fork/f0.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from typing import Any, Literal
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchcrepe
|
9 |
+
from cm_time import timer
|
10 |
+
from numpy import dtype, float32, ndarray
|
11 |
+
from torch import FloatTensor, Tensor
|
12 |
+
|
13 |
+
from so_vits_svc_fork.utils import get_optimal_device
|
14 |
+
|
15 |
+
LOG = getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def normalize_f0(
|
19 |
+
f0: FloatTensor, x_mask: FloatTensor, uv: FloatTensor, random_scale=True
|
20 |
+
) -> FloatTensor:
|
21 |
+
# calculate means based on x_mask
|
22 |
+
uv_sum = torch.sum(uv, dim=1, keepdim=True)
|
23 |
+
uv_sum[uv_sum == 0] = 9999
|
24 |
+
means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
|
25 |
+
|
26 |
+
if random_scale:
|
27 |
+
factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
|
28 |
+
else:
|
29 |
+
factor = torch.ones(f0.shape[0], 1).to(f0.device)
|
30 |
+
# normalize f0 based on means and factor
|
31 |
+
f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
|
32 |
+
if torch.isnan(f0_norm).any():
|
33 |
+
exit(0)
|
34 |
+
return f0_norm * x_mask
|
35 |
+
|
36 |
+
|
37 |
+
def interpolate_f0(
|
38 |
+
f0: ndarray[Any, dtype[float32]]
|
39 |
+
) -> tuple[ndarray[Any, dtype[float32]], ndarray[Any, dtype[float32]]]:
|
40 |
+
data = np.reshape(f0, (f0.size, 1))
|
41 |
+
|
42 |
+
vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
|
43 |
+
vuv_vector[data > 0.0] = 1.0
|
44 |
+
vuv_vector[data <= 0.0] = 0.0
|
45 |
+
|
46 |
+
ip_data = data
|
47 |
+
|
48 |
+
frame_number = data.size
|
49 |
+
last_value = 0.0
|
50 |
+
for i in range(frame_number):
|
51 |
+
if data[i] <= 0.0:
|
52 |
+
j = i + 1
|
53 |
+
for j in range(i + 1, frame_number):
|
54 |
+
if data[j] > 0.0:
|
55 |
+
break
|
56 |
+
if j < frame_number - 1:
|
57 |
+
if last_value > 0.0:
|
58 |
+
step = (data[j] - data[i - 1]) / float(j - i)
|
59 |
+
for k in range(i, j):
|
60 |
+
ip_data[k] = data[i - 1] + step * (k - i + 1)
|
61 |
+
else:
|
62 |
+
for k in range(i, j):
|
63 |
+
ip_data[k] = data[j]
|
64 |
+
else:
|
65 |
+
for k in range(i, frame_number):
|
66 |
+
ip_data[k] = last_value
|
67 |
+
else:
|
68 |
+
ip_data[i] = data[i]
|
69 |
+
last_value = data[i]
|
70 |
+
|
71 |
+
return ip_data[:, 0], vuv_vector[:, 0]
|
72 |
+
|
73 |
+
|
74 |
+
def compute_f0_parselmouth(
|
75 |
+
wav_numpy: ndarray[Any, dtype[float32]],
|
76 |
+
p_len: None | int = None,
|
77 |
+
sampling_rate: int = 44100,
|
78 |
+
hop_length: int = 512,
|
79 |
+
):
|
80 |
+
import parselmouth
|
81 |
+
|
82 |
+
x = wav_numpy
|
83 |
+
if p_len is None:
|
84 |
+
p_len = x.shape[0] // hop_length
|
85 |
+
else:
|
86 |
+
assert abs(p_len - x.shape[0] // hop_length) < 4, "pad length error"
|
87 |
+
time_step = hop_length / sampling_rate * 1000
|
88 |
+
f0_min = 50
|
89 |
+
f0_max = 1100
|
90 |
+
f0 = (
|
91 |
+
parselmouth.Sound(x, sampling_rate)
|
92 |
+
.to_pitch_ac(
|
93 |
+
time_step=time_step / 1000,
|
94 |
+
voicing_threshold=0.6,
|
95 |
+
pitch_floor=f0_min,
|
96 |
+
pitch_ceiling=f0_max,
|
97 |
+
)
|
98 |
+
.selected_array["frequency"]
|
99 |
+
)
|
100 |
+
|
101 |
+
pad_size = (p_len - len(f0) + 1) // 2
|
102 |
+
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
103 |
+
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
104 |
+
return f0
|
105 |
+
|
106 |
+
|
107 |
+
def _resize_f0(
|
108 |
+
x: ndarray[Any, dtype[float32]], target_len: int
|
109 |
+
) -> ndarray[Any, dtype[float32]]:
|
110 |
+
source = np.array(x)
|
111 |
+
source[source < 0.001] = np.nan
|
112 |
+
target = np.interp(
|
113 |
+
np.arange(0, len(source) * target_len, len(source)) / target_len,
|
114 |
+
np.arange(0, len(source)),
|
115 |
+
source,
|
116 |
+
)
|
117 |
+
res = np.nan_to_num(target)
|
118 |
+
return res
|
119 |
+
|
120 |
+
|
121 |
+
def compute_f0_pyworld(
|
122 |
+
wav_numpy: ndarray[Any, dtype[float32]],
|
123 |
+
p_len: None | int = None,
|
124 |
+
sampling_rate: int = 44100,
|
125 |
+
hop_length: int = 512,
|
126 |
+
type_: Literal["dio", "harvest"] = "dio",
|
127 |
+
):
|
128 |
+
import pyworld
|
129 |
+
|
130 |
+
if p_len is None:
|
131 |
+
p_len = wav_numpy.shape[0] // hop_length
|
132 |
+
if type_ == "dio":
|
133 |
+
f0, t = pyworld.dio(
|
134 |
+
wav_numpy.astype(np.double),
|
135 |
+
fs=sampling_rate,
|
136 |
+
f0_ceil=f0_max,
|
137 |
+
f0_floor=f0_min,
|
138 |
+
frame_period=1000 * hop_length / sampling_rate,
|
139 |
+
)
|
140 |
+
elif type_ == "harvest":
|
141 |
+
f0, t = pyworld.harvest(
|
142 |
+
wav_numpy.astype(np.double),
|
143 |
+
fs=sampling_rate,
|
144 |
+
f0_ceil=f0_max,
|
145 |
+
f0_floor=f0_min,
|
146 |
+
frame_period=1000 * hop_length / sampling_rate,
|
147 |
+
)
|
148 |
+
f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
|
149 |
+
for index, pitch in enumerate(f0):
|
150 |
+
f0[index] = round(pitch, 1)
|
151 |
+
return _resize_f0(f0, p_len)
|
152 |
+
|
153 |
+
|
154 |
+
def compute_f0_crepe(
|
155 |
+
wav_numpy: ndarray[Any, dtype[float32]],
|
156 |
+
p_len: None | int = None,
|
157 |
+
sampling_rate: int = 44100,
|
158 |
+
hop_length: int = 512,
|
159 |
+
device: str | torch.device = get_optimal_device(),
|
160 |
+
model: Literal["full", "tiny"] = "full",
|
161 |
+
):
|
162 |
+
audio = torch.from_numpy(wav_numpy).to(device, copy=True)
|
163 |
+
audio = torch.unsqueeze(audio, dim=0)
|
164 |
+
|
165 |
+
if audio.ndim == 2 and audio.shape[0] > 1:
|
166 |
+
audio = torch.mean(audio, dim=0, keepdim=True).detach()
|
167 |
+
# (T) -> (1, T)
|
168 |
+
audio = audio.detach()
|
169 |
+
|
170 |
+
pitch: Tensor = torchcrepe.predict(
|
171 |
+
audio,
|
172 |
+
sampling_rate,
|
173 |
+
hop_length,
|
174 |
+
f0_min,
|
175 |
+
f0_max,
|
176 |
+
model,
|
177 |
+
batch_size=hop_length * 2,
|
178 |
+
device=device,
|
179 |
+
pad=True,
|
180 |
+
)
|
181 |
+
|
182 |
+
f0 = pitch.squeeze(0).cpu().float().numpy()
|
183 |
+
p_len = p_len or wav_numpy.shape[0] // hop_length
|
184 |
+
f0 = _resize_f0(f0, p_len)
|
185 |
+
return f0
|
186 |
+
|
187 |
+
|
188 |
+
def compute_f0(
|
189 |
+
wav_numpy: ndarray[Any, dtype[float32]],
|
190 |
+
p_len: None | int = None,
|
191 |
+
sampling_rate: int = 44100,
|
192 |
+
hop_length: int = 512,
|
193 |
+
method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
|
194 |
+
**kwargs,
|
195 |
+
):
|
196 |
+
with timer() as t:
|
197 |
+
wav_numpy = wav_numpy.astype(np.float32)
|
198 |
+
wav_numpy /= np.quantile(np.abs(wav_numpy), 0.999)
|
199 |
+
if method in ["dio", "harvest"]:
|
200 |
+
f0 = compute_f0_pyworld(wav_numpy, p_len, sampling_rate, hop_length, method)
|
201 |
+
elif method == "crepe":
|
202 |
+
f0 = compute_f0_crepe(wav_numpy, p_len, sampling_rate, hop_length, **kwargs)
|
203 |
+
elif method == "crepe-tiny":
|
204 |
+
f0 = compute_f0_crepe(
|
205 |
+
wav_numpy, p_len, sampling_rate, hop_length, model="tiny", **kwargs
|
206 |
+
)
|
207 |
+
elif method == "parselmouth":
|
208 |
+
f0 = compute_f0_parselmouth(wav_numpy, p_len, sampling_rate, hop_length)
|
209 |
+
else:
|
210 |
+
raise ValueError(
|
211 |
+
"type must be dio, crepe, crepe-tiny, harvest or parselmouth"
|
212 |
+
)
|
213 |
+
rtf = t.elapsed / (len(wav_numpy) / sampling_rate)
|
214 |
+
LOG.info(f"F0 inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
|
215 |
+
return f0
|
216 |
+
|
217 |
+
|
218 |
+
def f0_to_coarse(f0: torch.Tensor | float):
|
219 |
+
is_torch = isinstance(f0, torch.Tensor)
|
220 |
+
f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
|
221 |
+
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (
|
222 |
+
f0_mel_max - f0_mel_min
|
223 |
+
) + 1
|
224 |
+
|
225 |
+
f0_mel[f0_mel <= 1] = 1
|
226 |
+
f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
|
227 |
+
f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
|
228 |
+
assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
|
229 |
+
f0_coarse.max(),
|
230 |
+
f0_coarse.min(),
|
231 |
+
)
|
232 |
+
return f0_coarse
|
233 |
+
|
234 |
+
|
235 |
+
f0_bin = 256
|
236 |
+
f0_max = 1100.0
|
237 |
+
f0_min = 50.0
|
238 |
+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
239 |
+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
so_vits_svc_fork/gui.py
ADDED
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import multiprocessing
|
5 |
+
import os
|
6 |
+
from copy import copy
|
7 |
+
from logging import getLogger
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import PySimpleGUI as sg
|
11 |
+
import sounddevice as sd
|
12 |
+
import soundfile as sf
|
13 |
+
import torch
|
14 |
+
from pebble import ProcessFuture, ProcessPool
|
15 |
+
|
16 |
+
from . import __version__
|
17 |
+
from .utils import get_optimal_device
|
18 |
+
|
19 |
+
GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json"
|
20 |
+
GUI_PRESETS_PATH = Path("./user_gui_presets.json").absolute()
|
21 |
+
|
22 |
+
LOG = getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def play_audio(path: Path | str):
|
26 |
+
if isinstance(path, Path):
|
27 |
+
path = path.as_posix()
|
28 |
+
data, sr = sf.read(path)
|
29 |
+
sd.play(data, sr)
|
30 |
+
|
31 |
+
|
32 |
+
def load_presets() -> dict:
|
33 |
+
defaults = json.loads(GUI_DEFAULT_PRESETS_PATH.read_text("utf-8"))
|
34 |
+
users = (
|
35 |
+
json.loads(GUI_PRESETS_PATH.read_text("utf-8"))
|
36 |
+
if GUI_PRESETS_PATH.exists()
|
37 |
+
else {}
|
38 |
+
)
|
39 |
+
# prioriy: defaults > users
|
40 |
+
# order: defaults -> users
|
41 |
+
return {**defaults, **users, **defaults}
|
42 |
+
|
43 |
+
|
44 |
+
def add_preset(name: str, preset: dict) -> dict:
|
45 |
+
presets = load_presets()
|
46 |
+
presets[name] = preset
|
47 |
+
with GUI_PRESETS_PATH.open("w") as f:
|
48 |
+
json.dump(presets, f, indent=2)
|
49 |
+
return load_presets()
|
50 |
+
|
51 |
+
|
52 |
+
def delete_preset(name: str) -> dict:
|
53 |
+
presets = load_presets()
|
54 |
+
if name in presets:
|
55 |
+
del presets[name]
|
56 |
+
else:
|
57 |
+
LOG.warning(f"Cannot delete preset {name} because it does not exist.")
|
58 |
+
with GUI_PRESETS_PATH.open("w") as f:
|
59 |
+
json.dump(presets, f, indent=2)
|
60 |
+
return load_presets()
|
61 |
+
|
62 |
+
|
63 |
+
def get_output_path(input_path: Path) -> Path:
|
64 |
+
# Default output path
|
65 |
+
output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
|
66 |
+
|
67 |
+
# Increment file number in path if output file already exists
|
68 |
+
file_num = 1
|
69 |
+
while output_path.exists():
|
70 |
+
output_path = (
|
71 |
+
input_path.parent / f"{input_path.stem}.out_{file_num}{input_path.suffix}"
|
72 |
+
)
|
73 |
+
file_num += 1
|
74 |
+
return output_path
|
75 |
+
|
76 |
+
|
77 |
+
def get_supported_file_types() -> tuple[tuple[str, str], ...]:
|
78 |
+
res = tuple(
|
79 |
+
[
|
80 |
+
(extension, f".{extension.lower()}")
|
81 |
+
for extension in sf.available_formats().keys()
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
# Sort by popularity
|
86 |
+
common_file_types = ["WAV", "MP3", "FLAC", "OGG", "M4A", "WMA"]
|
87 |
+
res = sorted(
|
88 |
+
res,
|
89 |
+
key=lambda x: common_file_types.index(x[0])
|
90 |
+
if x[0] in common_file_types
|
91 |
+
else len(common_file_types),
|
92 |
+
)
|
93 |
+
return res
|
94 |
+
|
95 |
+
|
96 |
+
def get_supported_file_types_concat() -> tuple[tuple[str, str], ...]:
|
97 |
+
return (("Audio", " ".join(sf.available_formats().keys())),)
|
98 |
+
|
99 |
+
|
100 |
+
def validate_output_file_type(output_path: Path) -> bool:
|
101 |
+
supported_file_types = sorted(
|
102 |
+
[f".{extension.lower()}" for extension in sf.available_formats().keys()]
|
103 |
+
)
|
104 |
+
if not output_path.suffix:
|
105 |
+
sg.popup_ok(
|
106 |
+
"Error: Output path missing file type extension, enter "
|
107 |
+
+ "one of the following manually:\n\n"
|
108 |
+
+ "\n".join(supported_file_types)
|
109 |
+
)
|
110 |
+
return False
|
111 |
+
if output_path.suffix.lower() not in supported_file_types:
|
112 |
+
sg.popup_ok(
|
113 |
+
f"Error: {output_path.suffix.lower()} is not a supported "
|
114 |
+
+ "extension; use one of the following:\n\n"
|
115 |
+
+ "\n".join(supported_file_types)
|
116 |
+
)
|
117 |
+
return False
|
118 |
+
return True
|
119 |
+
|
120 |
+
|
121 |
+
def get_devices(
|
122 |
+
update: bool = True,
|
123 |
+
) -> tuple[list[str], list[str], list[int], list[int]]:
|
124 |
+
if update:
|
125 |
+
sd._terminate()
|
126 |
+
sd._initialize()
|
127 |
+
devices = sd.query_devices()
|
128 |
+
hostapis = sd.query_hostapis()
|
129 |
+
for hostapi in hostapis:
|
130 |
+
for device_idx in hostapi["devices"]:
|
131 |
+
devices[device_idx]["hostapi_name"] = hostapi["name"]
|
132 |
+
input_devices = [
|
133 |
+
f"{d['name']} ({d['hostapi_name']})"
|
134 |
+
for d in devices
|
135 |
+
if d["max_input_channels"] > 0
|
136 |
+
]
|
137 |
+
output_devices = [
|
138 |
+
f"{d['name']} ({d['hostapi_name']})"
|
139 |
+
for d in devices
|
140 |
+
if d["max_output_channels"] > 0
|
141 |
+
]
|
142 |
+
input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
|
143 |
+
output_devices_indices = [
|
144 |
+
d["index"] for d in devices if d["max_output_channels"] > 0
|
145 |
+
]
|
146 |
+
return input_devices, output_devices, input_devices_indices, output_devices_indices
|
147 |
+
|
148 |
+
|
149 |
+
def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path: Path):
|
150 |
+
try:
|
151 |
+
LOG.info(f"Finished inference for {path.stem}{path.suffix}")
|
152 |
+
window["infer"].update(disabled=False)
|
153 |
+
|
154 |
+
if auto_play:
|
155 |
+
play_audio(output_path)
|
156 |
+
except Exception as e:
|
157 |
+
LOG.exception(e)
|
158 |
+
|
159 |
+
|
160 |
+
def main():
|
161 |
+
LOG.info(f"version: {__version__}")
|
162 |
+
|
163 |
+
# sg.theme("Dark")
|
164 |
+
sg.theme_add_new(
|
165 |
+
"Very Dark",
|
166 |
+
{
|
167 |
+
"BACKGROUND": "#111111",
|
168 |
+
"TEXT": "#FFFFFF",
|
169 |
+
"INPUT": "#444444",
|
170 |
+
"TEXT_INPUT": "#FFFFFF",
|
171 |
+
"SCROLL": "#333333",
|
172 |
+
"BUTTON": ("white", "#112233"),
|
173 |
+
"PROGRESS": ("#111111", "#333333"),
|
174 |
+
"BORDER": 2,
|
175 |
+
"SLIDER_DEPTH": 2,
|
176 |
+
"PROGRESS_DEPTH": 2,
|
177 |
+
},
|
178 |
+
)
|
179 |
+
sg.theme("Very Dark")
|
180 |
+
|
181 |
+
model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))
|
182 |
+
|
183 |
+
frame_contents = {
|
184 |
+
"Paths": [
|
185 |
+
[
|
186 |
+
sg.Text("Model path"),
|
187 |
+
sg.Push(),
|
188 |
+
sg.InputText(
|
189 |
+
key="model_path",
|
190 |
+
default_text=model_candidates[-1].absolute().as_posix()
|
191 |
+
if model_candidates
|
192 |
+
else "",
|
193 |
+
enable_events=True,
|
194 |
+
),
|
195 |
+
sg.FileBrowse(
|
196 |
+
initial_folder=Path("./logs/44k/").absolute
|
197 |
+
if Path("./logs/44k/").exists()
|
198 |
+
else Path(".").absolute().as_posix(),
|
199 |
+
key="model_path_browse",
|
200 |
+
file_types=(
|
201 |
+
("PyTorch", "G_*.pth G_*.pt"),
|
202 |
+
("Pytorch", "*.pth *.pt"),
|
203 |
+
),
|
204 |
+
),
|
205 |
+
],
|
206 |
+
[
|
207 |
+
sg.Text("Config path"),
|
208 |
+
sg.Push(),
|
209 |
+
sg.InputText(
|
210 |
+
key="config_path",
|
211 |
+
default_text=Path("./configs/44k/config.json").absolute().as_posix()
|
212 |
+
if Path("./configs/44k/config.json").exists()
|
213 |
+
else "",
|
214 |
+
enable_events=True,
|
215 |
+
),
|
216 |
+
sg.FileBrowse(
|
217 |
+
initial_folder=Path("./configs/44k/").as_posix()
|
218 |
+
if Path("./configs/44k/").exists()
|
219 |
+
else Path(".").absolute().as_posix(),
|
220 |
+
key="config_path_browse",
|
221 |
+
file_types=(("JSON", "*.json"),),
|
222 |
+
),
|
223 |
+
],
|
224 |
+
[
|
225 |
+
sg.Text("Cluster model path (Optional)"),
|
226 |
+
sg.Push(),
|
227 |
+
sg.InputText(
|
228 |
+
key="cluster_model_path",
|
229 |
+
default_text=Path("./logs/44k/kmeans.pt").absolute().as_posix()
|
230 |
+
if Path("./logs/44k/kmeans.pt").exists()
|
231 |
+
else "",
|
232 |
+
enable_events=True,
|
233 |
+
),
|
234 |
+
sg.FileBrowse(
|
235 |
+
initial_folder="./logs/44k/"
|
236 |
+
if Path("./logs/44k/").exists()
|
237 |
+
else ".",
|
238 |
+
key="cluster_model_path_browse",
|
239 |
+
file_types=(("PyTorch", "*.pt"), ("Pickle", "*.pt *.pth *.pkl")),
|
240 |
+
),
|
241 |
+
],
|
242 |
+
],
|
243 |
+
"Common": [
|
244 |
+
[
|
245 |
+
sg.Text("Speaker"),
|
246 |
+
sg.Push(),
|
247 |
+
sg.Combo(values=[], key="speaker", size=(20, 1)),
|
248 |
+
],
|
249 |
+
[
|
250 |
+
sg.Text("Silence threshold"),
|
251 |
+
sg.Push(),
|
252 |
+
sg.Slider(
|
253 |
+
range=(-60.0, 0),
|
254 |
+
orientation="h",
|
255 |
+
key="silence_threshold",
|
256 |
+
resolution=0.1,
|
257 |
+
),
|
258 |
+
],
|
259 |
+
[
|
260 |
+
sg.Text(
|
261 |
+
"Pitch (12 = 1 octave)\n"
|
262 |
+
"ADJUST THIS based on your voice\n"
|
263 |
+
"when Auto predict F0 is turned off.",
|
264 |
+
size=(None, 4),
|
265 |
+
),
|
266 |
+
sg.Push(),
|
267 |
+
sg.Slider(
|
268 |
+
range=(-36, 36),
|
269 |
+
orientation="h",
|
270 |
+
key="transpose",
|
271 |
+
tick_interval=12,
|
272 |
+
),
|
273 |
+
],
|
274 |
+
[
|
275 |
+
sg.Checkbox(
|
276 |
+
key="auto_predict_f0",
|
277 |
+
text="Auto predict F0 (Pitch may become unstable when turned on in real-time inference.)",
|
278 |
+
)
|
279 |
+
],
|
280 |
+
[
|
281 |
+
sg.Text("F0 prediction method"),
|
282 |
+
sg.Push(),
|
283 |
+
sg.Combo(
|
284 |
+
["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
|
285 |
+
key="f0_method",
|
286 |
+
),
|
287 |
+
],
|
288 |
+
[
|
289 |
+
sg.Text("Cluster infer ratio"),
|
290 |
+
sg.Push(),
|
291 |
+
sg.Slider(
|
292 |
+
range=(0, 1.0),
|
293 |
+
orientation="h",
|
294 |
+
key="cluster_infer_ratio",
|
295 |
+
resolution=0.01,
|
296 |
+
),
|
297 |
+
],
|
298 |
+
[
|
299 |
+
sg.Text("Noise scale"),
|
300 |
+
sg.Push(),
|
301 |
+
sg.Slider(
|
302 |
+
range=(0.0, 1.0),
|
303 |
+
orientation="h",
|
304 |
+
key="noise_scale",
|
305 |
+
resolution=0.01,
|
306 |
+
),
|
307 |
+
],
|
308 |
+
[
|
309 |
+
sg.Text("Pad seconds"),
|
310 |
+
sg.Push(),
|
311 |
+
sg.Slider(
|
312 |
+
range=(0.0, 1.0),
|
313 |
+
orientation="h",
|
314 |
+
key="pad_seconds",
|
315 |
+
resolution=0.01,
|
316 |
+
),
|
317 |
+
],
|
318 |
+
[
|
319 |
+
sg.Text("Chunk seconds"),
|
320 |
+
sg.Push(),
|
321 |
+
sg.Slider(
|
322 |
+
range=(0.0, 3.0),
|
323 |
+
orientation="h",
|
324 |
+
key="chunk_seconds",
|
325 |
+
resolution=0.01,
|
326 |
+
),
|
327 |
+
],
|
328 |
+
[
|
329 |
+
sg.Text("Max chunk seconds (set lower if Out Of Memory, 0 to disable)"),
|
330 |
+
sg.Push(),
|
331 |
+
sg.Slider(
|
332 |
+
range=(0.0, 240.0),
|
333 |
+
orientation="h",
|
334 |
+
key="max_chunk_seconds",
|
335 |
+
resolution=1.0,
|
336 |
+
),
|
337 |
+
],
|
338 |
+
[
|
339 |
+
sg.Checkbox(
|
340 |
+
key="absolute_thresh",
|
341 |
+
text="Absolute threshold (ignored (True) in realtime inference)",
|
342 |
+
)
|
343 |
+
],
|
344 |
+
],
|
345 |
+
"File": [
|
346 |
+
[
|
347 |
+
sg.Text("Input audio path"),
|
348 |
+
sg.Push(),
|
349 |
+
sg.InputText(key="input_path", enable_events=True),
|
350 |
+
sg.FileBrowse(
|
351 |
+
initial_folder=".",
|
352 |
+
key="input_path_browse",
|
353 |
+
file_types=get_supported_file_types_concat(),
|
354 |
+
),
|
355 |
+
sg.FolderBrowse(
|
356 |
+
button_text="Browse(Folder)",
|
357 |
+
initial_folder=".",
|
358 |
+
key="input_path_folder_browse",
|
359 |
+
target="input_path",
|
360 |
+
),
|
361 |
+
sg.Button("Play", key="play_input"),
|
362 |
+
],
|
363 |
+
[
|
364 |
+
sg.Text("Output audio path"),
|
365 |
+
sg.Push(),
|
366 |
+
sg.InputText(key="output_path"),
|
367 |
+
sg.FileSaveAs(
|
368 |
+
initial_folder=".",
|
369 |
+
key="output_path_browse",
|
370 |
+
file_types=get_supported_file_types(),
|
371 |
+
),
|
372 |
+
],
|
373 |
+
[sg.Checkbox(key="auto_play", text="Auto play", default=True)],
|
374 |
+
],
|
375 |
+
"Realtime": [
|
376 |
+
[
|
377 |
+
sg.Text("Crossfade seconds"),
|
378 |
+
sg.Push(),
|
379 |
+
sg.Slider(
|
380 |
+
range=(0, 0.6),
|
381 |
+
orientation="h",
|
382 |
+
key="crossfade_seconds",
|
383 |
+
resolution=0.001,
|
384 |
+
),
|
385 |
+
],
|
386 |
+
[
|
387 |
+
sg.Text(
|
388 |
+
"Block seconds", # \n(big -> more robust, slower, (the same) latency)"
|
389 |
+
tooltip="Big -> more robust, slower, (the same) latency",
|
390 |
+
),
|
391 |
+
sg.Push(),
|
392 |
+
sg.Slider(
|
393 |
+
range=(0, 3.0),
|
394 |
+
orientation="h",
|
395 |
+
key="block_seconds",
|
396 |
+
resolution=0.001,
|
397 |
+
),
|
398 |
+
],
|
399 |
+
[
|
400 |
+
sg.Text(
|
401 |
+
"Additional Infer seconds (before)", # \n(big -> more robust, slower)"
|
402 |
+
tooltip="Big -> more robust, slower, additional latency",
|
403 |
+
),
|
404 |
+
sg.Push(),
|
405 |
+
sg.Slider(
|
406 |
+
range=(0, 2.0),
|
407 |
+
orientation="h",
|
408 |
+
key="additional_infer_before_seconds",
|
409 |
+
resolution=0.001,
|
410 |
+
),
|
411 |
+
],
|
412 |
+
[
|
413 |
+
sg.Text(
|
414 |
+
"Additional Infer seconds (after)", # \n(big -> more robust, slower, additional latency)"
|
415 |
+
tooltip="Big -> more robust, slower, additional latency",
|
416 |
+
),
|
417 |
+
sg.Push(),
|
418 |
+
sg.Slider(
|
419 |
+
range=(0, 2.0),
|
420 |
+
orientation="h",
|
421 |
+
key="additional_infer_after_seconds",
|
422 |
+
resolution=0.001,
|
423 |
+
),
|
424 |
+
],
|
425 |
+
[
|
426 |
+
sg.Text("Realtime algorithm"),
|
427 |
+
sg.Push(),
|
428 |
+
sg.Combo(
|
429 |
+
["2 (Divide by speech)", "1 (Divide constantly)"],
|
430 |
+
default_value="1 (Divide constantly)",
|
431 |
+
key="realtime_algorithm",
|
432 |
+
),
|
433 |
+
],
|
434 |
+
[
|
435 |
+
sg.Text("Input device"),
|
436 |
+
sg.Push(),
|
437 |
+
sg.Combo(
|
438 |
+
key="input_device",
|
439 |
+
values=[],
|
440 |
+
size=(60, 1),
|
441 |
+
),
|
442 |
+
],
|
443 |
+
[
|
444 |
+
sg.Text("Output device"),
|
445 |
+
sg.Push(),
|
446 |
+
sg.Combo(
|
447 |
+
key="output_device",
|
448 |
+
values=[],
|
449 |
+
size=(60, 1),
|
450 |
+
),
|
451 |
+
],
|
452 |
+
[
|
453 |
+
sg.Checkbox(
|
454 |
+
"Passthrough original audio (for latency check)",
|
455 |
+
key="passthrough_original",
|
456 |
+
default=False,
|
457 |
+
),
|
458 |
+
sg.Push(),
|
459 |
+
sg.Button("Refresh devices", key="refresh_devices"),
|
460 |
+
],
|
461 |
+
[
|
462 |
+
sg.Frame(
|
463 |
+
"Notes",
|
464 |
+
[
|
465 |
+
[
|
466 |
+
sg.Text(
|
467 |
+
"In Realtime Inference:\n"
|
468 |
+
" - Setting F0 prediction method to 'crepe` may cause performance degradation.\n"
|
469 |
+
" - Auto Predict F0 must be turned off.\n"
|
470 |
+
"If the audio sounds mumbly and choppy:\n"
|
471 |
+
" Case: The inference has not been made in time (Increase Block seconds)\n"
|
472 |
+
" Case: Mic input is low (Decrease Silence threshold)\n"
|
473 |
+
)
|
474 |
+
]
|
475 |
+
],
|
476 |
+
),
|
477 |
+
],
|
478 |
+
],
|
479 |
+
"Presets": [
|
480 |
+
[
|
481 |
+
sg.Text("Presets"),
|
482 |
+
sg.Push(),
|
483 |
+
sg.Combo(
|
484 |
+
key="presets",
|
485 |
+
values=list(load_presets().keys()),
|
486 |
+
size=(40, 1),
|
487 |
+
enable_events=True,
|
488 |
+
),
|
489 |
+
sg.Button("Delete preset", key="delete_preset"),
|
490 |
+
],
|
491 |
+
[
|
492 |
+
sg.Text("Preset name"),
|
493 |
+
sg.Stretch(),
|
494 |
+
sg.InputText(key="preset_name", size=(26, 1)),
|
495 |
+
sg.Button("Add current settings as a preset", key="add_preset"),
|
496 |
+
],
|
497 |
+
],
|
498 |
+
}
|
499 |
+
|
500 |
+
# frames
|
501 |
+
frames = {}
|
502 |
+
for name, items in frame_contents.items():
|
503 |
+
frame = sg.Frame(name, items)
|
504 |
+
frame.expand_x = True
|
505 |
+
frames[name] = [frame]
|
506 |
+
|
507 |
+
bottoms = [
|
508 |
+
[
|
509 |
+
sg.Checkbox(
|
510 |
+
key="use_gpu",
|
511 |
+
default=get_optimal_device() != torch.device("cpu"),
|
512 |
+
text="Use GPU"
|
513 |
+
+ (
|
514 |
+
" (not available; if your device has GPU, make sure you installed PyTorch with CUDA support)"
|
515 |
+
if get_optimal_device() == torch.device("cpu")
|
516 |
+
else ""
|
517 |
+
),
|
518 |
+
disabled=get_optimal_device() == torch.device("cpu"),
|
519 |
+
)
|
520 |
+
],
|
521 |
+
[
|
522 |
+
sg.Button("Infer", key="infer"),
|
523 |
+
sg.Button("(Re)Start Voice Changer", key="start_vc"),
|
524 |
+
sg.Button("Stop Voice Changer", key="stop_vc"),
|
525 |
+
sg.Push(),
|
526 |
+
# sg.Button("ONNX Export", key="onnx_export"),
|
527 |
+
],
|
528 |
+
]
|
529 |
+
column1 = sg.Column(
|
530 |
+
[
|
531 |
+
frames["Paths"],
|
532 |
+
frames["Common"],
|
533 |
+
],
|
534 |
+
vertical_alignment="top",
|
535 |
+
)
|
536 |
+
column2 = sg.Column(
|
537 |
+
[
|
538 |
+
frames["File"],
|
539 |
+
frames["Realtime"],
|
540 |
+
frames["Presets"],
|
541 |
+
]
|
542 |
+
+ bottoms
|
543 |
+
)
|
544 |
+
# columns
|
545 |
+
layout = [[column1, column2]]
|
546 |
+
# get screen size
|
547 |
+
screen_width, screen_height = sg.Window.get_screen_size()
|
548 |
+
if screen_height < 720:
|
549 |
+
layout = [
|
550 |
+
[
|
551 |
+
sg.Column(
|
552 |
+
layout,
|
553 |
+
vertical_alignment="top",
|
554 |
+
scrollable=False,
|
555 |
+
expand_x=True,
|
556 |
+
expand_y=True,
|
557 |
+
vertical_scroll_only=True,
|
558 |
+
key="main_column",
|
559 |
+
)
|
560 |
+
]
|
561 |
+
]
|
562 |
+
window = sg.Window(
|
563 |
+
f"{__name__.split('.')[0].replace('_', '-')} v{__version__}",
|
564 |
+
layout,
|
565 |
+
grab_anywhere=True,
|
566 |
+
finalize=True,
|
567 |
+
scaling=1,
|
568 |
+
font=("Yu Gothic UI", 11) if os.name == "nt" else None,
|
569 |
+
# resizable=True,
|
570 |
+
# size=(1280, 720),
|
571 |
+
# Below disables taskbar, which may be not useful for some users
|
572 |
+
# use_custom_titlebar=True, no_titlebar=False
|
573 |
+
# Keep on top
|
574 |
+
# keep_on_top=True
|
575 |
+
)
|
576 |
+
|
577 |
+
# event, values = window.read(timeout=0.01)
|
578 |
+
# window["main_column"].Scrollable = True
|
579 |
+
|
580 |
+
# make slider height smaller
|
581 |
+
try:
|
582 |
+
for v in window.element_list():
|
583 |
+
if isinstance(v, sg.Slider):
|
584 |
+
v.Widget.configure(sliderrelief="flat", width=10, sliderlength=20)
|
585 |
+
except Exception as e:
|
586 |
+
LOG.exception(e)
|
587 |
+
|
588 |
+
# for n in ["input_device", "output_device"]:
|
589 |
+
# window[n].Widget.configure(justify="right")
|
590 |
+
event, values = window.read(timeout=0.01)
|
591 |
+
|
592 |
+
def update_speaker() -> None:
|
593 |
+
from . import utils
|
594 |
+
|
595 |
+
config_path = Path(values["config_path"])
|
596 |
+
if config_path.exists() and config_path.is_file():
|
597 |
+
hp = utils.get_hparams(values["config_path"])
|
598 |
+
LOG.debug(f"Loaded config from {values['config_path']}")
|
599 |
+
window["speaker"].update(
|
600 |
+
values=list(hp.__dict__["spk"].keys()), set_to_index=0
|
601 |
+
)
|
602 |
+
|
603 |
+
def update_devices() -> None:
|
604 |
+
(
|
605 |
+
input_devices,
|
606 |
+
output_devices,
|
607 |
+
input_device_indices,
|
608 |
+
output_device_indices,
|
609 |
+
) = get_devices()
|
610 |
+
input_device_indices_reversed = {
|
611 |
+
v: k for k, v in enumerate(input_device_indices)
|
612 |
+
}
|
613 |
+
output_device_indices_reversed = {
|
614 |
+
v: k for k, v in enumerate(output_device_indices)
|
615 |
+
}
|
616 |
+
window["input_device"].update(
|
617 |
+
values=input_devices, value=values["input_device"]
|
618 |
+
)
|
619 |
+
window["output_device"].update(
|
620 |
+
values=output_devices, value=values["output_device"]
|
621 |
+
)
|
622 |
+
input_default, output_default = sd.default.device
|
623 |
+
if values["input_device"] not in input_devices:
|
624 |
+
window["input_device"].update(
|
625 |
+
values=input_devices,
|
626 |
+
set_to_index=input_device_indices_reversed.get(input_default, 0),
|
627 |
+
)
|
628 |
+
if values["output_device"] not in output_devices:
|
629 |
+
window["output_device"].update(
|
630 |
+
values=output_devices,
|
631 |
+
set_to_index=output_device_indices_reversed.get(output_default, 0),
|
632 |
+
)
|
633 |
+
|
634 |
+
PRESET_KEYS = [
|
635 |
+
key
|
636 |
+
for key in values.keys()
|
637 |
+
if not any(exclude in key for exclude in ["preset", "browse"])
|
638 |
+
]
|
639 |
+
|
640 |
+
def apply_preset(name: str) -> None:
|
641 |
+
for key, value in load_presets()[name].items():
|
642 |
+
if key in PRESET_KEYS:
|
643 |
+
window[key].update(value)
|
644 |
+
values[key] = value
|
645 |
+
|
646 |
+
default_name = list(load_presets().keys())[0]
|
647 |
+
apply_preset(default_name)
|
648 |
+
window["presets"].update(default_name)
|
649 |
+
del default_name
|
650 |
+
update_speaker()
|
651 |
+
update_devices()
|
652 |
+
# with ProcessPool(max_workers=1) as pool:
|
653 |
+
# to support Linux
|
654 |
+
with ProcessPool(
|
655 |
+
max_workers=min(2, multiprocessing.cpu_count()),
|
656 |
+
context=multiprocessing.get_context("spawn"),
|
657 |
+
) as pool:
|
658 |
+
future: None | ProcessFuture = None
|
659 |
+
infer_futures: set[ProcessFuture] = set()
|
660 |
+
while True:
|
661 |
+
event, values = window.read(200)
|
662 |
+
if event == sg.WIN_CLOSED:
|
663 |
+
break
|
664 |
+
if not event == sg.EVENT_TIMEOUT:
|
665 |
+
LOG.info(f"Event {event}, values {values}")
|
666 |
+
if event.endswith("_path"):
|
667 |
+
for name in window.AllKeysDict:
|
668 |
+
if str(name).endswith("_browse"):
|
669 |
+
browser = window[name]
|
670 |
+
if isinstance(browser, sg.Button):
|
671 |
+
LOG.info(
|
672 |
+
f"Updating browser {browser} to {Path(values[event]).parent}"
|
673 |
+
)
|
674 |
+
browser.InitialFolder = Path(values[event]).parent
|
675 |
+
browser.update()
|
676 |
+
else:
|
677 |
+
LOG.warning(f"Browser {browser} is not a FileBrowse")
|
678 |
+
window["transpose"].update(
|
679 |
+
disabled=values["auto_predict_f0"],
|
680 |
+
visible=not values["auto_predict_f0"],
|
681 |
+
)
|
682 |
+
|
683 |
+
input_path = Path(values["input_path"])
|
684 |
+
output_path = Path(values["output_path"])
|
685 |
+
|
686 |
+
if event == "add_preset":
|
687 |
+
presets = add_preset(
|
688 |
+
values["preset_name"], {key: values[key] for key in PRESET_KEYS}
|
689 |
+
)
|
690 |
+
window["presets"].update(values=list(presets.keys()))
|
691 |
+
elif event == "delete_preset":
|
692 |
+
presets = delete_preset(values["presets"])
|
693 |
+
window["presets"].update(values=list(presets.keys()))
|
694 |
+
elif event == "presets":
|
695 |
+
apply_preset(values["presets"])
|
696 |
+
update_speaker()
|
697 |
+
elif event == "refresh_devices":
|
698 |
+
update_devices()
|
699 |
+
elif event == "config_path":
|
700 |
+
update_speaker()
|
701 |
+
elif event == "input_path":
|
702 |
+
# Don't change the output path if it's already set
|
703 |
+
# if values["output_path"]:
|
704 |
+
# continue
|
705 |
+
# Set a sensible default output path
|
706 |
+
window.Element("output_path").Update(str(get_output_path(input_path)))
|
707 |
+
elif event == "infer":
|
708 |
+
if "Default VC" in values["presets"]:
|
709 |
+
window["presets"].update(
|
710 |
+
set_to_index=list(load_presets().keys()).index("Default File")
|
711 |
+
)
|
712 |
+
apply_preset("Default File")
|
713 |
+
if values["input_path"] == "":
|
714 |
+
LOG.warning("Input path is empty.")
|
715 |
+
continue
|
716 |
+
if not input_path.exists():
|
717 |
+
LOG.warning(f"Input path {input_path} does not exist.")
|
718 |
+
continue
|
719 |
+
# if not validate_output_file_type(output_path):
|
720 |
+
# continue
|
721 |
+
|
722 |
+
try:
|
723 |
+
from so_vits_svc_fork.inference.main import infer
|
724 |
+
|
725 |
+
LOG.info("Starting inference...")
|
726 |
+
window["infer"].update(disabled=True)
|
727 |
+
infer_future = pool.schedule(
|
728 |
+
infer,
|
729 |
+
kwargs=dict(
|
730 |
+
# paths
|
731 |
+
model_path=Path(values["model_path"]),
|
732 |
+
output_path=output_path,
|
733 |
+
input_path=input_path,
|
734 |
+
config_path=Path(values["config_path"]),
|
735 |
+
recursive=True,
|
736 |
+
# svc config
|
737 |
+
speaker=values["speaker"],
|
738 |
+
cluster_model_path=Path(values["cluster_model_path"])
|
739 |
+
if values["cluster_model_path"]
|
740 |
+
else None,
|
741 |
+
transpose=values["transpose"],
|
742 |
+
auto_predict_f0=values["auto_predict_f0"],
|
743 |
+
cluster_infer_ratio=values["cluster_infer_ratio"],
|
744 |
+
noise_scale=values["noise_scale"],
|
745 |
+
f0_method=values["f0_method"],
|
746 |
+
# slice config
|
747 |
+
db_thresh=values["silence_threshold"],
|
748 |
+
pad_seconds=values["pad_seconds"],
|
749 |
+
chunk_seconds=values["chunk_seconds"],
|
750 |
+
absolute_thresh=values["absolute_thresh"],
|
751 |
+
max_chunk_seconds=values["max_chunk_seconds"],
|
752 |
+
device="cpu"
|
753 |
+
if not values["use_gpu"]
|
754 |
+
else get_optimal_device(),
|
755 |
+
),
|
756 |
+
)
|
757 |
+
infer_future.add_done_callback(
|
758 |
+
lambda _future: after_inference(
|
759 |
+
window, input_path, values["auto_play"], output_path
|
760 |
+
)
|
761 |
+
)
|
762 |
+
infer_futures.add(infer_future)
|
763 |
+
except Exception as e:
|
764 |
+
LOG.exception(e)
|
765 |
+
elif event == "play_input":
|
766 |
+
if Path(values["input_path"]).exists():
|
767 |
+
pool.schedule(play_audio, args=[Path(values["input_path"])])
|
768 |
+
elif event == "start_vc":
|
769 |
+
_, _, input_device_indices, output_device_indices = get_devices(
|
770 |
+
update=False
|
771 |
+
)
|
772 |
+
from so_vits_svc_fork.inference.main import realtime
|
773 |
+
|
774 |
+
if future:
|
775 |
+
LOG.info("Canceling previous task")
|
776 |
+
future.cancel()
|
777 |
+
future = pool.schedule(
|
778 |
+
realtime,
|
779 |
+
kwargs=dict(
|
780 |
+
# paths
|
781 |
+
model_path=Path(values["model_path"]),
|
782 |
+
config_path=Path(values["config_path"]),
|
783 |
+
speaker=values["speaker"],
|
784 |
+
# svc config
|
785 |
+
cluster_model_path=Path(values["cluster_model_path"])
|
786 |
+
if values["cluster_model_path"]
|
787 |
+
else None,
|
788 |
+
transpose=values["transpose"],
|
789 |
+
auto_predict_f0=values["auto_predict_f0"],
|
790 |
+
cluster_infer_ratio=values["cluster_infer_ratio"],
|
791 |
+
noise_scale=values["noise_scale"],
|
792 |
+
f0_method=values["f0_method"],
|
793 |
+
# slice config
|
794 |
+
db_thresh=values["silence_threshold"],
|
795 |
+
pad_seconds=values["pad_seconds"],
|
796 |
+
chunk_seconds=values["chunk_seconds"],
|
797 |
+
# realtime config
|
798 |
+
crossfade_seconds=values["crossfade_seconds"],
|
799 |
+
additional_infer_before_seconds=values[
|
800 |
+
"additional_infer_before_seconds"
|
801 |
+
],
|
802 |
+
additional_infer_after_seconds=values[
|
803 |
+
"additional_infer_after_seconds"
|
804 |
+
],
|
805 |
+
block_seconds=values["block_seconds"],
|
806 |
+
version=int(values["realtime_algorithm"][0]),
|
807 |
+
input_device=input_device_indices[
|
808 |
+
window["input_device"].widget.current()
|
809 |
+
],
|
810 |
+
output_device=output_device_indices[
|
811 |
+
window["output_device"].widget.current()
|
812 |
+
],
|
813 |
+
device=get_optimal_device() if values["use_gpu"] else "cpu",
|
814 |
+
passthrough_original=values["passthrough_original"],
|
815 |
+
),
|
816 |
+
)
|
817 |
+
elif event == "stop_vc":
|
818 |
+
if future:
|
819 |
+
future.cancel()
|
820 |
+
future = None
|
821 |
+
elif event == "onnx_export":
|
822 |
+
try:
|
823 |
+
raise NotImplementedError("ONNX export is not implemented yet.")
|
824 |
+
from so_vits_svc_fork.modules.onnx._export import onnx_export
|
825 |
+
|
826 |
+
onnx_export(
|
827 |
+
input_path=Path(values["model_path"]),
|
828 |
+
output_path=Path(values["model_path"]).with_suffix(".onnx"),
|
829 |
+
config_path=Path(values["config_path"]),
|
830 |
+
device="cpu",
|
831 |
+
)
|
832 |
+
except Exception as e:
|
833 |
+
LOG.exception(e)
|
834 |
+
if future is not None and future.done():
|
835 |
+
try:
|
836 |
+
future.result()
|
837 |
+
except Exception as e:
|
838 |
+
LOG.error("Error in realtime: ")
|
839 |
+
LOG.exception(e)
|
840 |
+
future = None
|
841 |
+
for future in copy(infer_futures):
|
842 |
+
if future.done():
|
843 |
+
try:
|
844 |
+
future.result()
|
845 |
+
except Exception as e:
|
846 |
+
LOG.error("Error in inference: ")
|
847 |
+
LOG.exception(e)
|
848 |
+
infer_futures.remove(future)
|
849 |
+
if future:
|
850 |
+
future.cancel()
|
851 |
+
window.close()
|
so_vits_svc_fork/hparams.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
|
6 |
+
class HParams:
|
7 |
+
def __init__(self, **kwargs: Any) -> None:
|
8 |
+
for k, v in kwargs.items():
|
9 |
+
if type(v) == dict:
|
10 |
+
v = HParams(**v)
|
11 |
+
self[k] = v
|
12 |
+
|
13 |
+
def keys(self):
|
14 |
+
return self.__dict__.keys()
|
15 |
+
|
16 |
+
def items(self):
|
17 |
+
return self.__dict__.items()
|
18 |
+
|
19 |
+
def values(self):
|
20 |
+
return self.__dict__.values()
|
21 |
+
|
22 |
+
def get(self, key: str, default: Any = None):
|
23 |
+
return self.__dict__.get(key, default)
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.__dict__)
|
27 |
+
|
28 |
+
def __getitem__(self, key):
|
29 |
+
return getattr(self, key)
|
30 |
+
|
31 |
+
def __setitem__(self, key, value):
|
32 |
+
return setattr(self, key, value)
|
33 |
+
|
34 |
+
def __contains__(self, key):
|
35 |
+
return key in self.__dict__
|
36 |
+
|
37 |
+
def __repr__(self):
|
38 |
+
return self.__dict__.__repr__()
|
so_vits_svc_fork/inference/__init__.py
ADDED
File without changes
|
so_vits_svc_fork/inference/core.py
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
from logging import getLogger
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Callable, Iterable, Literal
|
7 |
+
|
8 |
+
import attrs
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from cm_time import timer
|
13 |
+
from numpy import dtype, float32, ndarray
|
14 |
+
|
15 |
+
import so_vits_svc_fork.f0
|
16 |
+
from so_vits_svc_fork import cluster, utils
|
17 |
+
|
18 |
+
from ..modules.synthesizers import SynthesizerTrn
|
19 |
+
from ..utils import get_optimal_device
|
20 |
+
|
21 |
+
LOG = getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def pad_array(array_, target_length: int):
|
25 |
+
current_length = array_.shape[0]
|
26 |
+
if current_length >= target_length:
|
27 |
+
return array_[
|
28 |
+
(current_length - target_length)
|
29 |
+
// 2 : (current_length - target_length)
|
30 |
+
// 2
|
31 |
+
+ target_length,
|
32 |
+
...,
|
33 |
+
]
|
34 |
+
else:
|
35 |
+
pad_width = target_length - current_length
|
36 |
+
pad_left = pad_width // 2
|
37 |
+
pad_right = pad_width - pad_left
|
38 |
+
padded_arr = np.pad(
|
39 |
+
array_, (pad_left, pad_right), "constant", constant_values=(0, 0)
|
40 |
+
)
|
41 |
+
return padded_arr
|
42 |
+
|
43 |
+
|
44 |
+
@attrs.frozen(kw_only=True)
|
45 |
+
class Chunk:
|
46 |
+
is_speech: bool
|
47 |
+
audio: ndarray[Any, dtype[float32]]
|
48 |
+
start: int
|
49 |
+
end: int
|
50 |
+
|
51 |
+
@property
|
52 |
+
def duration(self) -> float32:
|
53 |
+
# return self.end - self.start
|
54 |
+
return float32(self.audio.shape[0])
|
55 |
+
|
56 |
+
def __repr__(self) -> str:
|
57 |
+
return f"Chunk(Speech: {self.is_speech}, {self.duration})"
|
58 |
+
|
59 |
+
|
60 |
+
def split_silence(
|
61 |
+
audio: ndarray[Any, dtype[float32]],
|
62 |
+
top_db: int = 40,
|
63 |
+
ref: float | Callable[[ndarray[Any, dtype[float32]]], float] = 1,
|
64 |
+
frame_length: int = 2048,
|
65 |
+
hop_length: int = 512,
|
66 |
+
aggregate: Callable[[ndarray[Any, dtype[float32]]], float] = np.mean,
|
67 |
+
max_chunk_length: int = 0,
|
68 |
+
) -> Iterable[Chunk]:
|
69 |
+
non_silence_indices = librosa.effects.split(
|
70 |
+
audio,
|
71 |
+
top_db=top_db,
|
72 |
+
ref=ref,
|
73 |
+
frame_length=frame_length,
|
74 |
+
hop_length=hop_length,
|
75 |
+
aggregate=aggregate,
|
76 |
+
)
|
77 |
+
last_end = 0
|
78 |
+
for start, end in non_silence_indices:
|
79 |
+
if start != last_end:
|
80 |
+
yield Chunk(
|
81 |
+
is_speech=False, audio=audio[last_end:start], start=last_end, end=start
|
82 |
+
)
|
83 |
+
while max_chunk_length > 0 and end - start > max_chunk_length:
|
84 |
+
yield Chunk(
|
85 |
+
is_speech=True,
|
86 |
+
audio=audio[start : start + max_chunk_length],
|
87 |
+
start=start,
|
88 |
+
end=start + max_chunk_length,
|
89 |
+
)
|
90 |
+
start += max_chunk_length
|
91 |
+
if end - start > 0:
|
92 |
+
yield Chunk(is_speech=True, audio=audio[start:end], start=start, end=end)
|
93 |
+
last_end = end
|
94 |
+
if last_end != len(audio):
|
95 |
+
yield Chunk(
|
96 |
+
is_speech=False, audio=audio[last_end:], start=last_end, end=len(audio)
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
class Svc:
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
*,
|
104 |
+
net_g_path: Path | str,
|
105 |
+
config_path: Path | str,
|
106 |
+
device: torch.device | str | None = None,
|
107 |
+
cluster_model_path: Path | str | None = None,
|
108 |
+
half: bool = False,
|
109 |
+
):
|
110 |
+
self.net_g_path = net_g_path
|
111 |
+
if device is None:
|
112 |
+
self.device = (get_optimal_device(),)
|
113 |
+
else:
|
114 |
+
self.device = torch.device(device)
|
115 |
+
self.hps = utils.get_hparams(config_path)
|
116 |
+
self.target_sample = self.hps.data.sampling_rate
|
117 |
+
self.hop_size = self.hps.data.hop_length
|
118 |
+
self.spk2id = self.hps.spk
|
119 |
+
self.hubert_model = utils.get_hubert_model(
|
120 |
+
self.device, self.hps.data.get("contentvec_final_proj", True)
|
121 |
+
)
|
122 |
+
self.dtype = torch.float16 if half else torch.float32
|
123 |
+
self.contentvec_final_proj = self.hps.data.__dict__.get(
|
124 |
+
"contentvec_final_proj", True
|
125 |
+
)
|
126 |
+
self.load_model()
|
127 |
+
if cluster_model_path is not None and Path(cluster_model_path).exists():
|
128 |
+
self.cluster_model = cluster.get_cluster_model(cluster_model_path)
|
129 |
+
|
130 |
+
def load_model(self):
|
131 |
+
self.net_g = SynthesizerTrn(
|
132 |
+
self.hps.data.filter_length // 2 + 1,
|
133 |
+
self.hps.train.segment_size // self.hps.data.hop_length,
|
134 |
+
**self.hps.model,
|
135 |
+
)
|
136 |
+
_ = utils.load_checkpoint(self.net_g_path, self.net_g, None)
|
137 |
+
_ = self.net_g.eval()
|
138 |
+
for m in self.net_g.modules():
|
139 |
+
utils.remove_weight_norm_if_exists(m)
|
140 |
+
_ = self.net_g.to(self.device, dtype=self.dtype)
|
141 |
+
self.net_g = self.net_g
|
142 |
+
|
143 |
+
def get_unit_f0(
|
144 |
+
self,
|
145 |
+
audio: ndarray[Any, dtype[float32]],
|
146 |
+
tran: int,
|
147 |
+
cluster_infer_ratio: float,
|
148 |
+
speaker: int | str,
|
149 |
+
f0_method: Literal[
|
150 |
+
"crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
|
151 |
+
] = "dio",
|
152 |
+
):
|
153 |
+
f0 = so_vits_svc_fork.f0.compute_f0(
|
154 |
+
audio,
|
155 |
+
sampling_rate=self.target_sample,
|
156 |
+
hop_length=self.hop_size,
|
157 |
+
method=f0_method,
|
158 |
+
)
|
159 |
+
f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
|
160 |
+
f0 = torch.as_tensor(f0, dtype=self.dtype, device=self.device)
|
161 |
+
uv = torch.as_tensor(uv, dtype=self.dtype, device=self.device)
|
162 |
+
f0 = f0 * 2 ** (tran / 12)
|
163 |
+
f0 = f0.unsqueeze(0)
|
164 |
+
uv = uv.unsqueeze(0)
|
165 |
+
|
166 |
+
c = utils.get_content(
|
167 |
+
self.hubert_model,
|
168 |
+
audio,
|
169 |
+
self.device,
|
170 |
+
self.target_sample,
|
171 |
+
self.contentvec_final_proj,
|
172 |
+
).to(self.dtype)
|
173 |
+
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
174 |
+
|
175 |
+
if cluster_infer_ratio != 0:
|
176 |
+
cluster_c = cluster.get_cluster_center_result(
|
177 |
+
self.cluster_model, c.cpu().numpy().T, speaker
|
178 |
+
).T
|
179 |
+
cluster_c = torch.FloatTensor(cluster_c).to(self.device)
|
180 |
+
c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
|
181 |
+
|
182 |
+
c = c.unsqueeze(0)
|
183 |
+
return c, f0, uv
|
184 |
+
|
185 |
+
def infer(
|
186 |
+
self,
|
187 |
+
speaker: int | str,
|
188 |
+
transpose: int,
|
189 |
+
audio: ndarray[Any, dtype[float32]],
|
190 |
+
cluster_infer_ratio: float = 0,
|
191 |
+
auto_predict_f0: bool = False,
|
192 |
+
noise_scale: float = 0.4,
|
193 |
+
f0_method: Literal[
|
194 |
+
"crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
|
195 |
+
] = "dio",
|
196 |
+
) -> tuple[torch.Tensor, int]:
|
197 |
+
audio = audio.astype(np.float32)
|
198 |
+
# get speaker id
|
199 |
+
if isinstance(speaker, int):
|
200 |
+
if len(self.spk2id.__dict__) >= speaker:
|
201 |
+
speaker_id = speaker
|
202 |
+
else:
|
203 |
+
raise ValueError(
|
204 |
+
f"Speaker id {speaker} >= number of speakers {len(self.spk2id.__dict__)}"
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
if speaker in self.spk2id.__dict__:
|
208 |
+
speaker_id = self.spk2id.__dict__[speaker]
|
209 |
+
else:
|
210 |
+
LOG.warning(f"Speaker {speaker} is not found. Use speaker 0 instead.")
|
211 |
+
speaker_id = 0
|
212 |
+
speaker_candidates = list(
|
213 |
+
filter(lambda x: x[1] == speaker_id, self.spk2id.__dict__.items())
|
214 |
+
)
|
215 |
+
if len(speaker_candidates) > 1:
|
216 |
+
raise ValueError(
|
217 |
+
f"Speaker_id {speaker_id} is not unique. Candidates: {speaker_candidates}"
|
218 |
+
)
|
219 |
+
elif len(speaker_candidates) == 0:
|
220 |
+
raise ValueError(f"Speaker_id {speaker_id} is not found.")
|
221 |
+
speaker = speaker_candidates[0][0]
|
222 |
+
sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0)
|
223 |
+
|
224 |
+
# get unit f0
|
225 |
+
c, f0, uv = self.get_unit_f0(
|
226 |
+
audio, transpose, cluster_infer_ratio, speaker, f0_method
|
227 |
+
)
|
228 |
+
|
229 |
+
# inference
|
230 |
+
with torch.no_grad():
|
231 |
+
with timer() as t:
|
232 |
+
audio = self.net_g.infer(
|
233 |
+
c,
|
234 |
+
f0=f0,
|
235 |
+
g=sid,
|
236 |
+
uv=uv,
|
237 |
+
predict_f0=auto_predict_f0,
|
238 |
+
noice_scale=noise_scale,
|
239 |
+
)[0, 0].data.float()
|
240 |
+
audio_duration = audio.shape[-1] / self.target_sample
|
241 |
+
LOG.info(
|
242 |
+
f"Inference time: {t.elapsed:.2f}s, RTF: {t.elapsed / audio_duration:.2f}"
|
243 |
+
)
|
244 |
+
torch.cuda.empty_cache()
|
245 |
+
return audio, audio.shape[-1]
|
246 |
+
|
247 |
+
def infer_silence(
|
248 |
+
self,
|
249 |
+
audio: np.ndarray[Any, np.dtype[np.float32]],
|
250 |
+
*,
|
251 |
+
# svc config
|
252 |
+
speaker: int | str,
|
253 |
+
transpose: int = 0,
|
254 |
+
auto_predict_f0: bool = False,
|
255 |
+
cluster_infer_ratio: float = 0,
|
256 |
+
noise_scale: float = 0.4,
|
257 |
+
f0_method: Literal[
|
258 |
+
"crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
|
259 |
+
] = "dio",
|
260 |
+
# slice config
|
261 |
+
db_thresh: int = -40,
|
262 |
+
pad_seconds: float = 0.5,
|
263 |
+
chunk_seconds: float = 0.5,
|
264 |
+
absolute_thresh: bool = False,
|
265 |
+
max_chunk_seconds: float = 40,
|
266 |
+
# fade_seconds: float = 0.0,
|
267 |
+
) -> np.ndarray[Any, np.dtype[np.float32]]:
|
268 |
+
sr = self.target_sample
|
269 |
+
result_audio = np.array([], dtype=np.float32)
|
270 |
+
chunk_length_min = chunk_length_min = (
|
271 |
+
int(
|
272 |
+
min(
|
273 |
+
sr / so_vits_svc_fork.f0.f0_min * 20 + 1,
|
274 |
+
chunk_seconds * sr,
|
275 |
+
)
|
276 |
+
)
|
277 |
+
// 2
|
278 |
+
)
|
279 |
+
for chunk in split_silence(
|
280 |
+
audio,
|
281 |
+
top_db=-db_thresh,
|
282 |
+
frame_length=chunk_length_min * 2,
|
283 |
+
hop_length=chunk_length_min,
|
284 |
+
ref=1 if absolute_thresh else np.max,
|
285 |
+
max_chunk_length=int(max_chunk_seconds * sr),
|
286 |
+
):
|
287 |
+
LOG.info(f"Chunk: {chunk}")
|
288 |
+
if not chunk.is_speech:
|
289 |
+
audio_chunk_infer = np.zeros_like(chunk.audio)
|
290 |
+
else:
|
291 |
+
# pad
|
292 |
+
pad_len = int(sr * pad_seconds)
|
293 |
+
audio_chunk_pad = np.concatenate(
|
294 |
+
[
|
295 |
+
np.zeros([pad_len], dtype=np.float32),
|
296 |
+
chunk.audio,
|
297 |
+
np.zeros([pad_len], dtype=np.float32),
|
298 |
+
]
|
299 |
+
)
|
300 |
+
audio_chunk_pad_infer_tensor, _ = self.infer(
|
301 |
+
speaker,
|
302 |
+
transpose,
|
303 |
+
audio_chunk_pad,
|
304 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
305 |
+
auto_predict_f0=auto_predict_f0,
|
306 |
+
noise_scale=noise_scale,
|
307 |
+
f0_method=f0_method,
|
308 |
+
)
|
309 |
+
audio_chunk_pad_infer = audio_chunk_pad_infer_tensor.cpu().numpy()
|
310 |
+
pad_len = int(self.target_sample * pad_seconds)
|
311 |
+
cut_len_2 = (len(audio_chunk_pad_infer) - len(chunk.audio)) // 2
|
312 |
+
audio_chunk_infer = audio_chunk_pad_infer[
|
313 |
+
cut_len_2 : cut_len_2 + len(chunk.audio)
|
314 |
+
]
|
315 |
+
|
316 |
+
# add fade
|
317 |
+
# fade_len = int(self.target_sample * fade_seconds)
|
318 |
+
# _audio[:fade_len] = _audio[:fade_len] * np.linspace(0, 1, fade_len)
|
319 |
+
# _audio[-fade_len:] = _audio[-fade_len:] * np.linspace(1, 0, fade_len)
|
320 |
+
|
321 |
+
# empty cache
|
322 |
+
torch.cuda.empty_cache()
|
323 |
+
result_audio = np.concatenate([result_audio, audio_chunk_infer])
|
324 |
+
result_audio = result_audio[: audio.shape[0]]
|
325 |
+
return result_audio
|
326 |
+
|
327 |
+
|
328 |
+
def sola_crossfade(
|
329 |
+
first: ndarray[Any, dtype[float32]],
|
330 |
+
second: ndarray[Any, dtype[float32]],
|
331 |
+
crossfade_len: int,
|
332 |
+
sola_search_len: int,
|
333 |
+
) -> ndarray[Any, dtype[float32]]:
|
334 |
+
cor_nom = np.convolve(
|
335 |
+
second[: sola_search_len + crossfade_len],
|
336 |
+
np.flip(first[-crossfade_len:]),
|
337 |
+
"valid",
|
338 |
+
)
|
339 |
+
cor_den = np.sqrt(
|
340 |
+
np.convolve(
|
341 |
+
second[: sola_search_len + crossfade_len] ** 2,
|
342 |
+
np.ones(crossfade_len),
|
343 |
+
"valid",
|
344 |
+
)
|
345 |
+
+ 1e-8
|
346 |
+
)
|
347 |
+
sola_shift = np.argmax(cor_nom / cor_den)
|
348 |
+
LOG.info(f"SOLA shift: {sola_shift}")
|
349 |
+
second = second[sola_shift : sola_shift + len(second) - sola_search_len]
|
350 |
+
return np.concatenate(
|
351 |
+
[
|
352 |
+
first[:-crossfade_len],
|
353 |
+
first[-crossfade_len:] * np.linspace(1, 0, crossfade_len)
|
354 |
+
+ second[:crossfade_len] * np.linspace(0, 1, crossfade_len),
|
355 |
+
second[crossfade_len:],
|
356 |
+
]
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
class Crossfader:
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
*,
|
364 |
+
additional_infer_before_len: int,
|
365 |
+
additional_infer_after_len: int,
|
366 |
+
crossfade_len: int,
|
367 |
+
sola_search_len: int = 384,
|
368 |
+
) -> None:
|
369 |
+
if additional_infer_before_len < 0:
|
370 |
+
raise ValueError("additional_infer_len must be >= 0")
|
371 |
+
if crossfade_len < 0:
|
372 |
+
raise ValueError("crossfade_len must be >= 0")
|
373 |
+
if additional_infer_after_len < 0:
|
374 |
+
raise ValueError("additional_infer_len must be >= 0")
|
375 |
+
if additional_infer_before_len < 0:
|
376 |
+
raise ValueError("additional_infer_len must be >= 0")
|
377 |
+
self.additional_infer_before_len = additional_infer_before_len
|
378 |
+
self.additional_infer_after_len = additional_infer_after_len
|
379 |
+
self.crossfade_len = crossfade_len
|
380 |
+
self.sola_search_len = sola_search_len
|
381 |
+
self.last_input_left = np.zeros(
|
382 |
+
sola_search_len
|
383 |
+
+ crossfade_len
|
384 |
+
+ additional_infer_before_len
|
385 |
+
+ additional_infer_after_len,
|
386 |
+
dtype=np.float32,
|
387 |
+
)
|
388 |
+
self.last_infered_left = np.zeros(crossfade_len, dtype=np.float32)
|
389 |
+
|
390 |
+
def process(
|
391 |
+
self, input_audio: ndarray[Any, dtype[float32]], *args, **kwargs: Any
|
392 |
+
) -> ndarray[Any, dtype[float32]]:
|
393 |
+
"""
|
394 |
+
chunks : ■■■■■■□□□□□□
|
395 |
+
add last input:□■■■■■■
|
396 |
+
■□□□□□□
|
397 |
+
infer :□■■■■■■
|
398 |
+
■□□□□□□
|
399 |
+
crossfade :▲■■■■■
|
400 |
+
▲□□□□□
|
401 |
+
"""
|
402 |
+
# check input
|
403 |
+
if input_audio.ndim != 1:
|
404 |
+
raise ValueError("Input audio must be 1-dimensional.")
|
405 |
+
if (
|
406 |
+
input_audio.shape[0] + self.additional_infer_before_len
|
407 |
+
<= self.crossfade_len
|
408 |
+
):
|
409 |
+
raise ValueError(
|
410 |
+
f"Input audio length ({input_audio.shape[0]}) + additional_infer_len ({self.additional_infer_before_len}) must be greater than crossfade_len ({self.crossfade_len})."
|
411 |
+
)
|
412 |
+
input_audio = input_audio.astype(np.float32)
|
413 |
+
input_audio_len = len(input_audio)
|
414 |
+
|
415 |
+
# concat last input and infer
|
416 |
+
input_audio_concat = np.concatenate([self.last_input_left, input_audio])
|
417 |
+
del input_audio
|
418 |
+
pad_len = 0
|
419 |
+
if pad_len:
|
420 |
+
infer_audio_concat = self.infer(
|
421 |
+
np.pad(input_audio_concat, (pad_len, pad_len), mode="reflect"),
|
422 |
+
*args,
|
423 |
+
**kwargs,
|
424 |
+
)[pad_len:-pad_len]
|
425 |
+
else:
|
426 |
+
infer_audio_concat = self.infer(input_audio_concat, *args, **kwargs)
|
427 |
+
|
428 |
+
# debug SOLA (using copy synthesis with a random shift)
|
429 |
+
"""
|
430 |
+
rs = int(np.random.uniform(-200,200))
|
431 |
+
LOG.info(f"Debug random shift: {rs}")
|
432 |
+
infer_audio_concat = np.roll(input_audio_concat, rs)
|
433 |
+
"""
|
434 |
+
|
435 |
+
if len(infer_audio_concat) != len(input_audio_concat):
|
436 |
+
raise ValueError(
|
437 |
+
f"Inferred audio length ({len(infer_audio_concat)}) should be equal to input audio length ({len(input_audio_concat)})."
|
438 |
+
)
|
439 |
+
infer_audio_to_use = infer_audio_concat[
|
440 |
+
-(
|
441 |
+
self.sola_search_len
|
442 |
+
+ self.crossfade_len
|
443 |
+
+ input_audio_len
|
444 |
+
+ self.additional_infer_after_len
|
445 |
+
) : -self.additional_infer_after_len
|
446 |
+
]
|
447 |
+
assert (
|
448 |
+
len(infer_audio_to_use)
|
449 |
+
== input_audio_len + self.sola_search_len + self.crossfade_len
|
450 |
+
), f"{len(infer_audio_to_use)} != {input_audio_len + self.sola_search_len + self.cross_fade_len}"
|
451 |
+
_audio = sola_crossfade(
|
452 |
+
self.last_infered_left,
|
453 |
+
infer_audio_to_use,
|
454 |
+
self.crossfade_len,
|
455 |
+
self.sola_search_len,
|
456 |
+
)
|
457 |
+
result_audio = _audio[: -self.crossfade_len]
|
458 |
+
assert (
|
459 |
+
len(result_audio) == input_audio_len
|
460 |
+
), f"{len(result_audio)} != {input_audio_len}"
|
461 |
+
|
462 |
+
# update last input and inferred
|
463 |
+
self.last_input_left = input_audio_concat[
|
464 |
+
-(
|
465 |
+
self.sola_search_len
|
466 |
+
+ self.crossfade_len
|
467 |
+
+ self.additional_infer_before_len
|
468 |
+
+ self.additional_infer_after_len
|
469 |
+
) :
|
470 |
+
]
|
471 |
+
self.last_infered_left = _audio[-self.crossfade_len :]
|
472 |
+
return result_audio
|
473 |
+
|
474 |
+
def infer(
|
475 |
+
self, input_audio: ndarray[Any, dtype[float32]]
|
476 |
+
) -> ndarray[Any, dtype[float32]]:
|
477 |
+
return input_audio
|
478 |
+
|
479 |
+
|
480 |
+
class RealtimeVC(Crossfader):
|
481 |
+
def __init__(
|
482 |
+
self,
|
483 |
+
*,
|
484 |
+
svc_model: Svc,
|
485 |
+
crossfade_len: int = 3840,
|
486 |
+
additional_infer_before_len: int = 7680,
|
487 |
+
additional_infer_after_len: int = 7680,
|
488 |
+
split: bool = True,
|
489 |
+
) -> None:
|
490 |
+
self.svc_model = svc_model
|
491 |
+
self.split = split
|
492 |
+
super().__init__(
|
493 |
+
crossfade_len=crossfade_len,
|
494 |
+
additional_infer_before_len=additional_infer_before_len,
|
495 |
+
additional_infer_after_len=additional_infer_after_len,
|
496 |
+
)
|
497 |
+
|
498 |
+
def process(
|
499 |
+
self,
|
500 |
+
input_audio: ndarray[Any, dtype[float32]],
|
501 |
+
*args: Any,
|
502 |
+
**kwargs: Any,
|
503 |
+
) -> ndarray[Any, dtype[float32]]:
|
504 |
+
return super().process(input_audio, *args, **kwargs)
|
505 |
+
|
506 |
+
def infer(
|
507 |
+
self,
|
508 |
+
input_audio: np.ndarray[Any, np.dtype[np.float32]],
|
509 |
+
# svc config
|
510 |
+
speaker: int | str,
|
511 |
+
transpose: int,
|
512 |
+
cluster_infer_ratio: float = 0,
|
513 |
+
auto_predict_f0: bool = False,
|
514 |
+
noise_scale: float = 0.4,
|
515 |
+
f0_method: Literal[
|
516 |
+
"crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
|
517 |
+
] = "dio",
|
518 |
+
# slice config
|
519 |
+
db_thresh: int = -40,
|
520 |
+
pad_seconds: float = 0.5,
|
521 |
+
chunk_seconds: float = 0.5,
|
522 |
+
) -> ndarray[Any, dtype[float32]]:
|
523 |
+
# infer
|
524 |
+
if self.split:
|
525 |
+
return self.svc_model.infer_silence(
|
526 |
+
audio=input_audio,
|
527 |
+
speaker=speaker,
|
528 |
+
transpose=transpose,
|
529 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
530 |
+
auto_predict_f0=auto_predict_f0,
|
531 |
+
noise_scale=noise_scale,
|
532 |
+
f0_method=f0_method,
|
533 |
+
db_thresh=db_thresh,
|
534 |
+
pad_seconds=pad_seconds,
|
535 |
+
chunk_seconds=chunk_seconds,
|
536 |
+
absolute_thresh=True,
|
537 |
+
)
|
538 |
+
else:
|
539 |
+
rms = np.sqrt(np.mean(input_audio**2))
|
540 |
+
min_rms = 10 ** (db_thresh / 20)
|
541 |
+
if rms < min_rms:
|
542 |
+
LOG.info(f"Skip silence: RMS={rms:.2f} < {min_rms:.2f}")
|
543 |
+
return np.zeros_like(input_audio)
|
544 |
+
else:
|
545 |
+
LOG.info(f"Start inference: RMS={rms:.2f} >= {min_rms:.2f}")
|
546 |
+
infered_audio_c, _ = self.svc_model.infer(
|
547 |
+
speaker=speaker,
|
548 |
+
transpose=transpose,
|
549 |
+
audio=input_audio,
|
550 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
551 |
+
auto_predict_f0=auto_predict_f0,
|
552 |
+
noise_scale=noise_scale,
|
553 |
+
f0_method=f0_method,
|
554 |
+
)
|
555 |
+
return infered_audio_c.cpu().numpy()
|
556 |
+
|
557 |
+
|
558 |
+
class RealtimeVC2:
|
559 |
+
chunk_store: list[Chunk]
|
560 |
+
|
561 |
+
def __init__(self, svc_model: Svc) -> None:
|
562 |
+
self.input_audio_store = np.array([], dtype=np.float32)
|
563 |
+
self.chunk_store = []
|
564 |
+
self.svc_model = svc_model
|
565 |
+
|
566 |
+
def process(
|
567 |
+
self,
|
568 |
+
input_audio: np.ndarray[Any, np.dtype[np.float32]],
|
569 |
+
# svc config
|
570 |
+
speaker: int | str,
|
571 |
+
transpose: int,
|
572 |
+
cluster_infer_ratio: float = 0,
|
573 |
+
auto_predict_f0: bool = False,
|
574 |
+
noise_scale: float = 0.4,
|
575 |
+
f0_method: Literal[
|
576 |
+
"crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
|
577 |
+
] = "dio",
|
578 |
+
# slice config
|
579 |
+
db_thresh: int = -40,
|
580 |
+
chunk_seconds: float = 0.5,
|
581 |
+
) -> ndarray[Any, dtype[float32]]:
|
582 |
+
def infer(audio: ndarray[Any, dtype[float32]]) -> ndarray[Any, dtype[float32]]:
|
583 |
+
infered_audio_c, _ = self.svc_model.infer(
|
584 |
+
speaker=speaker,
|
585 |
+
transpose=transpose,
|
586 |
+
audio=audio,
|
587 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
588 |
+
auto_predict_f0=auto_predict_f0,
|
589 |
+
noise_scale=noise_scale,
|
590 |
+
f0_method=f0_method,
|
591 |
+
)
|
592 |
+
return infered_audio_c.cpu().numpy()
|
593 |
+
|
594 |
+
self.input_audio_store = np.concatenate([self.input_audio_store, input_audio])
|
595 |
+
LOG.info(f"input_audio_store: {self.input_audio_store.shape}")
|
596 |
+
sr = self.svc_model.target_sample
|
597 |
+
chunk_length_min = (
|
598 |
+
int(min(sr / so_vits_svc_fork.f0.f0_min * 20 + 1, chunk_seconds * sr)) // 2
|
599 |
+
)
|
600 |
+
LOG.info(f"Chunk length min: {chunk_length_min}")
|
601 |
+
chunk_list = list(
|
602 |
+
split_silence(
|
603 |
+
self.input_audio_store,
|
604 |
+
-db_thresh,
|
605 |
+
frame_length=chunk_length_min * 2,
|
606 |
+
hop_length=chunk_length_min,
|
607 |
+
ref=1, # use absolute threshold
|
608 |
+
)
|
609 |
+
)
|
610 |
+
assert len(chunk_list) > 0
|
611 |
+
LOG.info(f"Chunk list: {chunk_list}")
|
612 |
+
# do not infer LAST incomplete is_speech chunk and save to store
|
613 |
+
if chunk_list[-1].is_speech:
|
614 |
+
self.input_audio_store = chunk_list.pop().audio
|
615 |
+
else:
|
616 |
+
self.input_audio_store = np.array([], dtype=np.float32)
|
617 |
+
|
618 |
+
# infer complete is_speech chunk and save to store
|
619 |
+
self.chunk_store.extend(
|
620 |
+
[
|
621 |
+
attrs.evolve(c, audio=infer(c.audio) if c.is_speech else c.audio)
|
622 |
+
for c in chunk_list
|
623 |
+
]
|
624 |
+
)
|
625 |
+
|
626 |
+
# calculate lengths and determine compress rate
|
627 |
+
total_speech_len = sum(
|
628 |
+
[c.duration if c.is_speech else 0 for c in self.chunk_store]
|
629 |
+
)
|
630 |
+
total_silence_len = sum(
|
631 |
+
[c.duration if not c.is_speech else 0 for c in self.chunk_store]
|
632 |
+
)
|
633 |
+
input_audio_len = input_audio.shape[0]
|
634 |
+
silence_compress_rate = total_silence_len / max(
|
635 |
+
0, input_audio_len - total_speech_len
|
636 |
+
)
|
637 |
+
LOG.info(
|
638 |
+
f"Total speech len: {total_speech_len}, silence len: {total_silence_len}, silence compress rate: {silence_compress_rate}"
|
639 |
+
)
|
640 |
+
|
641 |
+
# generate output audio
|
642 |
+
output_audio = np.array([], dtype=np.float32)
|
643 |
+
break_flag = False
|
644 |
+
LOG.info(f"Chunk store: {self.chunk_store}")
|
645 |
+
for chunk in deepcopy(self.chunk_store):
|
646 |
+
compress_rate = 1 if chunk.is_speech else silence_compress_rate
|
647 |
+
left_len = input_audio_len - output_audio.shape[0]
|
648 |
+
# calculate chunk duration
|
649 |
+
chunk_duration_output = int(min(chunk.duration / compress_rate, left_len))
|
650 |
+
chunk_duration_input = int(min(chunk.duration, left_len * compress_rate))
|
651 |
+
LOG.info(
|
652 |
+
f"Chunk duration output: {chunk_duration_output}, input: {chunk_duration_input}, left len: {left_len}"
|
653 |
+
)
|
654 |
+
|
655 |
+
# remove chunk from store
|
656 |
+
self.chunk_store.pop(0)
|
657 |
+
if chunk.duration > chunk_duration_input:
|
658 |
+
left_chunk = attrs.evolve(
|
659 |
+
chunk, audio=chunk.audio[chunk_duration_input:]
|
660 |
+
)
|
661 |
+
chunk = attrs.evolve(chunk, audio=chunk.audio[:chunk_duration_input])
|
662 |
+
|
663 |
+
self.chunk_store.insert(0, left_chunk)
|
664 |
+
break_flag = True
|
665 |
+
|
666 |
+
if chunk.is_speech:
|
667 |
+
# if is_speech, just concat
|
668 |
+
output_audio = np.concatenate([output_audio, chunk.audio])
|
669 |
+
else:
|
670 |
+
# if is_silence, concat with zeros and compress with silence_compress_rate
|
671 |
+
output_audio = np.concatenate(
|
672 |
+
[
|
673 |
+
output_audio,
|
674 |
+
np.zeros(
|
675 |
+
chunk_duration_output,
|
676 |
+
dtype=np.float32,
|
677 |
+
),
|
678 |
+
]
|
679 |
+
)
|
680 |
+
|
681 |
+
if break_flag:
|
682 |
+
break
|
683 |
+
LOG.info(f"Chunk store: {self.chunk_store}, output_audio: {output_audio.shape}")
|
684 |
+
# make same length (errors)
|
685 |
+
output_audio = output_audio[:input_audio_len]
|
686 |
+
output_audio = np.concatenate(
|
687 |
+
[
|
688 |
+
output_audio,
|
689 |
+
np.zeros(input_audio_len - output_audio.shape[0], dtype=np.float32),
|
690 |
+
]
|
691 |
+
)
|
692 |
+
return output_audio
|
so_vits_svc_fork/inference/main.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Literal, Sequence
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import soundfile
|
10 |
+
import torch
|
11 |
+
from cm_time import timer
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
|
15 |
+
from so_vits_svc_fork.utils import get_optimal_device
|
16 |
+
|
17 |
+
LOG = getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def infer(
|
21 |
+
*,
|
22 |
+
# paths
|
23 |
+
input_path: Path | str | Sequence[Path | str],
|
24 |
+
output_path: Path | str | Sequence[Path | str],
|
25 |
+
model_path: Path | str,
|
26 |
+
config_path: Path | str,
|
27 |
+
recursive: bool = False,
|
28 |
+
# svc config
|
29 |
+
speaker: int | str,
|
30 |
+
cluster_model_path: Path | str | None = None,
|
31 |
+
transpose: int = 0,
|
32 |
+
auto_predict_f0: bool = False,
|
33 |
+
cluster_infer_ratio: float = 0,
|
34 |
+
noise_scale: float = 0.4,
|
35 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
|
36 |
+
# slice config
|
37 |
+
db_thresh: int = -40,
|
38 |
+
pad_seconds: float = 0.5,
|
39 |
+
chunk_seconds: float = 0.5,
|
40 |
+
absolute_thresh: bool = False,
|
41 |
+
max_chunk_seconds: float = 40,
|
42 |
+
device: str | torch.device = get_optimal_device(),
|
43 |
+
):
|
44 |
+
if isinstance(input_path, (str, Path)):
|
45 |
+
input_path = [input_path]
|
46 |
+
if isinstance(output_path, (str, Path)):
|
47 |
+
output_path = [output_path]
|
48 |
+
if len(input_path) != len(output_path):
|
49 |
+
raise ValueError(
|
50 |
+
f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}"
|
51 |
+
)
|
52 |
+
|
53 |
+
model_path = Path(model_path)
|
54 |
+
config_path = Path(config_path)
|
55 |
+
output_path = [Path(p) for p in output_path]
|
56 |
+
input_path = [Path(p) for p in input_path]
|
57 |
+
output_paths = []
|
58 |
+
input_paths = []
|
59 |
+
|
60 |
+
for input_path, output_path in zip(input_path, output_path):
|
61 |
+
if input_path.is_dir():
|
62 |
+
if not recursive:
|
63 |
+
raise ValueError(
|
64 |
+
f"input_path is a directory, but recursive is False: {input_path}"
|
65 |
+
)
|
66 |
+
input_paths.extend(list(input_path.rglob("*.*")))
|
67 |
+
output_paths.extend(
|
68 |
+
[output_path / p.relative_to(input_path) for p in input_paths]
|
69 |
+
)
|
70 |
+
continue
|
71 |
+
input_paths.append(input_path)
|
72 |
+
output_paths.append(output_path)
|
73 |
+
|
74 |
+
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
|
75 |
+
svc_model = Svc(
|
76 |
+
net_g_path=model_path.as_posix(),
|
77 |
+
config_path=config_path.as_posix(),
|
78 |
+
cluster_model_path=cluster_model_path.as_posix()
|
79 |
+
if cluster_model_path
|
80 |
+
else None,
|
81 |
+
device=device,
|
82 |
+
)
|
83 |
+
|
84 |
+
try:
|
85 |
+
pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
|
86 |
+
for input_path, output_path in pbar:
|
87 |
+
pbar.set_description(f"{input_path}")
|
88 |
+
try:
|
89 |
+
audio, _ = librosa.load(str(input_path), sr=svc_model.target_sample)
|
90 |
+
except Exception as e:
|
91 |
+
LOG.error(f"Failed to load {input_path}")
|
92 |
+
LOG.exception(e)
|
93 |
+
continue
|
94 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
95 |
+
audio = svc_model.infer_silence(
|
96 |
+
audio.astype(np.float32),
|
97 |
+
speaker=speaker,
|
98 |
+
transpose=transpose,
|
99 |
+
auto_predict_f0=auto_predict_f0,
|
100 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
101 |
+
noise_scale=noise_scale,
|
102 |
+
f0_method=f0_method,
|
103 |
+
db_thresh=db_thresh,
|
104 |
+
pad_seconds=pad_seconds,
|
105 |
+
chunk_seconds=chunk_seconds,
|
106 |
+
absolute_thresh=absolute_thresh,
|
107 |
+
max_chunk_seconds=max_chunk_seconds,
|
108 |
+
)
|
109 |
+
soundfile.write(str(output_path), audio, svc_model.target_sample)
|
110 |
+
finally:
|
111 |
+
del svc_model
|
112 |
+
torch.cuda.empty_cache()
|
113 |
+
|
114 |
+
|
115 |
+
def realtime(
|
116 |
+
*,
|
117 |
+
# paths
|
118 |
+
model_path: Path | str,
|
119 |
+
config_path: Path | str,
|
120 |
+
# svc config
|
121 |
+
speaker: str,
|
122 |
+
cluster_model_path: Path | str | None = None,
|
123 |
+
transpose: int = 0,
|
124 |
+
auto_predict_f0: bool = False,
|
125 |
+
cluster_infer_ratio: float = 0,
|
126 |
+
noise_scale: float = 0.4,
|
127 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
|
128 |
+
# slice config
|
129 |
+
db_thresh: int = -40,
|
130 |
+
pad_seconds: float = 0.5,
|
131 |
+
chunk_seconds: float = 0.5,
|
132 |
+
# realtime config
|
133 |
+
crossfade_seconds: float = 0.05,
|
134 |
+
additional_infer_before_seconds: float = 0.2,
|
135 |
+
additional_infer_after_seconds: float = 0.1,
|
136 |
+
block_seconds: float = 0.5,
|
137 |
+
version: int = 2,
|
138 |
+
input_device: int | str | None = None,
|
139 |
+
output_device: int | str | None = None,
|
140 |
+
device: str | torch.device = get_optimal_device(),
|
141 |
+
passthrough_original: bool = False,
|
142 |
+
):
|
143 |
+
import sounddevice as sd
|
144 |
+
|
145 |
+
model_path = Path(model_path)
|
146 |
+
config_path = Path(config_path)
|
147 |
+
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
|
148 |
+
svc_model = Svc(
|
149 |
+
net_g_path=model_path.as_posix(),
|
150 |
+
config_path=config_path.as_posix(),
|
151 |
+
cluster_model_path=cluster_model_path.as_posix()
|
152 |
+
if cluster_model_path
|
153 |
+
else None,
|
154 |
+
device=device,
|
155 |
+
)
|
156 |
+
|
157 |
+
LOG.info("Creating realtime model...")
|
158 |
+
if version == 1:
|
159 |
+
model = RealtimeVC(
|
160 |
+
svc_model=svc_model,
|
161 |
+
crossfade_len=int(crossfade_seconds * svc_model.target_sample),
|
162 |
+
additional_infer_before_len=int(
|
163 |
+
additional_infer_before_seconds * svc_model.target_sample
|
164 |
+
),
|
165 |
+
additional_infer_after_len=int(
|
166 |
+
additional_infer_after_seconds * svc_model.target_sample
|
167 |
+
),
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
model = RealtimeVC2(
|
171 |
+
svc_model=svc_model,
|
172 |
+
)
|
173 |
+
|
174 |
+
# LOG all device info
|
175 |
+
devices = sd.query_devices()
|
176 |
+
LOG.info(f"Device: {devices}")
|
177 |
+
if isinstance(input_device, str):
|
178 |
+
input_device_candidates = [
|
179 |
+
i for i, d in enumerate(devices) if d["name"] == input_device
|
180 |
+
]
|
181 |
+
if len(input_device_candidates) == 0:
|
182 |
+
LOG.warning(f"Input device {input_device} not found, using default")
|
183 |
+
input_device = None
|
184 |
+
else:
|
185 |
+
input_device = input_device_candidates[0]
|
186 |
+
if isinstance(output_device, str):
|
187 |
+
output_device_candidates = [
|
188 |
+
i for i, d in enumerate(devices) if d["name"] == output_device
|
189 |
+
]
|
190 |
+
if len(output_device_candidates) == 0:
|
191 |
+
LOG.warning(f"Output device {output_device} not found, using default")
|
192 |
+
output_device = None
|
193 |
+
else:
|
194 |
+
output_device = output_device_candidates[0]
|
195 |
+
if input_device is None or input_device >= len(devices):
|
196 |
+
input_device = sd.default.device[0]
|
197 |
+
if output_device is None or output_device >= len(devices):
|
198 |
+
output_device = sd.default.device[1]
|
199 |
+
LOG.info(
|
200 |
+
f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}"
|
201 |
+
)
|
202 |
+
|
203 |
+
# the model RTL is somewhat significantly high only in the first inference
|
204 |
+
# there could be no better way to warm up the model than to do a dummy inference
|
205 |
+
# (there are not differences in the behavior of the model between the first and the later inferences)
|
206 |
+
# so we do a dummy inference to warm up the model (1 second of audio)
|
207 |
+
LOG.info("Warming up the model...")
|
208 |
+
svc_model.infer(
|
209 |
+
speaker=speaker,
|
210 |
+
transpose=transpose,
|
211 |
+
auto_predict_f0=auto_predict_f0,
|
212 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
213 |
+
noise_scale=noise_scale,
|
214 |
+
f0_method=f0_method,
|
215 |
+
audio=np.zeros(svc_model.target_sample, dtype=np.float32),
|
216 |
+
)
|
217 |
+
|
218 |
+
def callback(
|
219 |
+
indata: np.ndarray,
|
220 |
+
outdata: np.ndarray,
|
221 |
+
frames: int,
|
222 |
+
time: int,
|
223 |
+
status: sd.CallbackFlags,
|
224 |
+
) -> None:
|
225 |
+
LOG.debug(
|
226 |
+
f"Frames: {frames}, Status: {status}, Shape: {indata.shape}, Time: {time}"
|
227 |
+
)
|
228 |
+
|
229 |
+
kwargs = dict(
|
230 |
+
input_audio=indata.mean(axis=1).astype(np.float32),
|
231 |
+
# svc config
|
232 |
+
speaker=speaker,
|
233 |
+
transpose=transpose,
|
234 |
+
auto_predict_f0=auto_predict_f0,
|
235 |
+
cluster_infer_ratio=cluster_infer_ratio,
|
236 |
+
noise_scale=noise_scale,
|
237 |
+
f0_method=f0_method,
|
238 |
+
# slice config
|
239 |
+
db_thresh=db_thresh,
|
240 |
+
# pad_seconds=pad_seconds,
|
241 |
+
chunk_seconds=chunk_seconds,
|
242 |
+
)
|
243 |
+
if version == 1:
|
244 |
+
kwargs["pad_seconds"] = pad_seconds
|
245 |
+
with timer() as t:
|
246 |
+
inference = model.process(
|
247 |
+
**kwargs,
|
248 |
+
).reshape(-1, 1)
|
249 |
+
if passthrough_original:
|
250 |
+
outdata[:] = (indata + inference) / 2
|
251 |
+
else:
|
252 |
+
outdata[:] = inference
|
253 |
+
rtf = t.elapsed / block_seconds
|
254 |
+
LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
|
255 |
+
if rtf > 1:
|
256 |
+
LOG.warning("RTF is too high, consider increasing block_seconds")
|
257 |
+
|
258 |
+
try:
|
259 |
+
with sd.Stream(
|
260 |
+
device=(input_device, output_device),
|
261 |
+
channels=1,
|
262 |
+
callback=callback,
|
263 |
+
samplerate=svc_model.target_sample,
|
264 |
+
blocksize=int(block_seconds * svc_model.target_sample),
|
265 |
+
latency="low",
|
266 |
+
) as stream:
|
267 |
+
LOG.info(f"Latency: {stream.latency}")
|
268 |
+
while True:
|
269 |
+
sd.sleep(1000)
|
270 |
+
finally:
|
271 |
+
# del model, svc_model
|
272 |
+
torch.cuda.empty_cache()
|
so_vits_svc_fork/logger.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from logging import DEBUG, INFO, StreamHandler, basicConfig, captureWarnings, getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from rich.logging import RichHandler
|
7 |
+
|
8 |
+
LOGGER_INIT = False
|
9 |
+
|
10 |
+
|
11 |
+
def init_logger() -> None:
|
12 |
+
global LOGGER_INIT
|
13 |
+
if LOGGER_INIT:
|
14 |
+
return
|
15 |
+
|
16 |
+
IS_TEST = "test" in Path.cwd().stem
|
17 |
+
package_name = sys.modules[__name__].__package__
|
18 |
+
basicConfig(
|
19 |
+
level=INFO,
|
20 |
+
format="%(asctime)s %(message)s",
|
21 |
+
datefmt="[%X]",
|
22 |
+
handlers=[
|
23 |
+
StreamHandler() if is_notebook() else RichHandler(),
|
24 |
+
# FileHandler(f"{package_name}.log"),
|
25 |
+
],
|
26 |
+
)
|
27 |
+
if IS_TEST:
|
28 |
+
getLogger(package_name).setLevel(DEBUG)
|
29 |
+
captureWarnings(True)
|
30 |
+
LOGGER_INIT = True
|
31 |
+
|
32 |
+
|
33 |
+
def is_notebook():
|
34 |
+
try:
|
35 |
+
from IPython import get_ipython
|
36 |
+
|
37 |
+
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
38 |
+
raise ImportError("console")
|
39 |
+
return False
|
40 |
+
if "VSCODE_PID" in os.environ: # pragma: no cover
|
41 |
+
raise ImportError("vscode")
|
42 |
+
return False
|
43 |
+
except Exception:
|
44 |
+
return False
|
45 |
+
else: # pragma: no cover
|
46 |
+
return True
|
so_vits_svc_fork/modules/__init__.py
ADDED
File without changes
|
so_vits_svc_fork/modules/attentions.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from so_vits_svc_fork.modules import commons
|
8 |
+
from so_vits_svc_fork.modules.modules import LayerNorm
|
9 |
+
|
10 |
+
|
11 |
+
class FFT(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
hidden_channels,
|
15 |
+
filter_channels,
|
16 |
+
n_heads,
|
17 |
+
n_layers=1,
|
18 |
+
kernel_size=1,
|
19 |
+
p_dropout=0.0,
|
20 |
+
proximal_bias=False,
|
21 |
+
proximal_init=True,
|
22 |
+
**kwargs
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.hidden_channels = hidden_channels
|
26 |
+
self.filter_channels = filter_channels
|
27 |
+
self.n_heads = n_heads
|
28 |
+
self.n_layers = n_layers
|
29 |
+
self.kernel_size = kernel_size
|
30 |
+
self.p_dropout = p_dropout
|
31 |
+
self.proximal_bias = proximal_bias
|
32 |
+
self.proximal_init = proximal_init
|
33 |
+
|
34 |
+
self.drop = nn.Dropout(p_dropout)
|
35 |
+
self.self_attn_layers = nn.ModuleList()
|
36 |
+
self.norm_layers_0 = nn.ModuleList()
|
37 |
+
self.ffn_layers = nn.ModuleList()
|
38 |
+
self.norm_layers_1 = nn.ModuleList()
|
39 |
+
for i in range(self.n_layers):
|
40 |
+
self.self_attn_layers.append(
|
41 |
+
MultiHeadAttention(
|
42 |
+
hidden_channels,
|
43 |
+
hidden_channels,
|
44 |
+
n_heads,
|
45 |
+
p_dropout=p_dropout,
|
46 |
+
proximal_bias=proximal_bias,
|
47 |
+
proximal_init=proximal_init,
|
48 |
+
)
|
49 |
+
)
|
50 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
51 |
+
self.ffn_layers.append(
|
52 |
+
FFN(
|
53 |
+
hidden_channels,
|
54 |
+
hidden_channels,
|
55 |
+
filter_channels,
|
56 |
+
kernel_size,
|
57 |
+
p_dropout=p_dropout,
|
58 |
+
causal=True,
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
62 |
+
|
63 |
+
def forward(self, x, x_mask):
|
64 |
+
"""
|
65 |
+
x: decoder input
|
66 |
+
h: encoder output
|
67 |
+
"""
|
68 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
69 |
+
device=x.device, dtype=x.dtype
|
70 |
+
)
|
71 |
+
x = x * x_mask
|
72 |
+
for i in range(self.n_layers):
|
73 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
74 |
+
y = self.drop(y)
|
75 |
+
x = self.norm_layers_0[i](x + y)
|
76 |
+
|
77 |
+
y = self.ffn_layers[i](x, x_mask)
|
78 |
+
y = self.drop(y)
|
79 |
+
x = self.norm_layers_1[i](x + y)
|
80 |
+
x = x * x_mask
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class Encoder(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
hidden_channels,
|
88 |
+
filter_channels,
|
89 |
+
n_heads,
|
90 |
+
n_layers,
|
91 |
+
kernel_size=1,
|
92 |
+
p_dropout=0.0,
|
93 |
+
window_size=4,
|
94 |
+
**kwargs
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.hidden_channels = hidden_channels
|
98 |
+
self.filter_channels = filter_channels
|
99 |
+
self.n_heads = n_heads
|
100 |
+
self.n_layers = n_layers
|
101 |
+
self.kernel_size = kernel_size
|
102 |
+
self.p_dropout = p_dropout
|
103 |
+
self.window_size = window_size
|
104 |
+
|
105 |
+
self.drop = nn.Dropout(p_dropout)
|
106 |
+
self.attn_layers = nn.ModuleList()
|
107 |
+
self.norm_layers_1 = nn.ModuleList()
|
108 |
+
self.ffn_layers = nn.ModuleList()
|
109 |
+
self.norm_layers_2 = nn.ModuleList()
|
110 |
+
for i in range(self.n_layers):
|
111 |
+
self.attn_layers.append(
|
112 |
+
MultiHeadAttention(
|
113 |
+
hidden_channels,
|
114 |
+
hidden_channels,
|
115 |
+
n_heads,
|
116 |
+
p_dropout=p_dropout,
|
117 |
+
window_size=window_size,
|
118 |
+
)
|
119 |
+
)
|
120 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
121 |
+
self.ffn_layers.append(
|
122 |
+
FFN(
|
123 |
+
hidden_channels,
|
124 |
+
hidden_channels,
|
125 |
+
filter_channels,
|
126 |
+
kernel_size,
|
127 |
+
p_dropout=p_dropout,
|
128 |
+
)
|
129 |
+
)
|
130 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
131 |
+
|
132 |
+
def forward(self, x, x_mask):
|
133 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
134 |
+
x = x * x_mask
|
135 |
+
for i in range(self.n_layers):
|
136 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
137 |
+
y = self.drop(y)
|
138 |
+
x = self.norm_layers_1[i](x + y)
|
139 |
+
|
140 |
+
y = self.ffn_layers[i](x, x_mask)
|
141 |
+
y = self.drop(y)
|
142 |
+
x = self.norm_layers_2[i](x + y)
|
143 |
+
x = x * x_mask
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class Decoder(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
hidden_channels,
|
151 |
+
filter_channels,
|
152 |
+
n_heads,
|
153 |
+
n_layers,
|
154 |
+
kernel_size=1,
|
155 |
+
p_dropout=0.0,
|
156 |
+
proximal_bias=False,
|
157 |
+
proximal_init=True,
|
158 |
+
**kwargs
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
self.hidden_channels = hidden_channels
|
162 |
+
self.filter_channels = filter_channels
|
163 |
+
self.n_heads = n_heads
|
164 |
+
self.n_layers = n_layers
|
165 |
+
self.kernel_size = kernel_size
|
166 |
+
self.p_dropout = p_dropout
|
167 |
+
self.proximal_bias = proximal_bias
|
168 |
+
self.proximal_init = proximal_init
|
169 |
+
|
170 |
+
self.drop = nn.Dropout(p_dropout)
|
171 |
+
self.self_attn_layers = nn.ModuleList()
|
172 |
+
self.norm_layers_0 = nn.ModuleList()
|
173 |
+
self.encdec_attn_layers = nn.ModuleList()
|
174 |
+
self.norm_layers_1 = nn.ModuleList()
|
175 |
+
self.ffn_layers = nn.ModuleList()
|
176 |
+
self.norm_layers_2 = nn.ModuleList()
|
177 |
+
for i in range(self.n_layers):
|
178 |
+
self.self_attn_layers.append(
|
179 |
+
MultiHeadAttention(
|
180 |
+
hidden_channels,
|
181 |
+
hidden_channels,
|
182 |
+
n_heads,
|
183 |
+
p_dropout=p_dropout,
|
184 |
+
proximal_bias=proximal_bias,
|
185 |
+
proximal_init=proximal_init,
|
186 |
+
)
|
187 |
+
)
|
188 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
189 |
+
self.encdec_attn_layers.append(
|
190 |
+
MultiHeadAttention(
|
191 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
192 |
+
)
|
193 |
+
)
|
194 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
195 |
+
self.ffn_layers.append(
|
196 |
+
FFN(
|
197 |
+
hidden_channels,
|
198 |
+
hidden_channels,
|
199 |
+
filter_channels,
|
200 |
+
kernel_size,
|
201 |
+
p_dropout=p_dropout,
|
202 |
+
causal=True,
|
203 |
+
)
|
204 |
+
)
|
205 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
206 |
+
|
207 |
+
def forward(self, x, x_mask, h, h_mask):
|
208 |
+
"""
|
209 |
+
x: decoder input
|
210 |
+
h: encoder output
|
211 |
+
"""
|
212 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
213 |
+
device=x.device, dtype=x.dtype
|
214 |
+
)
|
215 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
216 |
+
x = x * x_mask
|
217 |
+
for i in range(self.n_layers):
|
218 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
219 |
+
y = self.drop(y)
|
220 |
+
x = self.norm_layers_0[i](x + y)
|
221 |
+
|
222 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
223 |
+
y = self.drop(y)
|
224 |
+
x = self.norm_layers_1[i](x + y)
|
225 |
+
|
226 |
+
y = self.ffn_layers[i](x, x_mask)
|
227 |
+
y = self.drop(y)
|
228 |
+
x = self.norm_layers_2[i](x + y)
|
229 |
+
x = x * x_mask
|
230 |
+
return x
|
231 |
+
|
232 |
+
|
233 |
+
class MultiHeadAttention(nn.Module):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
channels,
|
237 |
+
out_channels,
|
238 |
+
n_heads,
|
239 |
+
p_dropout=0.0,
|
240 |
+
window_size=None,
|
241 |
+
heads_share=True,
|
242 |
+
block_length=None,
|
243 |
+
proximal_bias=False,
|
244 |
+
proximal_init=False,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
assert channels % n_heads == 0
|
248 |
+
|
249 |
+
self.channels = channels
|
250 |
+
self.out_channels = out_channels
|
251 |
+
self.n_heads = n_heads
|
252 |
+
self.p_dropout = p_dropout
|
253 |
+
self.window_size = window_size
|
254 |
+
self.heads_share = heads_share
|
255 |
+
self.block_length = block_length
|
256 |
+
self.proximal_bias = proximal_bias
|
257 |
+
self.proximal_init = proximal_init
|
258 |
+
self.attn = None
|
259 |
+
|
260 |
+
self.k_channels = channels // n_heads
|
261 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
262 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
263 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
264 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
265 |
+
self.drop = nn.Dropout(p_dropout)
|
266 |
+
|
267 |
+
if window_size is not None:
|
268 |
+
n_heads_rel = 1 if heads_share else n_heads
|
269 |
+
rel_stddev = self.k_channels**-0.5
|
270 |
+
self.emb_rel_k = nn.Parameter(
|
271 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
272 |
+
* rel_stddev
|
273 |
+
)
|
274 |
+
self.emb_rel_v = nn.Parameter(
|
275 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
276 |
+
* rel_stddev
|
277 |
+
)
|
278 |
+
|
279 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
280 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
281 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
282 |
+
if proximal_init:
|
283 |
+
with torch.no_grad():
|
284 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
285 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
286 |
+
|
287 |
+
def forward(self, x, c, attn_mask=None):
|
288 |
+
q = self.conv_q(x)
|
289 |
+
k = self.conv_k(c)
|
290 |
+
v = self.conv_v(c)
|
291 |
+
|
292 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
293 |
+
|
294 |
+
x = self.conv_o(x)
|
295 |
+
return x
|
296 |
+
|
297 |
+
def attention(self, query, key, value, mask=None):
|
298 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
299 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
300 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
301 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
302 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
303 |
+
|
304 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
305 |
+
if self.window_size is not None:
|
306 |
+
assert (
|
307 |
+
t_s == t_t
|
308 |
+
), "Relative attention is only available for self-attention."
|
309 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
310 |
+
rel_logits = self._matmul_with_relative_keys(
|
311 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
312 |
+
)
|
313 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
314 |
+
scores = scores + scores_local
|
315 |
+
if self.proximal_bias:
|
316 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
317 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
318 |
+
device=scores.device, dtype=scores.dtype
|
319 |
+
)
|
320 |
+
if mask is not None:
|
321 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
322 |
+
if self.block_length is not None:
|
323 |
+
assert (
|
324 |
+
t_s == t_t
|
325 |
+
), "Local attention is only available for self-attention."
|
326 |
+
block_mask = (
|
327 |
+
torch.ones_like(scores)
|
328 |
+
.triu(-self.block_length)
|
329 |
+
.tril(self.block_length)
|
330 |
+
)
|
331 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
332 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
333 |
+
p_attn = self.drop(p_attn)
|
334 |
+
output = torch.matmul(p_attn, value)
|
335 |
+
if self.window_size is not None:
|
336 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
337 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
338 |
+
self.emb_rel_v, t_s
|
339 |
+
)
|
340 |
+
output = output + self._matmul_with_relative_values(
|
341 |
+
relative_weights, value_relative_embeddings
|
342 |
+
)
|
343 |
+
output = (
|
344 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
345 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
346 |
+
return output, p_attn
|
347 |
+
|
348 |
+
def _matmul_with_relative_values(self, x, y):
|
349 |
+
"""
|
350 |
+
x: [b, h, l, m]
|
351 |
+
y: [h or 1, m, d]
|
352 |
+
ret: [b, h, l, d]
|
353 |
+
"""
|
354 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
355 |
+
return ret
|
356 |
+
|
357 |
+
def _matmul_with_relative_keys(self, x, y):
|
358 |
+
"""
|
359 |
+
x: [b, h, l, d]
|
360 |
+
y: [h or 1, m, d]
|
361 |
+
ret: [b, h, l, m]
|
362 |
+
"""
|
363 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
364 |
+
return ret
|
365 |
+
|
366 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
367 |
+
2 * self.window_size + 1
|
368 |
+
# Pad first before slice to avoid using cond ops.
|
369 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
370 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
371 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
372 |
+
if pad_length > 0:
|
373 |
+
padded_relative_embeddings = F.pad(
|
374 |
+
relative_embeddings,
|
375 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
padded_relative_embeddings = relative_embeddings
|
379 |
+
used_relative_embeddings = padded_relative_embeddings[
|
380 |
+
:, slice_start_position:slice_end_position
|
381 |
+
]
|
382 |
+
return used_relative_embeddings
|
383 |
+
|
384 |
+
def _relative_position_to_absolute_position(self, x):
|
385 |
+
"""
|
386 |
+
x: [b, h, l, 2*l-1]
|
387 |
+
ret: [b, h, l, l]
|
388 |
+
"""
|
389 |
+
batch, heads, length, _ = x.size()
|
390 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
391 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
392 |
+
|
393 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
394 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
395 |
+
x_flat = F.pad(
|
396 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
397 |
+
)
|
398 |
+
|
399 |
+
# Reshape and slice out the padded elements.
|
400 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
401 |
+
:, :, :length, length - 1 :
|
402 |
+
]
|
403 |
+
return x_final
|
404 |
+
|
405 |
+
def _absolute_position_to_relative_position(self, x):
|
406 |
+
"""
|
407 |
+
x: [b, h, l, l]
|
408 |
+
ret: [b, h, l, 2*l-1]
|
409 |
+
"""
|
410 |
+
batch, heads, length, _ = x.size()
|
411 |
+
# pad along column
|
412 |
+
x = F.pad(
|
413 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
414 |
+
)
|
415 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
416 |
+
# add 0's in the beginning that will skew the elements after reshape
|
417 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
418 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
419 |
+
return x_final
|
420 |
+
|
421 |
+
def _attention_bias_proximal(self, length):
|
422 |
+
"""Bias for self-attention to encourage attention to close positions.
|
423 |
+
Args:
|
424 |
+
length: an integer scalar.
|
425 |
+
Returns:
|
426 |
+
a Tensor with shape [1, 1, length, length]
|
427 |
+
"""
|
428 |
+
r = torch.arange(length, dtype=torch.float32)
|
429 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
430 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
431 |
+
|
432 |
+
|
433 |
+
class FFN(nn.Module):
|
434 |
+
def __init__(
|
435 |
+
self,
|
436 |
+
in_channels,
|
437 |
+
out_channels,
|
438 |
+
filter_channels,
|
439 |
+
kernel_size,
|
440 |
+
p_dropout=0.0,
|
441 |
+
activation=None,
|
442 |
+
causal=False,
|
443 |
+
):
|
444 |
+
super().__init__()
|
445 |
+
self.in_channels = in_channels
|
446 |
+
self.out_channels = out_channels
|
447 |
+
self.filter_channels = filter_channels
|
448 |
+
self.kernel_size = kernel_size
|
449 |
+
self.p_dropout = p_dropout
|
450 |
+
self.activation = activation
|
451 |
+
self.causal = causal
|
452 |
+
|
453 |
+
if causal:
|
454 |
+
self.padding = self._causal_padding
|
455 |
+
else:
|
456 |
+
self.padding = self._same_padding
|
457 |
+
|
458 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
459 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
460 |
+
self.drop = nn.Dropout(p_dropout)
|
461 |
+
|
462 |
+
def forward(self, x, x_mask):
|
463 |
+
x = self.conv_1(self.padding(x * x_mask))
|
464 |
+
if self.activation == "gelu":
|
465 |
+
x = x * torch.sigmoid(1.702 * x)
|
466 |
+
else:
|
467 |
+
x = torch.relu(x)
|
468 |
+
x = self.drop(x)
|
469 |
+
x = self.conv_2(self.padding(x * x_mask))
|
470 |
+
return x * x_mask
|
471 |
+
|
472 |
+
def _causal_padding(self, x):
|
473 |
+
if self.kernel_size == 1:
|
474 |
+
return x
|
475 |
+
pad_l = self.kernel_size - 1
|
476 |
+
pad_r = 0
|
477 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
478 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
479 |
+
return x
|
480 |
+
|
481 |
+
def _same_padding(self, x):
|
482 |
+
if self.kernel_size == 1:
|
483 |
+
return x
|
484 |
+
pad_l = (self.kernel_size - 1) // 2
|
485 |
+
pad_r = self.kernel_size // 2
|
486 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
487 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
488 |
+
return x
|
so_vits_svc_fork/modules/commons.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
|
8 |
+
def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
|
9 |
+
if length is None:
|
10 |
+
return x
|
11 |
+
length = min(length, x.size(-1))
|
12 |
+
x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device)
|
13 |
+
ends = starts + length
|
14 |
+
for i, (start, end) in enumerate(zip(starts, ends)):
|
15 |
+
# LOG.debug(i, start, end, x.size(), x[i, ..., start:end].size(), x_slice.size())
|
16 |
+
# x_slice[i, ...] = x[i, ..., start:end] need to pad
|
17 |
+
# x_slice[i, ..., :end - start] = x[i, ..., start:end] this does not work
|
18 |
+
x_slice[i, ...] = F.pad(x[i, ..., start:end], (0, max(0, length - x.size(-1))))
|
19 |
+
return x_slice
|
20 |
+
|
21 |
+
|
22 |
+
def rand_slice_segments_with_pitch(
|
23 |
+
x: Tensor, f0: Tensor, x_lengths: Tensor | int | None, segment_size: int | None
|
24 |
+
):
|
25 |
+
if segment_size is None:
|
26 |
+
return x, f0, torch.arange(x.size(0), device=x.device)
|
27 |
+
if x_lengths is None:
|
28 |
+
x_lengths = x.size(-1) * torch.ones(
|
29 |
+
x.size(0), dtype=torch.long, device=x.device
|
30 |
+
)
|
31 |
+
# slice_starts = (torch.rand(z.size(0), device=z.device) * (z_lengths - segment_size)).long()
|
32 |
+
slice_starts = (
|
33 |
+
torch.rand(x.size(0), device=x.device)
|
34 |
+
* torch.max(
|
35 |
+
x_lengths - segment_size, torch.zeros_like(x_lengths, device=x.device)
|
36 |
+
)
|
37 |
+
).long()
|
38 |
+
z_slice = slice_segments(x, slice_starts, segment_size)
|
39 |
+
f0_slice = slice_segments(f0, slice_starts, segment_size)
|
40 |
+
return z_slice, f0_slice, slice_starts
|
41 |
+
|
42 |
+
|
43 |
+
def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
|
44 |
+
batch_size, num_features, seq_len = x.shape
|
45 |
+
ends = starts + length
|
46 |
+
idxs = (
|
47 |
+
torch.arange(seq_len, device=x.device)
|
48 |
+
.unsqueeze(0)
|
49 |
+
.unsqueeze(1)
|
50 |
+
.repeat(batch_size, num_features, 1)
|
51 |
+
)
|
52 |
+
mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & (
|
53 |
+
idxs < ends.unsqueeze(-1).unsqueeze(-1)
|
54 |
+
)
|
55 |
+
return x[mask].reshape(batch_size, num_features, length)
|
56 |
+
|
57 |
+
|
58 |
+
def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
|
59 |
+
batch_size, seq_len = x.shape
|
60 |
+
ends = starts + length
|
61 |
+
idxs = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
|
62 |
+
mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1))
|
63 |
+
return x[mask].reshape(batch_size, length)
|
64 |
+
|
65 |
+
|
66 |
+
def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor:
|
67 |
+
shape = x.shape[:-1] + (length,)
|
68 |
+
ends = starts + length
|
69 |
+
idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0)
|
70 |
+
unsqueeze_dims = len(shape) - len(
|
71 |
+
x.shape
|
72 |
+
) # calculate number of dimensions to unsqueeze
|
73 |
+
starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims)
|
74 |
+
ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims)
|
75 |
+
mask = (idxs >= starts) & (idxs < ends)
|
76 |
+
return x[mask].reshape(shape)
|
77 |
+
|
78 |
+
|
79 |
+
def init_weights(m, mean=0.0, std=0.01):
|
80 |
+
classname = m.__class__.__name__
|
81 |
+
if classname.find("Conv") != -1:
|
82 |
+
m.weight.data.normal_(mean, std)
|
83 |
+
|
84 |
+
|
85 |
+
def get_padding(kernel_size, dilation=1):
|
86 |
+
return int((kernel_size * dilation - dilation) / 2)
|
87 |
+
|
88 |
+
|
89 |
+
def convert_pad_shape(pad_shape):
|
90 |
+
l = pad_shape[::-1]
|
91 |
+
pad_shape = [item for sublist in l for item in sublist]
|
92 |
+
return pad_shape
|
93 |
+
|
94 |
+
|
95 |
+
def subsequent_mask(length):
|
96 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
97 |
+
return mask
|
98 |
+
|
99 |
+
|
100 |
+
@torch.jit.script
|
101 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
102 |
+
n_channels_int = n_channels[0]
|
103 |
+
in_act = input_a + input_b
|
104 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
105 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
106 |
+
acts = t_act * s_act
|
107 |
+
return acts
|
108 |
+
|
109 |
+
|
110 |
+
def sequence_mask(length, max_length=None):
|
111 |
+
if max_length is None:
|
112 |
+
max_length = length.max()
|
113 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
114 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
115 |
+
|
116 |
+
|
117 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
118 |
+
if isinstance(parameters, torch.Tensor):
|
119 |
+
parameters = [parameters]
|
120 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
121 |
+
norm_type = float(norm_type)
|
122 |
+
if clip_value is not None:
|
123 |
+
clip_value = float(clip_value)
|
124 |
+
|
125 |
+
total_norm = 0
|
126 |
+
for p in parameters:
|
127 |
+
param_norm = p.grad.data.norm(norm_type)
|
128 |
+
total_norm += param_norm.item() ** norm_type
|
129 |
+
if clip_value is not None:
|
130 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
131 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
132 |
+
return total_norm
|
so_vits_svc_fork/modules/decoders/__init__.py
ADDED
File without changes
|
so_vits_svc_fork/modules/decoders/f0.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from so_vits_svc_fork.modules import attentions as attentions
|
5 |
+
|
6 |
+
|
7 |
+
class F0Decoder(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
out_channels,
|
11 |
+
hidden_channels,
|
12 |
+
filter_channels,
|
13 |
+
n_heads,
|
14 |
+
n_layers,
|
15 |
+
kernel_size,
|
16 |
+
p_dropout,
|
17 |
+
spk_channels=0,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.out_channels = out_channels
|
21 |
+
self.hidden_channels = hidden_channels
|
22 |
+
self.filter_channels = filter_channels
|
23 |
+
self.n_heads = n_heads
|
24 |
+
self.n_layers = n_layers
|
25 |
+
self.kernel_size = kernel_size
|
26 |
+
self.p_dropout = p_dropout
|
27 |
+
self.spk_channels = spk_channels
|
28 |
+
|
29 |
+
self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
|
30 |
+
self.decoder = attentions.FFT(
|
31 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
32 |
+
)
|
33 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
34 |
+
self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
|
35 |
+
self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
|
36 |
+
|
37 |
+
def forward(self, x, norm_f0, x_mask, spk_emb=None):
|
38 |
+
x = torch.detach(x)
|
39 |
+
if spk_emb is not None:
|
40 |
+
spk_emb = torch.detach(spk_emb)
|
41 |
+
x = x + self.cond(spk_emb)
|
42 |
+
x += self.f0_prenet(norm_f0)
|
43 |
+
x = self.prenet(x) * x_mask
|
44 |
+
x = self.decoder(x * x_mask, x_mask)
|
45 |
+
x = self.proj(x) * x_mask
|
46 |
+
return x
|
so_vits_svc_fork/modules/decoders/hifigan/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from ._models import NSFHifiGANGenerator
|
2 |
+
|
3 |
+
__all__ = ["NSFHifiGANGenerator"]
|
so_vits_svc_fork/modules/decoders/hifigan/_models.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from logging import getLogger
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
8 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
9 |
+
|
10 |
+
from ...modules import ResBlock1, ResBlock2
|
11 |
+
from ._utils import init_weights
|
12 |
+
|
13 |
+
LOG = getLogger(__name__)
|
14 |
+
|
15 |
+
LRELU_SLOPE = 0.1
|
16 |
+
|
17 |
+
|
18 |
+
def padDiff(x):
|
19 |
+
return F.pad(
|
20 |
+
F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class SineGen(torch.nn.Module):
|
25 |
+
"""Definition of sine generator
|
26 |
+
SineGen(samp_rate, harmonic_num = 0,
|
27 |
+
sine_amp = 0.1, noise_std = 0.003,
|
28 |
+
voiced_threshold = 0,
|
29 |
+
flag_for_pulse=False)
|
30 |
+
samp_rate: sampling rate in Hz
|
31 |
+
harmonic_num: number of harmonic overtones (default 0)
|
32 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
33 |
+
noise_std: std of Gaussian noise (default 0.003)
|
34 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
35 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
36 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
37 |
+
segment is always sin(np.pi) or cos(0)
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
samp_rate,
|
43 |
+
harmonic_num=0,
|
44 |
+
sine_amp=0.1,
|
45 |
+
noise_std=0.003,
|
46 |
+
voiced_threshold=0,
|
47 |
+
flag_for_pulse=False,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.sine_amp = sine_amp
|
51 |
+
self.noise_std = noise_std
|
52 |
+
self.harmonic_num = harmonic_num
|
53 |
+
self.dim = self.harmonic_num + 1
|
54 |
+
self.sampling_rate = samp_rate
|
55 |
+
self.voiced_threshold = voiced_threshold
|
56 |
+
self.flag_for_pulse = flag_for_pulse
|
57 |
+
|
58 |
+
def _f02uv(self, f0):
|
59 |
+
# generate uv signal
|
60 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
61 |
+
return uv
|
62 |
+
|
63 |
+
def _f02sine(self, f0_values):
|
64 |
+
"""f0_values: (batchsize, length, dim)
|
65 |
+
where dim indicates fundamental tone and overtones
|
66 |
+
"""
|
67 |
+
# convert to F0 in rad. The integer part n can be ignored
|
68 |
+
# because 2 * np.pi * n doesn't affect phase
|
69 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
70 |
+
|
71 |
+
# initial phase noise (no noise for fundamental component)
|
72 |
+
rand_ini = torch.rand(
|
73 |
+
f0_values.shape[0], f0_values.shape[2], device=f0_values.device
|
74 |
+
)
|
75 |
+
rand_ini[:, 0] = 0
|
76 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
77 |
+
|
78 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
79 |
+
if not self.flag_for_pulse:
|
80 |
+
# for normal case
|
81 |
+
|
82 |
+
# To prevent torch.cumsum numerical overflow,
|
83 |
+
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
84 |
+
# Buffer tmp_over_one_idx indicates the time step to add -1.
|
85 |
+
# This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
86 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
87 |
+
tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
88 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
89 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
90 |
+
|
91 |
+
sines = torch.sin(
|
92 |
+
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
# If necessary, make sure that the first time step of every
|
96 |
+
# voiced segments is sin(pi) or cos(0)
|
97 |
+
# This is used for pulse-train generation
|
98 |
+
|
99 |
+
# identify the last time step in unvoiced segments
|
100 |
+
uv = self._f02uv(f0_values)
|
101 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
102 |
+
uv_1[:, -1, :] = 1
|
103 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
104 |
+
|
105 |
+
# get the instantanouse phase
|
106 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
107 |
+
# different batch needs to be processed differently
|
108 |
+
for idx in range(f0_values.shape[0]):
|
109 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
110 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
111 |
+
# stores the accumulation of i.phase within
|
112 |
+
# each voiced segments
|
113 |
+
tmp_cumsum[idx, :, :] = 0
|
114 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
115 |
+
|
116 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
117 |
+
# within the previous voiced segment.
|
118 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
119 |
+
|
120 |
+
# get the sines
|
121 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
122 |
+
return sines
|
123 |
+
|
124 |
+
def forward(self, f0):
|
125 |
+
"""sine_tensor, uv = forward(f0)
|
126 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
127 |
+
f0 for unvoiced steps should be 0
|
128 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
129 |
+
output uv: tensor(batchsize=1, length, 1)
|
130 |
+
"""
|
131 |
+
with torch.no_grad():
|
132 |
+
# f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
133 |
+
# fundamental component
|
134 |
+
# fn = torch.multiply(
|
135 |
+
# f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
136 |
+
# )
|
137 |
+
fn = torch.multiply(
|
138 |
+
f0, torch.arange(1, self.harmonic_num + 2).to(f0.device).to(f0.dtype)
|
139 |
+
)
|
140 |
+
|
141 |
+
# generate sine waveforms
|
142 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
143 |
+
|
144 |
+
# generate uv signal
|
145 |
+
# uv = torch.ones(f0.shape)
|
146 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
147 |
+
uv = self._f02uv(f0)
|
148 |
+
|
149 |
+
# noise: for unvoiced should be similar to sine_amp
|
150 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
151 |
+
# . for voiced regions is self.noise_std
|
152 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
153 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
154 |
+
|
155 |
+
# first: set the unvoiced part to 0 by uv
|
156 |
+
# then: additive noise
|
157 |
+
sine_waves = sine_waves * uv + noise
|
158 |
+
return sine_waves, uv, noise
|
159 |
+
|
160 |
+
|
161 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
162 |
+
"""SourceModule for hn-nsf
|
163 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
164 |
+
add_noise_std=0.003, voiced_threshod=0)
|
165 |
+
sampling_rate: sampling_rate in Hz
|
166 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
167 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
168 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
169 |
+
note that amplitude of noise in unvoiced is decided
|
170 |
+
by sine_amp
|
171 |
+
voiced_threshold: threshold to set U/V given F0 (default: 0)
|
172 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
173 |
+
F0_sampled (batchsize, length, 1)
|
174 |
+
Sine_source (batchsize, length, 1)
|
175 |
+
noise_source (batchsize, length 1)
|
176 |
+
uv (batchsize, length, 1)
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
sampling_rate,
|
182 |
+
harmonic_num=0,
|
183 |
+
sine_amp=0.1,
|
184 |
+
add_noise_std=0.003,
|
185 |
+
voiced_threshod=0,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
self.sine_amp = sine_amp
|
190 |
+
self.noise_std = add_noise_std
|
191 |
+
|
192 |
+
# to produce sine waveforms
|
193 |
+
self.l_sin_gen = SineGen(
|
194 |
+
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
|
195 |
+
)
|
196 |
+
|
197 |
+
# to merge source harmonics into a single excitation
|
198 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
199 |
+
self.l_tanh = torch.nn.Tanh()
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
"""
|
203 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
204 |
+
F0_sampled (batchsize, length, 1)
|
205 |
+
Sine_source (batchsize, length, 1)
|
206 |
+
noise_source (batchsize, length 1)
|
207 |
+
"""
|
208 |
+
# source for harmonic branch
|
209 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
210 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
211 |
+
|
212 |
+
# source for noise branch, in the same shape as uv
|
213 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
214 |
+
return sine_merge, noise, uv
|
215 |
+
|
216 |
+
|
217 |
+
class NSFHifiGANGenerator(torch.nn.Module):
|
218 |
+
def __init__(self, h):
|
219 |
+
super().__init__()
|
220 |
+
self.h = h
|
221 |
+
|
222 |
+
self.num_kernels = len(h["resblock_kernel_sizes"])
|
223 |
+
self.num_upsamples = len(h["upsample_rates"])
|
224 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
|
225 |
+
self.m_source = SourceModuleHnNSF(
|
226 |
+
sampling_rate=h["sampling_rate"], harmonic_num=8
|
227 |
+
)
|
228 |
+
self.noise_convs = nn.ModuleList()
|
229 |
+
self.conv_pre = weight_norm(
|
230 |
+
Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)
|
231 |
+
)
|
232 |
+
resblock = ResBlock1 if h["resblock"] == "1" else ResBlock2
|
233 |
+
self.ups = nn.ModuleList()
|
234 |
+
for i, (u, k) in enumerate(
|
235 |
+
zip(h["upsample_rates"], h["upsample_kernel_sizes"])
|
236 |
+
):
|
237 |
+
c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
|
238 |
+
self.ups.append(
|
239 |
+
weight_norm(
|
240 |
+
ConvTranspose1d(
|
241 |
+
h["upsample_initial_channel"] // (2**i),
|
242 |
+
h["upsample_initial_channel"] // (2 ** (i + 1)),
|
243 |
+
k,
|
244 |
+
u,
|
245 |
+
padding=(k - u) // 2,
|
246 |
+
)
|
247 |
+
)
|
248 |
+
)
|
249 |
+
if i + 1 < len(h["upsample_rates"]): #
|
250 |
+
stride_f0 = np.prod(h["upsample_rates"][i + 1 :])
|
251 |
+
self.noise_convs.append(
|
252 |
+
Conv1d(
|
253 |
+
1,
|
254 |
+
c_cur,
|
255 |
+
kernel_size=stride_f0 * 2,
|
256 |
+
stride=stride_f0,
|
257 |
+
padding=stride_f0 // 2,
|
258 |
+
)
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
262 |
+
self.resblocks = nn.ModuleList()
|
263 |
+
for i in range(len(self.ups)):
|
264 |
+
ch = h["upsample_initial_channel"] // (2 ** (i + 1))
|
265 |
+
for j, (k, d) in enumerate(
|
266 |
+
zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])
|
267 |
+
):
|
268 |
+
self.resblocks.append(resblock(ch, k, d))
|
269 |
+
|
270 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
271 |
+
self.ups.apply(init_weights)
|
272 |
+
self.conv_post.apply(init_weights)
|
273 |
+
self.cond = nn.Conv1d(h["gin_channels"], h["upsample_initial_channel"], 1)
|
274 |
+
|
275 |
+
def forward(self, x, f0, g=None):
|
276 |
+
# LOG.info(1,x.shape,f0.shape,f0[:, None].shape)
|
277 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
278 |
+
# LOG.info(2,f0.shape)
|
279 |
+
har_source, noi_source, uv = self.m_source(f0)
|
280 |
+
har_source = har_source.transpose(1, 2)
|
281 |
+
x = self.conv_pre(x)
|
282 |
+
x = x + self.cond(g)
|
283 |
+
# LOG.info(124,x.shape,har_source.shape)
|
284 |
+
for i in range(self.num_upsamples):
|
285 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
286 |
+
# LOG.info(3,x.shape)
|
287 |
+
x = self.ups[i](x)
|
288 |
+
x_source = self.noise_convs[i](har_source)
|
289 |
+
# LOG.info(4,x_source.shape,har_source.shape,x.shape)
|
290 |
+
x = x + x_source
|
291 |
+
xs = None
|
292 |
+
for j in range(self.num_kernels):
|
293 |
+
if xs is None:
|
294 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
295 |
+
else:
|
296 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
297 |
+
x = xs / self.num_kernels
|
298 |
+
x = F.leaky_relu(x)
|
299 |
+
x = self.conv_post(x)
|
300 |
+
x = torch.tanh(x)
|
301 |
+
|
302 |
+
return x
|
303 |
+
|
304 |
+
def remove_weight_norm(self):
|
305 |
+
LOG.info("Removing weight norm...")
|
306 |
+
for l in self.ups:
|
307 |
+
remove_weight_norm(l)
|
308 |
+
for l in self.resblocks:
|
309 |
+
l.remove_weight_norm()
|
310 |
+
remove_weight_norm(self.conv_pre)
|
311 |
+
remove_weight_norm(self.conv_post)
|
so_vits_svc_fork/modules/decoders/hifigan/_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from logging import getLogger
|
2 |
+
|
3 |
+
# matplotlib.use("Agg")
|
4 |
+
|
5 |
+
LOG = getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
def init_weights(m, mean=0.0, std=0.01):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
m.weight.data.normal_(mean, std)
|
12 |
+
|
13 |
+
|
14 |
+
def get_padding(kernel_size, dilation=1):
|
15 |
+
return int((kernel_size * dilation - dilation) / 2)
|
so_vits_svc_fork/modules/decoders/mb_istft/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._generators import (
|
2 |
+
Multiband_iSTFT_Generator,
|
3 |
+
Multistream_iSTFT_Generator,
|
4 |
+
iSTFT_Generator,
|
5 |
+
)
|
6 |
+
from ._loss import subband_stft_loss
|
7 |
+
from ._pqmf import PQMF
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"subband_stft_loss",
|
11 |
+
"PQMF",
|
12 |
+
"iSTFT_Generator",
|
13 |
+
"Multiband_iSTFT_Generator",
|
14 |
+
"Multistream_iSTFT_Generator",
|
15 |
+
]
|
so_vits_svc_fork/modules/decoders/mb_istft/_generators.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
8 |
+
|
9 |
+
from ....modules import modules
|
10 |
+
from ....modules.commons import get_padding, init_weights
|
11 |
+
from ._pqmf import PQMF
|
12 |
+
from ._stft import TorchSTFT
|
13 |
+
|
14 |
+
|
15 |
+
class iSTFT_Generator(torch.nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
initial_channel,
|
19 |
+
resblock,
|
20 |
+
resblock_kernel_sizes,
|
21 |
+
resblock_dilation_sizes,
|
22 |
+
upsample_rates,
|
23 |
+
upsample_initial_channel,
|
24 |
+
upsample_kernel_sizes,
|
25 |
+
gen_istft_n_fft,
|
26 |
+
gen_istft_hop_size,
|
27 |
+
gin_channels=0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
# self.h = h
|
31 |
+
self.gen_istft_n_fft = gen_istft_n_fft
|
32 |
+
self.gen_istft_hop_size = gen_istft_hop_size
|
33 |
+
|
34 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
35 |
+
self.num_upsamples = len(upsample_rates)
|
36 |
+
self.conv_pre = weight_norm(
|
37 |
+
Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
38 |
+
)
|
39 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
40 |
+
|
41 |
+
self.ups = nn.ModuleList()
|
42 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
43 |
+
self.ups.append(
|
44 |
+
weight_norm(
|
45 |
+
ConvTranspose1d(
|
46 |
+
upsample_initial_channel // (2**i),
|
47 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
48 |
+
k,
|
49 |
+
u,
|
50 |
+
padding=(k - u) // 2,
|
51 |
+
)
|
52 |
+
)
|
53 |
+
)
|
54 |
+
|
55 |
+
self.resblocks = nn.ModuleList()
|
56 |
+
for i in range(len(self.ups)):
|
57 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
58 |
+
for j, (k, d) in enumerate(
|
59 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
60 |
+
):
|
61 |
+
self.resblocks.append(resblock(ch, k, d))
|
62 |
+
|
63 |
+
self.post_n_fft = self.gen_istft_n_fft
|
64 |
+
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
65 |
+
self.ups.apply(init_weights)
|
66 |
+
self.conv_post.apply(init_weights)
|
67 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
68 |
+
self.stft = TorchSTFT(
|
69 |
+
filter_length=self.gen_istft_n_fft,
|
70 |
+
hop_length=self.gen_istft_hop_size,
|
71 |
+
win_length=self.gen_istft_n_fft,
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, x, g=None):
|
75 |
+
x = self.conv_pre(x)
|
76 |
+
for i in range(self.num_upsamples):
|
77 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
78 |
+
x = self.ups[i](x)
|
79 |
+
xs = None
|
80 |
+
for j in range(self.num_kernels):
|
81 |
+
if xs is None:
|
82 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
83 |
+
else:
|
84 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
85 |
+
x = xs / self.num_kernels
|
86 |
+
x = F.leaky_relu(x)
|
87 |
+
x = self.reflection_pad(x)
|
88 |
+
x = self.conv_post(x)
|
89 |
+
spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
|
90 |
+
phase = math.pi * torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
|
91 |
+
out = self.stft.inverse(spec, phase).to(x.device)
|
92 |
+
return out, None
|
93 |
+
|
94 |
+
def remove_weight_norm(self):
|
95 |
+
print("Removing weight norm...")
|
96 |
+
for l in self.ups:
|
97 |
+
remove_weight_norm(l)
|
98 |
+
for l in self.resblocks:
|
99 |
+
l.remove_weight_norm()
|
100 |
+
remove_weight_norm(self.conv_pre)
|
101 |
+
remove_weight_norm(self.conv_post)
|
102 |
+
|
103 |
+
|
104 |
+
class Multiband_iSTFT_Generator(torch.nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
initial_channel,
|
108 |
+
resblock,
|
109 |
+
resblock_kernel_sizes,
|
110 |
+
resblock_dilation_sizes,
|
111 |
+
upsample_rates,
|
112 |
+
upsample_initial_channel,
|
113 |
+
upsample_kernel_sizes,
|
114 |
+
gen_istft_n_fft,
|
115 |
+
gen_istft_hop_size,
|
116 |
+
subbands,
|
117 |
+
gin_channels=0,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
# self.h = h
|
121 |
+
self.subbands = subbands
|
122 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
123 |
+
self.num_upsamples = len(upsample_rates)
|
124 |
+
self.conv_pre = weight_norm(
|
125 |
+
Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
126 |
+
)
|
127 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
128 |
+
|
129 |
+
self.ups = nn.ModuleList()
|
130 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
131 |
+
self.ups.append(
|
132 |
+
weight_norm(
|
133 |
+
ConvTranspose1d(
|
134 |
+
upsample_initial_channel // (2**i),
|
135 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
136 |
+
k,
|
137 |
+
u,
|
138 |
+
padding=(k - u) // 2,
|
139 |
+
)
|
140 |
+
)
|
141 |
+
)
|
142 |
+
|
143 |
+
self.resblocks = nn.ModuleList()
|
144 |
+
for i in range(len(self.ups)):
|
145 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
146 |
+
for j, (k, d) in enumerate(
|
147 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
148 |
+
):
|
149 |
+
self.resblocks.append(resblock(ch, k, d))
|
150 |
+
|
151 |
+
self.post_n_fft = gen_istft_n_fft
|
152 |
+
self.ups.apply(init_weights)
|
153 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
154 |
+
self.reshape_pixelshuffle = []
|
155 |
+
|
156 |
+
self.subband_conv_post = weight_norm(
|
157 |
+
Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3)
|
158 |
+
)
|
159 |
+
|
160 |
+
self.subband_conv_post.apply(init_weights)
|
161 |
+
|
162 |
+
self.gen_istft_n_fft = gen_istft_n_fft
|
163 |
+
self.gen_istft_hop_size = gen_istft_hop_size
|
164 |
+
|
165 |
+
def forward(self, x, g=None):
|
166 |
+
stft = TorchSTFT(
|
167 |
+
filter_length=self.gen_istft_n_fft,
|
168 |
+
hop_length=self.gen_istft_hop_size,
|
169 |
+
win_length=self.gen_istft_n_fft,
|
170 |
+
).to(x.device)
|
171 |
+
pqmf = PQMF(x.device, subbands=self.subbands).to(x.device, dtype=x.dtype)
|
172 |
+
|
173 |
+
x = self.conv_pre(x) # [B, ch, length]
|
174 |
+
|
175 |
+
for i in range(self.num_upsamples):
|
176 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
177 |
+
x = self.ups[i](x)
|
178 |
+
|
179 |
+
xs = None
|
180 |
+
for j in range(self.num_kernels):
|
181 |
+
if xs is None:
|
182 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
183 |
+
else:
|
184 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
185 |
+
x = xs / self.num_kernels
|
186 |
+
|
187 |
+
x = F.leaky_relu(x)
|
188 |
+
x = self.reflection_pad(x)
|
189 |
+
x = self.subband_conv_post(x)
|
190 |
+
x = torch.reshape(
|
191 |
+
x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1])
|
192 |
+
)
|
193 |
+
|
194 |
+
spec = torch.exp(x[:, :, : self.post_n_fft // 2 + 1, :])
|
195 |
+
phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1 :, :])
|
196 |
+
|
197 |
+
y_mb_hat = stft.inverse(
|
198 |
+
torch.reshape(
|
199 |
+
spec,
|
200 |
+
(
|
201 |
+
spec.shape[0] * self.subbands,
|
202 |
+
self.gen_istft_n_fft // 2 + 1,
|
203 |
+
spec.shape[-1],
|
204 |
+
),
|
205 |
+
),
|
206 |
+
torch.reshape(
|
207 |
+
phase,
|
208 |
+
(
|
209 |
+
phase.shape[0] * self.subbands,
|
210 |
+
self.gen_istft_n_fft // 2 + 1,
|
211 |
+
phase.shape[-1],
|
212 |
+
),
|
213 |
+
),
|
214 |
+
)
|
215 |
+
y_mb_hat = torch.reshape(
|
216 |
+
y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])
|
217 |
+
)
|
218 |
+
y_mb_hat = y_mb_hat.squeeze(-2)
|
219 |
+
|
220 |
+
y_g_hat = pqmf.synthesis(y_mb_hat)
|
221 |
+
|
222 |
+
return y_g_hat, y_mb_hat
|
223 |
+
|
224 |
+
def remove_weight_norm(self):
|
225 |
+
print("Removing weight norm...")
|
226 |
+
for l in self.ups:
|
227 |
+
remove_weight_norm(l)
|
228 |
+
for l in self.resblocks:
|
229 |
+
l.remove_weight_norm()
|
230 |
+
|
231 |
+
|
232 |
+
class Multistream_iSTFT_Generator(torch.nn.Module):
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
initial_channel,
|
236 |
+
resblock,
|
237 |
+
resblock_kernel_sizes,
|
238 |
+
resblock_dilation_sizes,
|
239 |
+
upsample_rates,
|
240 |
+
upsample_initial_channel,
|
241 |
+
upsample_kernel_sizes,
|
242 |
+
gen_istft_n_fft,
|
243 |
+
gen_istft_hop_size,
|
244 |
+
subbands,
|
245 |
+
gin_channels=0,
|
246 |
+
):
|
247 |
+
super().__init__()
|
248 |
+
# self.h = h
|
249 |
+
self.subbands = subbands
|
250 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
251 |
+
self.num_upsamples = len(upsample_rates)
|
252 |
+
self.conv_pre = weight_norm(
|
253 |
+
Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
254 |
+
)
|
255 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
256 |
+
|
257 |
+
self.ups = nn.ModuleList()
|
258 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
259 |
+
self.ups.append(
|
260 |
+
weight_norm(
|
261 |
+
ConvTranspose1d(
|
262 |
+
upsample_initial_channel // (2**i),
|
263 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
264 |
+
k,
|
265 |
+
u,
|
266 |
+
padding=(k - u) // 2,
|
267 |
+
)
|
268 |
+
)
|
269 |
+
)
|
270 |
+
|
271 |
+
self.resblocks = nn.ModuleList()
|
272 |
+
for i in range(len(self.ups)):
|
273 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
274 |
+
for j, (k, d) in enumerate(
|
275 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
276 |
+
):
|
277 |
+
self.resblocks.append(resblock(ch, k, d))
|
278 |
+
|
279 |
+
self.post_n_fft = gen_istft_n_fft
|
280 |
+
self.ups.apply(init_weights)
|
281 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
282 |
+
self.reshape_pixelshuffle = []
|
283 |
+
|
284 |
+
self.subband_conv_post = weight_norm(
|
285 |
+
Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3)
|
286 |
+
)
|
287 |
+
|
288 |
+
self.subband_conv_post.apply(init_weights)
|
289 |
+
|
290 |
+
self.gen_istft_n_fft = gen_istft_n_fft
|
291 |
+
self.gen_istft_hop_size = gen_istft_hop_size
|
292 |
+
|
293 |
+
updown_filter = torch.zeros(
|
294 |
+
(self.subbands, self.subbands, self.subbands)
|
295 |
+
).float()
|
296 |
+
for k in range(self.subbands):
|
297 |
+
updown_filter[k, k, 0] = 1.0
|
298 |
+
self.register_buffer("updown_filter", updown_filter)
|
299 |
+
self.multistream_conv_post = weight_norm(
|
300 |
+
Conv1d(
|
301 |
+
self.subbands, 1, kernel_size=63, bias=False, padding=get_padding(63, 1)
|
302 |
+
)
|
303 |
+
)
|
304 |
+
self.multistream_conv_post.apply(init_weights)
|
305 |
+
|
306 |
+
def forward(self, x, g=None):
|
307 |
+
stft = TorchSTFT(
|
308 |
+
filter_length=self.gen_istft_n_fft,
|
309 |
+
hop_length=self.gen_istft_hop_size,
|
310 |
+
win_length=self.gen_istft_n_fft,
|
311 |
+
).to(x.device)
|
312 |
+
# pqmf = PQMF(x.device)
|
313 |
+
|
314 |
+
x = self.conv_pre(x) # [B, ch, length]
|
315 |
+
|
316 |
+
for i in range(self.num_upsamples):
|
317 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
318 |
+
x = self.ups[i](x)
|
319 |
+
|
320 |
+
xs = None
|
321 |
+
for j in range(self.num_kernels):
|
322 |
+
if xs is None:
|
323 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
324 |
+
else:
|
325 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
326 |
+
x = xs / self.num_kernels
|
327 |
+
|
328 |
+
x = F.leaky_relu(x)
|
329 |
+
x = self.reflection_pad(x)
|
330 |
+
x = self.subband_conv_post(x)
|
331 |
+
x = torch.reshape(
|
332 |
+
x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1])
|
333 |
+
)
|
334 |
+
|
335 |
+
spec = torch.exp(x[:, :, : self.post_n_fft // 2 + 1, :])
|
336 |
+
phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1 :, :])
|
337 |
+
|
338 |
+
y_mb_hat = stft.inverse(
|
339 |
+
torch.reshape(
|
340 |
+
spec,
|
341 |
+
(
|
342 |
+
spec.shape[0] * self.subbands,
|
343 |
+
self.gen_istft_n_fft // 2 + 1,
|
344 |
+
spec.shape[-1],
|
345 |
+
),
|
346 |
+
),
|
347 |
+
torch.reshape(
|
348 |
+
phase,
|
349 |
+
(
|
350 |
+
phase.shape[0] * self.subbands,
|
351 |
+
self.gen_istft_n_fft // 2 + 1,
|
352 |
+
phase.shape[-1],
|
353 |
+
),
|
354 |
+
),
|
355 |
+
)
|
356 |
+
y_mb_hat = torch.reshape(
|
357 |
+
y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])
|
358 |
+
)
|
359 |
+
y_mb_hat = y_mb_hat.squeeze(-2)
|
360 |
+
|
361 |
+
y_mb_hat = F.conv_transpose1d(
|
362 |
+
y_mb_hat,
|
363 |
+
self.updown_filter.to(x.device) * self.subbands,
|
364 |
+
stride=self.subbands,
|
365 |
+
)
|
366 |
+
|
367 |
+
y_g_hat = self.multistream_conv_post(y_mb_hat)
|
368 |
+
|
369 |
+
return y_g_hat, y_mb_hat
|
370 |
+
|
371 |
+
def remove_weight_norm(self):
|
372 |
+
print("Removing weight norm...")
|
373 |
+
for l in self.ups:
|
374 |
+
remove_weight_norm(l)
|
375 |
+
for l in self.resblocks:
|
376 |
+
l.remove_weight_norm()
|
so_vits_svc_fork/modules/decoders/mb_istft/_loss.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._stft_loss import MultiResolutionSTFTLoss
|
2 |
+
|
3 |
+
|
4 |
+
def subband_stft_loss(h, y_mb, y_hat_mb):
|
5 |
+
sub_stft_loss = MultiResolutionSTFTLoss(
|
6 |
+
h.train.fft_sizes, h.train.hop_sizes, h.train.win_lengths
|
7 |
+
)
|
8 |
+
y_mb = y_mb.view(-1, y_mb.size(2))
|
9 |
+
y_hat_mb = y_hat_mb.view(-1, y_hat_mb.size(2))
|
10 |
+
sub_sc_loss, sub_mag_loss = sub_stft_loss(y_hat_mb[:, : y_mb.size(-1)], y_mb)
|
11 |
+
return sub_sc_loss + sub_mag_loss
|
so_vits_svc_fork/modules/decoders/mb_istft/_pqmf.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
|
4 |
+
"""Pseudo QMF modules."""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from scipy.signal import kaiser
|
10 |
+
|
11 |
+
|
12 |
+
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
|
13 |
+
"""Design prototype filter for PQMF.
|
14 |
+
This method is based on `A Kaiser window approach for the design of prototype
|
15 |
+
filters of cosine modulated filterbanks`_.
|
16 |
+
Args:
|
17 |
+
taps (int): The number of filter taps.
|
18 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
19 |
+
beta (float): Beta coefficient for kaiser window.
|
20 |
+
Returns:
|
21 |
+
ndarray: Impluse response of prototype filter (taps + 1,).
|
22 |
+
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
23 |
+
https://ieeexplore.ieee.org/abstract/document/681427
|
24 |
+
"""
|
25 |
+
# check the arguments are valid
|
26 |
+
assert taps % 2 == 0, "The number of taps mush be even number."
|
27 |
+
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
28 |
+
|
29 |
+
# make initial filter
|
30 |
+
omega_c = np.pi * cutoff_ratio
|
31 |
+
with np.errstate(invalid="ignore"):
|
32 |
+
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
|
33 |
+
np.pi * (np.arange(taps + 1) - 0.5 * taps)
|
34 |
+
)
|
35 |
+
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
36 |
+
|
37 |
+
# apply kaiser window
|
38 |
+
w = kaiser(taps + 1, beta)
|
39 |
+
h = h_i * w
|
40 |
+
|
41 |
+
return h
|
42 |
+
|
43 |
+
|
44 |
+
class PQMF(torch.nn.Module):
|
45 |
+
"""PQMF module.
|
46 |
+
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
47 |
+
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
48 |
+
https://ieeexplore.ieee.org/document/258122
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, device, subbands=8, taps=62, cutoff_ratio=0.15, beta=9.0):
|
52 |
+
"""Initialize PQMF module.
|
53 |
+
Args:
|
54 |
+
subbands (int): The number of subbands.
|
55 |
+
taps (int): The number of filter taps.
|
56 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
57 |
+
beta (float): Beta coefficient for kaiser window.
|
58 |
+
"""
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
# define filter coefficient
|
62 |
+
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
63 |
+
h_analysis = np.zeros((subbands, len(h_proto)))
|
64 |
+
h_synthesis = np.zeros((subbands, len(h_proto)))
|
65 |
+
for k in range(subbands):
|
66 |
+
h_analysis[k] = (
|
67 |
+
2
|
68 |
+
* h_proto
|
69 |
+
* np.cos(
|
70 |
+
(2 * k + 1)
|
71 |
+
* (np.pi / (2 * subbands))
|
72 |
+
* (np.arange(taps + 1) - ((taps - 1) / 2))
|
73 |
+
+ (-1) ** k * np.pi / 4
|
74 |
+
)
|
75 |
+
)
|
76 |
+
h_synthesis[k] = (
|
77 |
+
2
|
78 |
+
* h_proto
|
79 |
+
* np.cos(
|
80 |
+
(2 * k + 1)
|
81 |
+
* (np.pi / (2 * subbands))
|
82 |
+
* (np.arange(taps + 1) - ((taps - 1) / 2))
|
83 |
+
- (-1) ** k * np.pi / 4
|
84 |
+
)
|
85 |
+
)
|
86 |
+
|
87 |
+
# convert to tensor
|
88 |
+
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1).to(device)
|
89 |
+
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0).to(device)
|
90 |
+
|
91 |
+
# register coefficients as buffer
|
92 |
+
self.register_buffer("analysis_filter", analysis_filter)
|
93 |
+
self.register_buffer("synthesis_filter", synthesis_filter)
|
94 |
+
|
95 |
+
# filter for downsampling & upsampling
|
96 |
+
updown_filter = torch.zeros((subbands, subbands, subbands)).float().to(device)
|
97 |
+
for k in range(subbands):
|
98 |
+
updown_filter[k, k, 0] = 1.0
|
99 |
+
self.register_buffer("updown_filter", updown_filter)
|
100 |
+
self.subbands = subbands
|
101 |
+
|
102 |
+
# keep padding info
|
103 |
+
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
104 |
+
|
105 |
+
def analysis(self, x):
|
106 |
+
"""Analysis with PQMF.
|
107 |
+
Args:
|
108 |
+
x (Tensor): Input tensor (B, 1, T).
|
109 |
+
Returns:
|
110 |
+
Tensor: Output tensor (B, subbands, T // subbands).
|
111 |
+
"""
|
112 |
+
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
113 |
+
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
114 |
+
|
115 |
+
def synthesis(self, x):
|
116 |
+
"""Synthesis with PQMF.
|
117 |
+
Args:
|
118 |
+
x (Tensor): Input tensor (B, subbands, T // subbands).
|
119 |
+
Returns:
|
120 |
+
Tensor: Output tensor (B, 1, T).
|
121 |
+
"""
|
122 |
+
# NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands.
|
123 |
+
# Not sure this is the correct way, it is better to check again.
|
124 |
+
# TODO(kan-bayashi): Understand the reconstruction procedure
|
125 |
+
x = F.conv_transpose1d(
|
126 |
+
x, self.updown_filter * self.subbands, stride=self.subbands
|
127 |
+
)
|
128 |
+
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
so_vits_svc_fork/modules/decoders/mb_istft/_stft.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
Copyright (c) 2017, Prem Seetharaman
|
4 |
+
All rights reserved.
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
* Redistributions of source code must retain the above copyright notice,
|
8 |
+
this list of conditions and the following disclaimer.
|
9 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer in the
|
11 |
+
documentation and/or other materials provided with the distribution.
|
12 |
+
* Neither the name of the copyright holder nor the names of its
|
13 |
+
contributors may be used to endorse or promote products derived from this
|
14 |
+
software without specific prior written permission.
|
15 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
19 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
22 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
"""
|
26 |
+
|
27 |
+
import librosa.util as librosa_util
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from librosa.util import pad_center, tiny
|
32 |
+
from scipy.signal import get_window
|
33 |
+
from torch.autograd import Variable
|
34 |
+
|
35 |
+
|
36 |
+
def window_sumsquare(
|
37 |
+
window,
|
38 |
+
n_frames,
|
39 |
+
hop_length=200,
|
40 |
+
win_length=800,
|
41 |
+
n_fft=800,
|
42 |
+
dtype=np.float32,
|
43 |
+
norm=None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
# from librosa 0.6
|
47 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
48 |
+
This is used to estimate modulation effects induced by windowing
|
49 |
+
observations in short-time fourier transforms.
|
50 |
+
Parameters
|
51 |
+
----------
|
52 |
+
window : string, tuple, number, callable, or list-like
|
53 |
+
Window specification, as in `get_window`
|
54 |
+
n_frames : int > 0
|
55 |
+
The number of analysis frames
|
56 |
+
hop_length : int > 0
|
57 |
+
The number of samples to advance between frames
|
58 |
+
win_length : [optional]
|
59 |
+
The length of the window function. By default, this matches `n_fft`.
|
60 |
+
n_fft : int > 0
|
61 |
+
The length of each analysis frame.
|
62 |
+
dtype : np.dtype
|
63 |
+
The data type of the output
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
67 |
+
The sum-squared envelope of the window function
|
68 |
+
"""
|
69 |
+
if win_length is None:
|
70 |
+
win_length = n_fft
|
71 |
+
|
72 |
+
n = n_fft + hop_length * (n_frames - 1)
|
73 |
+
x = np.zeros(n, dtype=dtype)
|
74 |
+
|
75 |
+
# Compute the squared window at the desired length
|
76 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
77 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
78 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
79 |
+
|
80 |
+
# Fill the envelope
|
81 |
+
for i in range(n_frames):
|
82 |
+
sample = i * hop_length
|
83 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class STFT(torch.nn.Module):
|
88 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.filter_length = filter_length
|
95 |
+
self.hop_length = hop_length
|
96 |
+
self.win_length = win_length
|
97 |
+
self.window = window
|
98 |
+
self.forward_transform = None
|
99 |
+
scale = self.filter_length / self.hop_length
|
100 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
101 |
+
|
102 |
+
cutoff = int(self.filter_length / 2 + 1)
|
103 |
+
fourier_basis = np.vstack(
|
104 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
105 |
+
)
|
106 |
+
|
107 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
108 |
+
inverse_basis = torch.FloatTensor(
|
109 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
110 |
+
)
|
111 |
+
|
112 |
+
if window is not None:
|
113 |
+
assert filter_length >= win_length
|
114 |
+
# get window and zero center pad it to filter_length
|
115 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
116 |
+
fft_window = pad_center(fft_window, filter_length)
|
117 |
+
fft_window = torch.from_numpy(fft_window).float()
|
118 |
+
|
119 |
+
# window the bases
|
120 |
+
forward_basis *= fft_window
|
121 |
+
inverse_basis *= fft_window
|
122 |
+
|
123 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
124 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
125 |
+
|
126 |
+
def transform(self, input_data):
|
127 |
+
num_batches = input_data.size(0)
|
128 |
+
num_samples = input_data.size(1)
|
129 |
+
|
130 |
+
self.num_samples = num_samples
|
131 |
+
|
132 |
+
# similar to librosa, reflect-pad the input
|
133 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
134 |
+
input_data = F.pad(
|
135 |
+
input_data.unsqueeze(1),
|
136 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
137 |
+
mode="reflect",
|
138 |
+
)
|
139 |
+
input_data = input_data.squeeze(1)
|
140 |
+
|
141 |
+
forward_transform = F.conv1d(
|
142 |
+
input_data,
|
143 |
+
Variable(self.forward_basis, requires_grad=False),
|
144 |
+
stride=self.hop_length,
|
145 |
+
padding=0,
|
146 |
+
)
|
147 |
+
|
148 |
+
cutoff = int((self.filter_length / 2) + 1)
|
149 |
+
real_part = forward_transform[:, :cutoff, :]
|
150 |
+
imag_part = forward_transform[:, cutoff:, :]
|
151 |
+
|
152 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
153 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
154 |
+
|
155 |
+
return magnitude, phase
|
156 |
+
|
157 |
+
def inverse(self, magnitude, phase):
|
158 |
+
recombine_magnitude_phase = torch.cat(
|
159 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
160 |
+
)
|
161 |
+
|
162 |
+
inverse_transform = F.conv_transpose1d(
|
163 |
+
recombine_magnitude_phase,
|
164 |
+
Variable(self.inverse_basis, requires_grad=False),
|
165 |
+
stride=self.hop_length,
|
166 |
+
padding=0,
|
167 |
+
)
|
168 |
+
|
169 |
+
if self.window is not None:
|
170 |
+
window_sum = window_sumsquare(
|
171 |
+
self.window,
|
172 |
+
magnitude.size(-1),
|
173 |
+
hop_length=self.hop_length,
|
174 |
+
win_length=self.win_length,
|
175 |
+
n_fft=self.filter_length,
|
176 |
+
dtype=np.float32,
|
177 |
+
)
|
178 |
+
# remove modulation effects
|
179 |
+
approx_nonzero_indices = torch.from_numpy(
|
180 |
+
np.where(window_sum > tiny(window_sum))[0]
|
181 |
+
)
|
182 |
+
window_sum = torch.autograd.Variable(
|
183 |
+
torch.from_numpy(window_sum), requires_grad=False
|
184 |
+
)
|
185 |
+
window_sum = window_sum.to(inverse_transform.device())
|
186 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
187 |
+
approx_nonzero_indices
|
188 |
+
]
|
189 |
+
|
190 |
+
# scale by hop ratio
|
191 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
192 |
+
|
193 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
194 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
195 |
+
|
196 |
+
return inverse_transform
|
197 |
+
|
198 |
+
def forward(self, input_data):
|
199 |
+
self.magnitude, self.phase = self.transform(input_data)
|
200 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
201 |
+
return reconstruction
|
202 |
+
|
203 |
+
|
204 |
+
class TorchSTFT(torch.nn.Module):
|
205 |
+
def __init__(
|
206 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
207 |
+
):
|
208 |
+
super().__init__()
|
209 |
+
self.filter_length = filter_length
|
210 |
+
self.hop_length = hop_length
|
211 |
+
self.win_length = win_length
|
212 |
+
self.window = torch.from_numpy(
|
213 |
+
get_window(window, win_length, fftbins=True).astype(np.float32)
|
214 |
+
)
|
215 |
+
|
216 |
+
def transform(self, input_data):
|
217 |
+
forward_transform = torch.stft(
|
218 |
+
input_data,
|
219 |
+
self.filter_length,
|
220 |
+
self.hop_length,
|
221 |
+
self.win_length,
|
222 |
+
window=self.window,
|
223 |
+
return_complex=True,
|
224 |
+
)
|
225 |
+
|
226 |
+
return torch.abs(forward_transform), torch.angle(forward_transform)
|
227 |
+
|
228 |
+
def inverse(self, magnitude, phase):
|
229 |
+
inverse_transform = torch.istft(
|
230 |
+
magnitude * torch.exp(phase * 1j),
|
231 |
+
self.filter_length,
|
232 |
+
self.hop_length,
|
233 |
+
self.win_length,
|
234 |
+
window=self.window.to(magnitude.device),
|
235 |
+
)
|
236 |
+
|
237 |
+
return inverse_transform.unsqueeze(
|
238 |
+
-2
|
239 |
+
) # unsqueeze to stay consistent with conv_transpose1d implementation
|
240 |
+
|
241 |
+
def forward(self, input_data):
|
242 |
+
self.magnitude, self.phase = self.transform(input_data)
|
243 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
244 |
+
return reconstruction
|
so_vits_svc_fork/modules/decoders/mb_istft/_stft_loss.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
|
4 |
+
"""STFT-based Loss modules."""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
11 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
12 |
+
Args:
|
13 |
+
x (Tensor): Input signal tensor (B, T).
|
14 |
+
fft_size (int): FFT size.
|
15 |
+
hop_size (int): Hop size.
|
16 |
+
win_length (int): Window length.
|
17 |
+
window (str): Window function type.
|
18 |
+
Returns:
|
19 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
20 |
+
"""
|
21 |
+
x_stft = torch.stft(
|
22 |
+
x, fft_size, hop_size, win_length, window.to(x.device), return_complex=False
|
23 |
+
)
|
24 |
+
real = x_stft[..., 0]
|
25 |
+
imag = x_stft[..., 1]
|
26 |
+
|
27 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
28 |
+
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
29 |
+
|
30 |
+
|
31 |
+
class SpectralConvergengeLoss(torch.nn.Module):
|
32 |
+
"""Spectral convergence loss module."""
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
"""Initialize spectral convergence loss module."""
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
def forward(self, x_mag, y_mag):
|
39 |
+
"""Calculate forward propagation.
|
40 |
+
Args:
|
41 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
42 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
43 |
+
Returns:
|
44 |
+
Tensor: Spectral convergence loss value.
|
45 |
+
"""
|
46 |
+
return torch.norm(y_mag - x_mag) / torch.norm(
|
47 |
+
y_mag
|
48 |
+
) # MB-iSTFT-VITS changed here due to codespell
|
49 |
+
|
50 |
+
|
51 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
52 |
+
"""Log STFT magnitude loss module."""
|
53 |
+
|
54 |
+
def __init__(self):
|
55 |
+
"""Initialize los STFT magnitude loss module."""
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
def forward(self, x_mag, y_mag):
|
59 |
+
"""Calculate forward propagation.
|
60 |
+
Args:
|
61 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
62 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
63 |
+
Returns:
|
64 |
+
Tensor: Log STFT magnitude loss value.
|
65 |
+
"""
|
66 |
+
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
67 |
+
|
68 |
+
|
69 |
+
class STFTLoss(torch.nn.Module):
|
70 |
+
"""STFT loss module."""
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"
|
74 |
+
):
|
75 |
+
"""Initialize STFT loss module."""
|
76 |
+
super().__init__()
|
77 |
+
self.fft_size = fft_size
|
78 |
+
self.shift_size = shift_size
|
79 |
+
self.win_length = win_length
|
80 |
+
self.window = getattr(torch, window)(win_length)
|
81 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
82 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
83 |
+
|
84 |
+
def forward(self, x, y):
|
85 |
+
"""Calculate forward propagation.
|
86 |
+
Args:
|
87 |
+
x (Tensor): Predicted signal (B, T).
|
88 |
+
y (Tensor): Groundtruth signal (B, T).
|
89 |
+
Returns:
|
90 |
+
Tensor: Spectral convergence loss value.
|
91 |
+
Tensor: Log STFT magnitude loss value.
|
92 |
+
"""
|
93 |
+
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
94 |
+
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
95 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
96 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
97 |
+
|
98 |
+
return sc_loss, mag_loss
|
99 |
+
|
100 |
+
|
101 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
102 |
+
"""Multi resolution STFT loss module."""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
fft_sizes=[1024, 2048, 512],
|
107 |
+
hop_sizes=[120, 240, 50],
|
108 |
+
win_lengths=[600, 1200, 240],
|
109 |
+
window="hann_window",
|
110 |
+
):
|
111 |
+
"""Initialize Multi resolution STFT loss module.
|
112 |
+
Args:
|
113 |
+
fft_sizes (list): List of FFT sizes.
|
114 |
+
hop_sizes (list): List of hop sizes.
|
115 |
+
win_lengths (list): List of window lengths.
|
116 |
+
window (str): Window function type.
|
117 |
+
"""
|
118 |
+
super().__init__()
|
119 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
120 |
+
self.stft_losses = torch.nn.ModuleList()
|
121 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
122 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
123 |
+
|
124 |
+
def forward(self, x, y):
|
125 |
+
"""Calculate forward propagation.
|
126 |
+
Args:
|
127 |
+
x (Tensor): Predicted signal (B, T).
|
128 |
+
y (Tensor): Groundtruth signal (B, T).
|
129 |
+
Returns:
|
130 |
+
Tensor: Multi resolution spectral convergence loss value.
|
131 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
132 |
+
"""
|
133 |
+
sc_loss = 0.0
|
134 |
+
mag_loss = 0.0
|
135 |
+
for f in self.stft_losses:
|
136 |
+
sc_l, mag_l = f(x, y)
|
137 |
+
sc_loss += sc_l
|
138 |
+
mag_loss += mag_l
|
139 |
+
sc_loss /= len(self.stft_losses)
|
140 |
+
mag_loss /= len(self.stft_losses)
|
141 |
+
|
142 |
+
return sc_loss, mag_loss
|
so_vits_svc_fork/modules/descriminators.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import AvgPool1d, Conv1d, Conv2d
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
6 |
+
|
7 |
+
from so_vits_svc_fork.modules import modules as modules
|
8 |
+
from so_vits_svc_fork.modules.commons import get_padding
|
9 |
+
|
10 |
+
|
11 |
+
class DiscriminatorP(torch.nn.Module):
|
12 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
13 |
+
super().__init__()
|
14 |
+
self.period = period
|
15 |
+
self.use_spectral_norm = use_spectral_norm
|
16 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
17 |
+
self.convs = nn.ModuleList(
|
18 |
+
[
|
19 |
+
norm_f(
|
20 |
+
Conv2d(
|
21 |
+
1,
|
22 |
+
32,
|
23 |
+
(kernel_size, 1),
|
24 |
+
(stride, 1),
|
25 |
+
padding=(get_padding(kernel_size, 1), 0),
|
26 |
+
)
|
27 |
+
),
|
28 |
+
norm_f(
|
29 |
+
Conv2d(
|
30 |
+
32,
|
31 |
+
128,
|
32 |
+
(kernel_size, 1),
|
33 |
+
(stride, 1),
|
34 |
+
padding=(get_padding(kernel_size, 1), 0),
|
35 |
+
)
|
36 |
+
),
|
37 |
+
norm_f(
|
38 |
+
Conv2d(
|
39 |
+
128,
|
40 |
+
512,
|
41 |
+
(kernel_size, 1),
|
42 |
+
(stride, 1),
|
43 |
+
padding=(get_padding(kernel_size, 1), 0),
|
44 |
+
)
|
45 |
+
),
|
46 |
+
norm_f(
|
47 |
+
Conv2d(
|
48 |
+
512,
|
49 |
+
1024,
|
50 |
+
(kernel_size, 1),
|
51 |
+
(stride, 1),
|
52 |
+
padding=(get_padding(kernel_size, 1), 0),
|
53 |
+
)
|
54 |
+
),
|
55 |
+
norm_f(
|
56 |
+
Conv2d(
|
57 |
+
1024,
|
58 |
+
1024,
|
59 |
+
(kernel_size, 1),
|
60 |
+
1,
|
61 |
+
padding=(get_padding(kernel_size, 1), 0),
|
62 |
+
)
|
63 |
+
),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
fmap = []
|
70 |
+
|
71 |
+
# 1d to 2d
|
72 |
+
b, c, t = x.shape
|
73 |
+
if t % self.period != 0: # pad first
|
74 |
+
n_pad = self.period - (t % self.period)
|
75 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
76 |
+
t = t + n_pad
|
77 |
+
x = x.view(b, c, t // self.period, self.period)
|
78 |
+
|
79 |
+
for l in self.convs:
|
80 |
+
x = l(x)
|
81 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
82 |
+
fmap.append(x)
|
83 |
+
x = self.conv_post(x)
|
84 |
+
fmap.append(x)
|
85 |
+
x = torch.flatten(x, 1, -1)
|
86 |
+
|
87 |
+
return x, fmap
|
88 |
+
|
89 |
+
|
90 |
+
class DiscriminatorS(torch.nn.Module):
|
91 |
+
def __init__(self, use_spectral_norm=False):
|
92 |
+
super().__init__()
|
93 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
94 |
+
self.convs = nn.ModuleList(
|
95 |
+
[
|
96 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
97 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
98 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
99 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
100 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
101 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
102 |
+
]
|
103 |
+
)
|
104 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
fmap = []
|
108 |
+
|
109 |
+
for l in self.convs:
|
110 |
+
x = l(x)
|
111 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
112 |
+
fmap.append(x)
|
113 |
+
x = self.conv_post(x)
|
114 |
+
fmap.append(x)
|
115 |
+
x = torch.flatten(x, 1, -1)
|
116 |
+
|
117 |
+
return x, fmap
|
118 |
+
|
119 |
+
|
120 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
121 |
+
def __init__(self, use_spectral_norm=False):
|
122 |
+
super().__init__()
|
123 |
+
periods = [2, 3, 5, 7, 11]
|
124 |
+
|
125 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
126 |
+
discs = discs + [
|
127 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
128 |
+
]
|
129 |
+
self.discriminators = nn.ModuleList(discs)
|
130 |
+
|
131 |
+
def forward(self, y, y_hat):
|
132 |
+
y_d_rs = []
|
133 |
+
y_d_gs = []
|
134 |
+
fmap_rs = []
|
135 |
+
fmap_gs = []
|
136 |
+
for i, d in enumerate(self.discriminators):
|
137 |
+
y_d_r, fmap_r = d(y)
|
138 |
+
y_d_g, fmap_g = d(y_hat)
|
139 |
+
y_d_rs.append(y_d_r)
|
140 |
+
y_d_gs.append(y_d_g)
|
141 |
+
fmap_rs.append(fmap_r)
|
142 |
+
fmap_gs.append(fmap_g)
|
143 |
+
|
144 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
145 |
+
|
146 |
+
|
147 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
148 |
+
def __init__(self):
|
149 |
+
super().__init__()
|
150 |
+
self.discriminators = nn.ModuleList(
|
151 |
+
[
|
152 |
+
DiscriminatorS(use_spectral_norm=True),
|
153 |
+
DiscriminatorS(),
|
154 |
+
DiscriminatorS(),
|
155 |
+
]
|
156 |
+
)
|
157 |
+
self.meanpools = nn.ModuleList(
|
158 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
159 |
+
)
|
160 |
+
|
161 |
+
def forward(self, y, y_hat):
|
162 |
+
y_d_rs = []
|
163 |
+
y_d_gs = []
|
164 |
+
fmap_rs = []
|
165 |
+
fmap_gs = []
|
166 |
+
for i, d in enumerate(self.discriminators):
|
167 |
+
if i != 0:
|
168 |
+
y = self.meanpools[i - 1](y)
|
169 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
170 |
+
y_d_r, fmap_r = d(y)
|
171 |
+
y_d_g, fmap_g = d(y_hat)
|
172 |
+
y_d_rs.append(y_d_r)
|
173 |
+
fmap_rs.append(fmap_r)
|
174 |
+
y_d_gs.append(y_d_g)
|
175 |
+
fmap_gs.append(fmap_g)
|
176 |
+
|
177 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
so_vits_svc_fork/modules/encoders.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from so_vits_svc_fork.modules import attentions as attentions
|
5 |
+
from so_vits_svc_fork.modules import commons as commons
|
6 |
+
from so_vits_svc_fork.modules import modules as modules
|
7 |
+
|
8 |
+
|
9 |
+
class SpeakerEncoder(torch.nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
mel_n_channels=80,
|
13 |
+
model_num_layers=3,
|
14 |
+
model_hidden_size=256,
|
15 |
+
model_embedding_size=256,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.lstm = nn.LSTM(
|
19 |
+
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
20 |
+
)
|
21 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
22 |
+
self.relu = nn.ReLU()
|
23 |
+
|
24 |
+
def forward(self, mels):
|
25 |
+
self.lstm.flatten_parameters()
|
26 |
+
_, (hidden, _) = self.lstm(mels)
|
27 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
28 |
+
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
29 |
+
|
30 |
+
def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
|
31 |
+
mel_slices = []
|
32 |
+
for i in range(0, total_frames - partial_frames, partial_hop):
|
33 |
+
mel_range = torch.arange(i, i + partial_frames)
|
34 |
+
mel_slices.append(mel_range)
|
35 |
+
|
36 |
+
return mel_slices
|
37 |
+
|
38 |
+
def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
|
39 |
+
mel_len = mel.size(1)
|
40 |
+
last_mel = mel[:, -partial_frames:]
|
41 |
+
|
42 |
+
if mel_len > partial_frames:
|
43 |
+
mel_slices = self.compute_partial_slices(
|
44 |
+
mel_len, partial_frames, partial_hop
|
45 |
+
)
|
46 |
+
mels = list(mel[:, s] for s in mel_slices)
|
47 |
+
mels.append(last_mel)
|
48 |
+
mels = torch.stack(tuple(mels), 0).squeeze(1)
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
partial_embeds = self(mels)
|
52 |
+
embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
|
53 |
+
# embed = embed / torch.linalg.norm(embed, 2)
|
54 |
+
else:
|
55 |
+
with torch.no_grad():
|
56 |
+
embed = self(last_mel)
|
57 |
+
|
58 |
+
return embed
|
59 |
+
|
60 |
+
|
61 |
+
class Encoder(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
in_channels,
|
65 |
+
out_channels,
|
66 |
+
hidden_channels,
|
67 |
+
kernel_size,
|
68 |
+
dilation_rate,
|
69 |
+
n_layers,
|
70 |
+
gin_channels=0,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
self.in_channels = in_channels
|
74 |
+
self.out_channels = out_channels
|
75 |
+
self.hidden_channels = hidden_channels
|
76 |
+
self.kernel_size = kernel_size
|
77 |
+
self.dilation_rate = dilation_rate
|
78 |
+
self.n_layers = n_layers
|
79 |
+
self.gin_channels = gin_channels
|
80 |
+
|
81 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
82 |
+
self.enc = modules.WN(
|
83 |
+
hidden_channels,
|
84 |
+
kernel_size,
|
85 |
+
dilation_rate,
|
86 |
+
n_layers,
|
87 |
+
gin_channels=gin_channels,
|
88 |
+
)
|
89 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
90 |
+
|
91 |
+
def forward(self, x, x_lengths, g=None):
|
92 |
+
# print(x.shape,x_lengths.shape)
|
93 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
94 |
+
x.dtype
|
95 |
+
)
|
96 |
+
x = self.pre(x) * x_mask
|
97 |
+
x = self.enc(x, x_mask, g=g)
|
98 |
+
stats = self.proj(x) * x_mask
|
99 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
100 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
101 |
+
return z, m, logs, x_mask
|
102 |
+
|
103 |
+
|
104 |
+
class TextEncoder(nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
out_channels,
|
108 |
+
hidden_channels,
|
109 |
+
kernel_size,
|
110 |
+
n_layers,
|
111 |
+
gin_channels=0,
|
112 |
+
filter_channels=None,
|
113 |
+
n_heads=None,
|
114 |
+
p_dropout=None,
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.out_channels = out_channels
|
118 |
+
self.hidden_channels = hidden_channels
|
119 |
+
self.kernel_size = kernel_size
|
120 |
+
self.n_layers = n_layers
|
121 |
+
self.gin_channels = gin_channels
|
122 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
123 |
+
self.f0_emb = nn.Embedding(256, hidden_channels)
|
124 |
+
|
125 |
+
self.enc_ = attentions.Encoder(
|
126 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x, x_mask, f0=None, noice_scale=1):
|
130 |
+
x = x + self.f0_emb(f0).transpose(1, 2)
|
131 |
+
x = self.enc_(x * x_mask, x_mask)
|
132 |
+
stats = self.proj(x) * x_mask
|
133 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
134 |
+
z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
|
135 |
+
|
136 |
+
return z, m, logs, x_mask
|
so_vits_svc_fork/modules/flows.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from so_vits_svc_fork.modules import modules as modules
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualCouplingBlock(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
channels,
|
10 |
+
hidden_channels,
|
11 |
+
kernel_size,
|
12 |
+
dilation_rate,
|
13 |
+
n_layers,
|
14 |
+
n_flows=4,
|
15 |
+
gin_channels=0,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.channels = channels
|
19 |
+
self.hidden_channels = hidden_channels
|
20 |
+
self.kernel_size = kernel_size
|
21 |
+
self.dilation_rate = dilation_rate
|
22 |
+
self.n_layers = n_layers
|
23 |
+
self.n_flows = n_flows
|
24 |
+
self.gin_channels = gin_channels
|
25 |
+
|
26 |
+
self.flows = nn.ModuleList()
|
27 |
+
for i in range(n_flows):
|
28 |
+
self.flows.append(
|
29 |
+
modules.ResidualCouplingLayer(
|
30 |
+
channels,
|
31 |
+
hidden_channels,
|
32 |
+
kernel_size,
|
33 |
+
dilation_rate,
|
34 |
+
n_layers,
|
35 |
+
gin_channels=gin_channels,
|
36 |
+
mean_only=True,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
self.flows.append(modules.Flip())
|
40 |
+
|
41 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
42 |
+
if not reverse:
|
43 |
+
for flow in self.flows:
|
44 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
45 |
+
else:
|
46 |
+
for flow in reversed(self.flows):
|
47 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
48 |
+
return x
|
so_vits_svc_fork/modules/losses.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def feature_loss(fmap_r, fmap_g):
|
5 |
+
loss = 0
|
6 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
7 |
+
for rl, gl in zip(dr, dg):
|
8 |
+
rl = rl.float().detach()
|
9 |
+
gl = gl.float()
|
10 |
+
loss += torch.mean(torch.abs(rl - gl))
|
11 |
+
|
12 |
+
return loss * 2
|
13 |
+
|
14 |
+
|
15 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
16 |
+
loss = 0
|
17 |
+
r_losses = []
|
18 |
+
g_losses = []
|
19 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
20 |
+
dr = dr.float()
|
21 |
+
dg = dg.float()
|
22 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
23 |
+
g_loss = torch.mean(dg**2)
|
24 |
+
loss += r_loss + g_loss
|
25 |
+
r_losses.append(r_loss.item())
|
26 |
+
g_losses.append(g_loss.item())
|
27 |
+
|
28 |
+
return loss, r_losses, g_losses
|
29 |
+
|
30 |
+
|
31 |
+
def generator_loss(disc_outputs):
|
32 |
+
loss = 0
|
33 |
+
gen_losses = []
|
34 |
+
for dg in disc_outputs:
|
35 |
+
dg = dg.float()
|
36 |
+
l = torch.mean((1 - dg) ** 2)
|
37 |
+
gen_losses.append(l)
|
38 |
+
loss += l
|
39 |
+
|
40 |
+
return loss, gen_losses
|
41 |
+
|
42 |
+
|
43 |
+
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
44 |
+
"""
|
45 |
+
z_p, logs_q: [b, h, t_t]
|
46 |
+
m_p, logs_p: [b, h, t_t]
|
47 |
+
"""
|
48 |
+
z_p = z_p.float()
|
49 |
+
logs_q = logs_q.float()
|
50 |
+
m_p = m_p.float()
|
51 |
+
logs_p = logs_p.float()
|
52 |
+
z_mask = z_mask.float()
|
53 |
+
# print(logs_p)
|
54 |
+
kl = logs_p - logs_q - 0.5
|
55 |
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
56 |
+
kl = torch.sum(kl * z_mask)
|
57 |
+
l = kl / torch.sum(z_mask)
|
58 |
+
return l
|
so_vits_svc_fork/modules/mel_processing.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""from logging import getLogger
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.utils.data
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
LOG = getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
from ..hparams import HParams
|
11 |
+
|
12 |
+
|
13 |
+
def spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
|
14 |
+
return torchaudio.transforms.Spectrogram(
|
15 |
+
n_fft=hps.data.filter_length,
|
16 |
+
win_length=hps.data.win_length,
|
17 |
+
hop_length=hps.data.hop_length,
|
18 |
+
power=1.0,
|
19 |
+
window_fn=torch.hann_window,
|
20 |
+
normalized=False,
|
21 |
+
).to(audio.device)(audio)
|
22 |
+
|
23 |
+
|
24 |
+
def spec_to_mel_torch(spec: torch.Tensor, hps: HParams) -> torch.Tensor:
|
25 |
+
return torchaudio.transforms.MelScale(
|
26 |
+
n_mels=hps.data.n_mel_channels,
|
27 |
+
sample_rate=hps.data.sampling_rate,
|
28 |
+
f_min=hps.data.mel_fmin,
|
29 |
+
f_max=hps.data.mel_fmax,
|
30 |
+
).to(spec.device)(spec)
|
31 |
+
|
32 |
+
|
33 |
+
def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
|
34 |
+
return torchaudio.transforms.MelSpectrogram(
|
35 |
+
sample_rate=hps.data.sampling_rate,
|
36 |
+
n_fft=hps.data.filter_length,
|
37 |
+
n_mels=hps.data.n_mel_channels,
|
38 |
+
win_length=hps.data.win_length,
|
39 |
+
hop_length=hps.data.hop_length,
|
40 |
+
f_min=hps.data.mel_fmin,
|
41 |
+
f_max=hps.data.mel_fmax,
|
42 |
+
power=1.0,
|
43 |
+
window_fn=torch.hann_window,
|
44 |
+
normalized=False,
|
45 |
+
).to(audio.device)(audio)"""
|
46 |
+
|
47 |
+
from logging import getLogger
|
48 |
+
|
49 |
+
import torch
|
50 |
+
import torch.utils.data
|
51 |
+
from librosa.filters import mel as librosa_mel_fn
|
52 |
+
|
53 |
+
LOG = getLogger(__name__)
|
54 |
+
|
55 |
+
MAX_WAV_VALUE = 32768.0
|
56 |
+
|
57 |
+
|
58 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
59 |
+
"""
|
60 |
+
PARAMS
|
61 |
+
------
|
62 |
+
C: compression factor
|
63 |
+
"""
|
64 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
65 |
+
|
66 |
+
|
67 |
+
def dynamic_range_decompression_torch(x, C=1):
|
68 |
+
"""
|
69 |
+
PARAMS
|
70 |
+
------
|
71 |
+
C: compression factor used to compress
|
72 |
+
"""
|
73 |
+
return torch.exp(x) / C
|
74 |
+
|
75 |
+
|
76 |
+
def spectral_normalize_torch(magnitudes):
|
77 |
+
output = dynamic_range_compression_torch(magnitudes)
|
78 |
+
return output
|
79 |
+
|
80 |
+
|
81 |
+
def spectral_de_normalize_torch(magnitudes):
|
82 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
mel_basis = {}
|
87 |
+
hann_window = {}
|
88 |
+
|
89 |
+
|
90 |
+
def spectrogram_torch(y, hps, center=False):
|
91 |
+
if torch.min(y) < -1.0:
|
92 |
+
LOG.info("min value is ", torch.min(y))
|
93 |
+
if torch.max(y) > 1.0:
|
94 |
+
LOG.info("max value is ", torch.max(y))
|
95 |
+
n_fft = hps.data.filter_length
|
96 |
+
hop_size = hps.data.hop_length
|
97 |
+
win_size = hps.data.win_length
|
98 |
+
global hann_window
|
99 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
100 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
101 |
+
if wnsize_dtype_device not in hann_window:
|
102 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
103 |
+
dtype=y.dtype, device=y.device
|
104 |
+
)
|
105 |
+
|
106 |
+
y = torch.nn.functional.pad(
|
107 |
+
y.unsqueeze(1),
|
108 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
109 |
+
mode="reflect",
|
110 |
+
)
|
111 |
+
y = y.squeeze(1)
|
112 |
+
|
113 |
+
spec = torch.stft(
|
114 |
+
y,
|
115 |
+
n_fft,
|
116 |
+
hop_length=hop_size,
|
117 |
+
win_length=win_size,
|
118 |
+
window=hann_window[wnsize_dtype_device],
|
119 |
+
center=center,
|
120 |
+
pad_mode="reflect",
|
121 |
+
normalized=False,
|
122 |
+
onesided=True,
|
123 |
+
return_complex=False,
|
124 |
+
)
|
125 |
+
|
126 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
127 |
+
return spec
|
128 |
+
|
129 |
+
|
130 |
+
def spec_to_mel_torch(spec, hps):
|
131 |
+
sampling_rate = hps.data.sampling_rate
|
132 |
+
n_fft = hps.data.filter_length
|
133 |
+
num_mels = hps.data.n_mel_channels
|
134 |
+
fmin = hps.data.mel_fmin
|
135 |
+
fmax = hps.data.mel_fmax
|
136 |
+
global mel_basis
|
137 |
+
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
138 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
139 |
+
if fmax_dtype_device not in mel_basis:
|
140 |
+
mel = librosa_mel_fn(
|
141 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
142 |
+
)
|
143 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
144 |
+
dtype=spec.dtype, device=spec.device
|
145 |
+
)
|
146 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
147 |
+
spec = spectral_normalize_torch(spec)
|
148 |
+
return spec
|
149 |
+
|
150 |
+
|
151 |
+
def mel_spectrogram_torch(y, hps, center=False):
|
152 |
+
sampling_rate = hps.data.sampling_rate
|
153 |
+
n_fft = hps.data.filter_length
|
154 |
+
num_mels = hps.data.n_mel_channels
|
155 |
+
fmin = hps.data.mel_fmin
|
156 |
+
fmax = hps.data.mel_fmax
|
157 |
+
hop_size = hps.data.hop_length
|
158 |
+
win_size = hps.data.win_length
|
159 |
+
if torch.min(y) < -1.0:
|
160 |
+
LOG.info(f"min value is {torch.min(y)}")
|
161 |
+
if torch.max(y) > 1.0:
|
162 |
+
LOG.info(f"max value is {torch.max(y)}")
|
163 |
+
|
164 |
+
global mel_basis, hann_window
|
165 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
166 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
167 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
168 |
+
if fmax_dtype_device not in mel_basis:
|
169 |
+
mel = librosa_mel_fn(
|
170 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
171 |
+
)
|
172 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
173 |
+
dtype=y.dtype, device=y.device
|
174 |
+
)
|
175 |
+
if wnsize_dtype_device not in hann_window:
|
176 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
177 |
+
dtype=y.dtype, device=y.device
|
178 |
+
)
|
179 |
+
|
180 |
+
y = torch.nn.functional.pad(
|
181 |
+
y.unsqueeze(1),
|
182 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
183 |
+
mode="reflect",
|
184 |
+
)
|
185 |
+
y = y.squeeze(1)
|
186 |
+
|
187 |
+
spec = torch.stft(
|
188 |
+
y,
|
189 |
+
n_fft,
|
190 |
+
hop_length=hop_size,
|
191 |
+
win_length=win_size,
|
192 |
+
window=hann_window[wnsize_dtype_device],
|
193 |
+
center=center,
|
194 |
+
pad_mode="reflect",
|
195 |
+
normalized=False,
|
196 |
+
onesided=True,
|
197 |
+
return_complex=False,
|
198 |
+
)
|
199 |
+
|
200 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
201 |
+
|
202 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
203 |
+
spec = spectral_normalize_torch(spec)
|
204 |
+
|
205 |
+
return spec
|
so_vits_svc_fork/modules/modules.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import Conv1d
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
6 |
+
|
7 |
+
from so_vits_svc_fork.modules import commons
|
8 |
+
from so_vits_svc_fork.modules.commons import get_padding, init_weights
|
9 |
+
|
10 |
+
LRELU_SLOPE = 0.1
|
11 |
+
|
12 |
+
|
13 |
+
class LayerNorm(nn.Module):
|
14 |
+
def __init__(self, channels, eps=1e-5):
|
15 |
+
super().__init__()
|
16 |
+
self.channels = channels
|
17 |
+
self.eps = eps
|
18 |
+
|
19 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
20 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = x.transpose(1, -1)
|
24 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
25 |
+
return x.transpose(1, -1)
|
26 |
+
|
27 |
+
|
28 |
+
class ConvReluNorm(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
in_channels,
|
32 |
+
hidden_channels,
|
33 |
+
out_channels,
|
34 |
+
kernel_size,
|
35 |
+
n_layers,
|
36 |
+
p_dropout,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.in_channels = in_channels
|
40 |
+
self.hidden_channels = hidden_channels
|
41 |
+
self.out_channels = out_channels
|
42 |
+
self.kernel_size = kernel_size
|
43 |
+
self.n_layers = n_layers
|
44 |
+
self.p_dropout = p_dropout
|
45 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
46 |
+
|
47 |
+
self.conv_layers = nn.ModuleList()
|
48 |
+
self.norm_layers = nn.ModuleList()
|
49 |
+
self.conv_layers.append(
|
50 |
+
nn.Conv1d(
|
51 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
52 |
+
)
|
53 |
+
)
|
54 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
55 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
56 |
+
for _ in range(n_layers - 1):
|
57 |
+
self.conv_layers.append(
|
58 |
+
nn.Conv1d(
|
59 |
+
hidden_channels,
|
60 |
+
hidden_channels,
|
61 |
+
kernel_size,
|
62 |
+
padding=kernel_size // 2,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
66 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
67 |
+
self.proj.weight.data.zero_()
|
68 |
+
self.proj.bias.data.zero_()
|
69 |
+
|
70 |
+
def forward(self, x, x_mask):
|
71 |
+
x_org = x
|
72 |
+
for i in range(self.n_layers):
|
73 |
+
x = self.conv_layers[i](x * x_mask)
|
74 |
+
x = self.norm_layers[i](x)
|
75 |
+
x = self.relu_drop(x)
|
76 |
+
x = x_org + self.proj(x)
|
77 |
+
return x * x_mask
|
78 |
+
|
79 |
+
|
80 |
+
class DDSConv(nn.Module):
|
81 |
+
"""
|
82 |
+
Dialted and Depth-Separable Convolution
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
86 |
+
super().__init__()
|
87 |
+
self.channels = channels
|
88 |
+
self.kernel_size = kernel_size
|
89 |
+
self.n_layers = n_layers
|
90 |
+
self.p_dropout = p_dropout
|
91 |
+
|
92 |
+
self.drop = nn.Dropout(p_dropout)
|
93 |
+
self.convs_sep = nn.ModuleList()
|
94 |
+
self.convs_1x1 = nn.ModuleList()
|
95 |
+
self.norms_1 = nn.ModuleList()
|
96 |
+
self.norms_2 = nn.ModuleList()
|
97 |
+
for i in range(n_layers):
|
98 |
+
dilation = kernel_size**i
|
99 |
+
padding = (kernel_size * dilation - dilation) // 2
|
100 |
+
self.convs_sep.append(
|
101 |
+
nn.Conv1d(
|
102 |
+
channels,
|
103 |
+
channels,
|
104 |
+
kernel_size,
|
105 |
+
groups=channels,
|
106 |
+
dilation=dilation,
|
107 |
+
padding=padding,
|
108 |
+
)
|
109 |
+
)
|
110 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
111 |
+
self.norms_1.append(LayerNorm(channels))
|
112 |
+
self.norms_2.append(LayerNorm(channels))
|
113 |
+
|
114 |
+
def forward(self, x, x_mask, g=None):
|
115 |
+
if g is not None:
|
116 |
+
x = x + g
|
117 |
+
for i in range(self.n_layers):
|
118 |
+
y = self.convs_sep[i](x * x_mask)
|
119 |
+
y = self.norms_1[i](y)
|
120 |
+
y = F.gelu(y)
|
121 |
+
y = self.convs_1x1[i](y)
|
122 |
+
y = self.norms_2[i](y)
|
123 |
+
y = F.gelu(y)
|
124 |
+
y = self.drop(y)
|
125 |
+
x = x + y
|
126 |
+
return x * x_mask
|
127 |
+
|
128 |
+
|
129 |
+
class WN(torch.nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
hidden_channels,
|
133 |
+
kernel_size,
|
134 |
+
dilation_rate,
|
135 |
+
n_layers,
|
136 |
+
gin_channels=0,
|
137 |
+
p_dropout=0,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
assert kernel_size % 2 == 1
|
141 |
+
self.hidden_channels = hidden_channels
|
142 |
+
self.kernel_size = (kernel_size,)
|
143 |
+
self.dilation_rate = dilation_rate
|
144 |
+
self.n_layers = n_layers
|
145 |
+
self.gin_channels = gin_channels
|
146 |
+
self.p_dropout = p_dropout
|
147 |
+
|
148 |
+
self.in_layers = torch.nn.ModuleList()
|
149 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
150 |
+
self.drop = nn.Dropout(p_dropout)
|
151 |
+
|
152 |
+
if gin_channels != 0:
|
153 |
+
cond_layer = torch.nn.Conv1d(
|
154 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
155 |
+
)
|
156 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
157 |
+
|
158 |
+
for i in range(n_layers):
|
159 |
+
dilation = dilation_rate**i
|
160 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
161 |
+
in_layer = torch.nn.Conv1d(
|
162 |
+
hidden_channels,
|
163 |
+
2 * hidden_channels,
|
164 |
+
kernel_size,
|
165 |
+
dilation=dilation,
|
166 |
+
padding=padding,
|
167 |
+
)
|
168 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
169 |
+
self.in_layers.append(in_layer)
|
170 |
+
|
171 |
+
# last one is not necessary
|
172 |
+
if i < n_layers - 1:
|
173 |
+
res_skip_channels = 2 * hidden_channels
|
174 |
+
else:
|
175 |
+
res_skip_channels = hidden_channels
|
176 |
+
|
177 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
178 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
179 |
+
self.res_skip_layers.append(res_skip_layer)
|
180 |
+
|
181 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
182 |
+
output = torch.zeros_like(x)
|
183 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
184 |
+
|
185 |
+
if g is not None:
|
186 |
+
g = self.cond_layer(g)
|
187 |
+
|
188 |
+
for i in range(self.n_layers):
|
189 |
+
x_in = self.in_layers[i](x)
|
190 |
+
if g is not None:
|
191 |
+
cond_offset = i * 2 * self.hidden_channels
|
192 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
193 |
+
else:
|
194 |
+
g_l = torch.zeros_like(x_in)
|
195 |
+
|
196 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
197 |
+
acts = self.drop(acts)
|
198 |
+
|
199 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
200 |
+
if i < self.n_layers - 1:
|
201 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
202 |
+
x = (x + res_acts) * x_mask
|
203 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
204 |
+
else:
|
205 |
+
output = output + res_skip_acts
|
206 |
+
return output * x_mask
|
207 |
+
|
208 |
+
def remove_weight_norm(self):
|
209 |
+
if self.gin_channels != 0:
|
210 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
211 |
+
for l in self.in_layers:
|
212 |
+
torch.nn.utils.remove_weight_norm(l)
|
213 |
+
for l in self.res_skip_layers:
|
214 |
+
torch.nn.utils.remove_weight_norm(l)
|
215 |
+
|
216 |
+
|
217 |
+
class ResBlock1(torch.nn.Module):
|
218 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
219 |
+
super().__init__()
|
220 |
+
self.convs1 = nn.ModuleList(
|
221 |
+
[
|
222 |
+
weight_norm(
|
223 |
+
Conv1d(
|
224 |
+
channels,
|
225 |
+
channels,
|
226 |
+
kernel_size,
|
227 |
+
1,
|
228 |
+
dilation=dilation[0],
|
229 |
+
padding=get_padding(kernel_size, dilation[0]),
|
230 |
+
)
|
231 |
+
),
|
232 |
+
weight_norm(
|
233 |
+
Conv1d(
|
234 |
+
channels,
|
235 |
+
channels,
|
236 |
+
kernel_size,
|
237 |
+
1,
|
238 |
+
dilation=dilation[1],
|
239 |
+
padding=get_padding(kernel_size, dilation[1]),
|
240 |
+
)
|
241 |
+
),
|
242 |
+
weight_norm(
|
243 |
+
Conv1d(
|
244 |
+
channels,
|
245 |
+
channels,
|
246 |
+
kernel_size,
|
247 |
+
1,
|
248 |
+
dilation=dilation[2],
|
249 |
+
padding=get_padding(kernel_size, dilation[2]),
|
250 |
+
)
|
251 |
+
),
|
252 |
+
]
|
253 |
+
)
|
254 |
+
self.convs1.apply(init_weights)
|
255 |
+
|
256 |
+
self.convs2 = nn.ModuleList(
|
257 |
+
[
|
258 |
+
weight_norm(
|
259 |
+
Conv1d(
|
260 |
+
channels,
|
261 |
+
channels,
|
262 |
+
kernel_size,
|
263 |
+
1,
|
264 |
+
dilation=1,
|
265 |
+
padding=get_padding(kernel_size, 1),
|
266 |
+
)
|
267 |
+
),
|
268 |
+
weight_norm(
|
269 |
+
Conv1d(
|
270 |
+
channels,
|
271 |
+
channels,
|
272 |
+
kernel_size,
|
273 |
+
1,
|
274 |
+
dilation=1,
|
275 |
+
padding=get_padding(kernel_size, 1),
|
276 |
+
)
|
277 |
+
),
|
278 |
+
weight_norm(
|
279 |
+
Conv1d(
|
280 |
+
channels,
|
281 |
+
channels,
|
282 |
+
kernel_size,
|
283 |
+
1,
|
284 |
+
dilation=1,
|
285 |
+
padding=get_padding(kernel_size, 1),
|
286 |
+
)
|
287 |
+
),
|
288 |
+
]
|
289 |
+
)
|
290 |
+
self.convs2.apply(init_weights)
|
291 |
+
|
292 |
+
def forward(self, x, x_mask=None):
|
293 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
294 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
295 |
+
if x_mask is not None:
|
296 |
+
xt = xt * x_mask
|
297 |
+
xt = c1(xt)
|
298 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
299 |
+
if x_mask is not None:
|
300 |
+
xt = xt * x_mask
|
301 |
+
xt = c2(xt)
|
302 |
+
x = xt + x
|
303 |
+
if x_mask is not None:
|
304 |
+
x = x * x_mask
|
305 |
+
return x
|
306 |
+
|
307 |
+
def remove_weight_norm(self):
|
308 |
+
for l in self.convs1:
|
309 |
+
remove_weight_norm(l)
|
310 |
+
for l in self.convs2:
|
311 |
+
remove_weight_norm(l)
|
312 |
+
|
313 |
+
|
314 |
+
class ResBlock2(torch.nn.Module):
|
315 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
316 |
+
super().__init__()
|
317 |
+
self.convs = nn.ModuleList(
|
318 |
+
[
|
319 |
+
weight_norm(
|
320 |
+
Conv1d(
|
321 |
+
channels,
|
322 |
+
channels,
|
323 |
+
kernel_size,
|
324 |
+
1,
|
325 |
+
dilation=dilation[0],
|
326 |
+
padding=get_padding(kernel_size, dilation[0]),
|
327 |
+
)
|
328 |
+
),
|
329 |
+
weight_norm(
|
330 |
+
Conv1d(
|
331 |
+
channels,
|
332 |
+
channels,
|
333 |
+
kernel_size,
|
334 |
+
1,
|
335 |
+
dilation=dilation[1],
|
336 |
+
padding=get_padding(kernel_size, dilation[1]),
|
337 |
+
)
|
338 |
+
),
|
339 |
+
]
|
340 |
+
)
|
341 |
+
self.convs.apply(init_weights)
|
342 |
+
|
343 |
+
def forward(self, x, x_mask=None):
|
344 |
+
for c in self.convs:
|
345 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
346 |
+
if x_mask is not None:
|
347 |
+
xt = xt * x_mask
|
348 |
+
xt = c(xt)
|
349 |
+
x = xt + x
|
350 |
+
if x_mask is not None:
|
351 |
+
x = x * x_mask
|
352 |
+
return x
|
353 |
+
|
354 |
+
def remove_weight_norm(self):
|
355 |
+
for l in self.convs:
|
356 |
+
remove_weight_norm(l)
|
357 |
+
|
358 |
+
|
359 |
+
class Log(nn.Module):
|
360 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
361 |
+
if not reverse:
|
362 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
363 |
+
logdet = torch.sum(-y, [1, 2])
|
364 |
+
return y, logdet
|
365 |
+
else:
|
366 |
+
x = torch.exp(x) * x_mask
|
367 |
+
return x
|
368 |
+
|
369 |
+
|
370 |
+
class Flip(nn.Module):
|
371 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
372 |
+
x = torch.flip(x, [1])
|
373 |
+
if not reverse:
|
374 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
375 |
+
return x, logdet
|
376 |
+
else:
|
377 |
+
return x
|
378 |
+
|
379 |
+
|
380 |
+
class ElementwiseAffine(nn.Module):
|
381 |
+
def __init__(self, channels):
|
382 |
+
super().__init__()
|
383 |
+
self.channels = channels
|
384 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
385 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
386 |
+
|
387 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
388 |
+
if not reverse:
|
389 |
+
y = self.m + torch.exp(self.logs) * x
|
390 |
+
y = y * x_mask
|
391 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
392 |
+
return y, logdet
|
393 |
+
else:
|
394 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
395 |
+
return x
|
396 |
+
|
397 |
+
|
398 |
+
class ResidualCouplingLayer(nn.Module):
|
399 |
+
def __init__(
|
400 |
+
self,
|
401 |
+
channels,
|
402 |
+
hidden_channels,
|
403 |
+
kernel_size,
|
404 |
+
dilation_rate,
|
405 |
+
n_layers,
|
406 |
+
p_dropout=0,
|
407 |
+
gin_channels=0,
|
408 |
+
mean_only=False,
|
409 |
+
):
|
410 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
411 |
+
super().__init__()
|
412 |
+
self.channels = channels
|
413 |
+
self.hidden_channels = hidden_channels
|
414 |
+
self.kernel_size = kernel_size
|
415 |
+
self.dilation_rate = dilation_rate
|
416 |
+
self.n_layers = n_layers
|
417 |
+
self.half_channels = channels // 2
|
418 |
+
self.mean_only = mean_only
|
419 |
+
|
420 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
421 |
+
self.enc = WN(
|
422 |
+
hidden_channels,
|
423 |
+
kernel_size,
|
424 |
+
dilation_rate,
|
425 |
+
n_layers,
|
426 |
+
p_dropout=p_dropout,
|
427 |
+
gin_channels=gin_channels,
|
428 |
+
)
|
429 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
430 |
+
self.post.weight.data.zero_()
|
431 |
+
self.post.bias.data.zero_()
|
432 |
+
|
433 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
434 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
435 |
+
h = self.pre(x0) * x_mask
|
436 |
+
h = self.enc(h, x_mask, g=g)
|
437 |
+
stats = self.post(h) * x_mask
|
438 |
+
if not self.mean_only:
|
439 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
440 |
+
else:
|
441 |
+
m = stats
|
442 |
+
logs = torch.zeros_like(m)
|
443 |
+
|
444 |
+
if not reverse:
|
445 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
446 |
+
x = torch.cat([x0, x1], 1)
|
447 |
+
logdet = torch.sum(logs, [1, 2])
|
448 |
+
return x, logdet
|
449 |
+
else:
|
450 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
451 |
+
x = torch.cat([x0, x1], 1)
|
452 |
+
return x
|
so_vits_svc_fork/modules/synthesizers.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from logging import getLogger
|
3 |
+
from typing import Any, Literal, Sequence
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import so_vits_svc_fork.f0
|
9 |
+
from so_vits_svc_fork.f0 import f0_to_coarse
|
10 |
+
from so_vits_svc_fork.modules import commons as commons
|
11 |
+
from so_vits_svc_fork.modules.decoders.f0 import F0Decoder
|
12 |
+
from so_vits_svc_fork.modules.decoders.hifigan import NSFHifiGANGenerator
|
13 |
+
from so_vits_svc_fork.modules.decoders.mb_istft import (
|
14 |
+
Multiband_iSTFT_Generator,
|
15 |
+
Multistream_iSTFT_Generator,
|
16 |
+
iSTFT_Generator,
|
17 |
+
)
|
18 |
+
from so_vits_svc_fork.modules.encoders import Encoder, TextEncoder
|
19 |
+
from so_vits_svc_fork.modules.flows import ResidualCouplingBlock
|
20 |
+
|
21 |
+
LOG = getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class SynthesizerTrn(nn.Module):
|
25 |
+
"""
|
26 |
+
Synthesizer for Training
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
spec_channels: int,
|
32 |
+
segment_size: int,
|
33 |
+
inter_channels: int,
|
34 |
+
hidden_channels: int,
|
35 |
+
filter_channels: int,
|
36 |
+
n_heads: int,
|
37 |
+
n_layers: int,
|
38 |
+
kernel_size: int,
|
39 |
+
p_dropout: int,
|
40 |
+
resblock: str,
|
41 |
+
resblock_kernel_sizes: Sequence[int],
|
42 |
+
resblock_dilation_sizes: Sequence[Sequence[int]],
|
43 |
+
upsample_rates: Sequence[int],
|
44 |
+
upsample_initial_channel: int,
|
45 |
+
upsample_kernel_sizes: Sequence[int],
|
46 |
+
gin_channels: int,
|
47 |
+
ssl_dim: int,
|
48 |
+
n_speakers: int,
|
49 |
+
sampling_rate: int = 44100,
|
50 |
+
type_: Literal["hifi-gan", "istft", "ms-istft", "mb-istft"] = "hifi-gan",
|
51 |
+
gen_istft_n_fft: int = 16,
|
52 |
+
gen_istft_hop_size: int = 4,
|
53 |
+
subbands: int = 4,
|
54 |
+
**kwargs: Any,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.spec_channels = spec_channels
|
58 |
+
self.inter_channels = inter_channels
|
59 |
+
self.hidden_channels = hidden_channels
|
60 |
+
self.filter_channels = filter_channels
|
61 |
+
self.n_heads = n_heads
|
62 |
+
self.n_layers = n_layers
|
63 |
+
self.kernel_size = kernel_size
|
64 |
+
self.p_dropout = p_dropout
|
65 |
+
self.resblock = resblock
|
66 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
67 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
68 |
+
self.upsample_rates = upsample_rates
|
69 |
+
self.upsample_initial_channel = upsample_initial_channel
|
70 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
71 |
+
self.segment_size = segment_size
|
72 |
+
self.gin_channels = gin_channels
|
73 |
+
self.ssl_dim = ssl_dim
|
74 |
+
self.n_speakers = n_speakers
|
75 |
+
self.sampling_rate = sampling_rate
|
76 |
+
self.type_ = type_
|
77 |
+
self.gen_istft_n_fft = gen_istft_n_fft
|
78 |
+
self.gen_istft_hop_size = gen_istft_hop_size
|
79 |
+
self.subbands = subbands
|
80 |
+
if kwargs:
|
81 |
+
warnings.warn(f"Unused arguments: {kwargs}")
|
82 |
+
|
83 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
84 |
+
|
85 |
+
if ssl_dim is None:
|
86 |
+
self.pre = nn.LazyConv1d(hidden_channels, kernel_size=5, padding=2)
|
87 |
+
else:
|
88 |
+
self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
|
89 |
+
|
90 |
+
self.enc_p = TextEncoder(
|
91 |
+
inter_channels,
|
92 |
+
hidden_channels,
|
93 |
+
filter_channels=filter_channels,
|
94 |
+
n_heads=n_heads,
|
95 |
+
n_layers=n_layers,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
p_dropout=p_dropout,
|
98 |
+
)
|
99 |
+
|
100 |
+
LOG.info(f"Decoder type: {type_}")
|
101 |
+
if type_ == "hifi-gan":
|
102 |
+
hps = {
|
103 |
+
"sampling_rate": sampling_rate,
|
104 |
+
"inter_channels": inter_channels,
|
105 |
+
"resblock": resblock,
|
106 |
+
"resblock_kernel_sizes": resblock_kernel_sizes,
|
107 |
+
"resblock_dilation_sizes": resblock_dilation_sizes,
|
108 |
+
"upsample_rates": upsample_rates,
|
109 |
+
"upsample_initial_channel": upsample_initial_channel,
|
110 |
+
"upsample_kernel_sizes": upsample_kernel_sizes,
|
111 |
+
"gin_channels": gin_channels,
|
112 |
+
}
|
113 |
+
self.dec = NSFHifiGANGenerator(h=hps)
|
114 |
+
self.mb = False
|
115 |
+
else:
|
116 |
+
hps = {
|
117 |
+
"initial_channel": inter_channels,
|
118 |
+
"resblock": resblock,
|
119 |
+
"resblock_kernel_sizes": resblock_kernel_sizes,
|
120 |
+
"resblock_dilation_sizes": resblock_dilation_sizes,
|
121 |
+
"upsample_rates": upsample_rates,
|
122 |
+
"upsample_initial_channel": upsample_initial_channel,
|
123 |
+
"upsample_kernel_sizes": upsample_kernel_sizes,
|
124 |
+
"gin_channels": gin_channels,
|
125 |
+
"gen_istft_n_fft": gen_istft_n_fft,
|
126 |
+
"gen_istft_hop_size": gen_istft_hop_size,
|
127 |
+
"subbands": subbands,
|
128 |
+
}
|
129 |
+
|
130 |
+
# gen_istft_n_fft, gen_istft_hop_size, subbands
|
131 |
+
if type_ == "istft":
|
132 |
+
del hps["subbands"]
|
133 |
+
self.dec = iSTFT_Generator(**hps)
|
134 |
+
elif type_ == "ms-istft":
|
135 |
+
self.dec = Multistream_iSTFT_Generator(**hps)
|
136 |
+
elif type_ == "mb-istft":
|
137 |
+
self.dec = Multiband_iSTFT_Generator(**hps)
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Unknown type: {type_}")
|
140 |
+
self.mb = True
|
141 |
+
|
142 |
+
self.enc_q = Encoder(
|
143 |
+
spec_channels,
|
144 |
+
inter_channels,
|
145 |
+
hidden_channels,
|
146 |
+
5,
|
147 |
+
1,
|
148 |
+
16,
|
149 |
+
gin_channels=gin_channels,
|
150 |
+
)
|
151 |
+
self.flow = ResidualCouplingBlock(
|
152 |
+
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
153 |
+
)
|
154 |
+
self.f0_decoder = F0Decoder(
|
155 |
+
1,
|
156 |
+
hidden_channels,
|
157 |
+
filter_channels,
|
158 |
+
n_heads,
|
159 |
+
n_layers,
|
160 |
+
kernel_size,
|
161 |
+
p_dropout,
|
162 |
+
spk_channels=gin_channels,
|
163 |
+
)
|
164 |
+
self.emb_uv = nn.Embedding(2, hidden_channels)
|
165 |
+
|
166 |
+
def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
|
167 |
+
g = self.emb_g(g).transpose(1, 2)
|
168 |
+
# ssl prenet
|
169 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
|
170 |
+
c.dtype
|
171 |
+
)
|
172 |
+
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
|
173 |
+
|
174 |
+
# f0 predict
|
175 |
+
lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
|
176 |
+
norm_lf0 = so_vits_svc_fork.f0.normalize_f0(lf0, x_mask, uv)
|
177 |
+
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
|
178 |
+
|
179 |
+
# encoder
|
180 |
+
z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
|
181 |
+
z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
|
182 |
+
|
183 |
+
# flow
|
184 |
+
z_p = self.flow(z, spec_mask, g=g)
|
185 |
+
z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(
|
186 |
+
z, f0, spec_lengths, self.segment_size
|
187 |
+
)
|
188 |
+
|
189 |
+
# MB-iSTFT-VITS
|
190 |
+
if self.mb:
|
191 |
+
o, o_mb = self.dec(z_slice, g=g)
|
192 |
+
# HiFi-GAN
|
193 |
+
else:
|
194 |
+
o = self.dec(z_slice, g=g, f0=pitch_slice)
|
195 |
+
o_mb = None
|
196 |
+
return (
|
197 |
+
o,
|
198 |
+
o_mb,
|
199 |
+
ids_slice,
|
200 |
+
spec_mask,
|
201 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
202 |
+
pred_lf0,
|
203 |
+
norm_lf0,
|
204 |
+
lf0,
|
205 |
+
)
|
206 |
+
|
207 |
+
def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
|
208 |
+
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
209 |
+
g = self.emb_g(g).transpose(1, 2)
|
210 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
|
211 |
+
c.dtype
|
212 |
+
)
|
213 |
+
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
|
214 |
+
|
215 |
+
if predict_f0:
|
216 |
+
lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
|
217 |
+
norm_lf0 = so_vits_svc_fork.f0.normalize_f0(
|
218 |
+
lf0, x_mask, uv, random_scale=False
|
219 |
+
)
|
220 |
+
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
|
221 |
+
f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
|
222 |
+
|
223 |
+
z_p, m_p, logs_p, c_mask = self.enc_p(
|
224 |
+
x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale
|
225 |
+
)
|
226 |
+
z = self.flow(z_p, c_mask, g=g, reverse=True)
|
227 |
+
|
228 |
+
# MB-iSTFT-VITS
|
229 |
+
if self.mb:
|
230 |
+
o, o_mb = self.dec(z * c_mask, g=g)
|
231 |
+
else:
|
232 |
+
o = self.dec(z * c_mask, g=g, f0=f0)
|
233 |
+
return o
|
so_vits_svc_fork/preprocessing/__init__.py
ADDED
File without changes
|
so_vits_svc_fork/preprocessing/config_templates/quickvc.json
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 100,
|
4 |
+
"eval_interval": 200,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 0.0001,
|
8 |
+
"betas": [0.8, 0.99],
|
9 |
+
"eps": 1e-9,
|
10 |
+
"batch_size": 16,
|
11 |
+
"fp16_run": false,
|
12 |
+
"bf16_run": false,
|
13 |
+
"lr_decay": 0.999875,
|
14 |
+
"segment_size": 10240,
|
15 |
+
"init_lr_ratio": 1,
|
16 |
+
"warmup_epochs": 0,
|
17 |
+
"c_mel": 45,
|
18 |
+
"c_kl": 1.0,
|
19 |
+
"use_sr": true,
|
20 |
+
"max_speclen": 512,
|
21 |
+
"port": "8001",
|
22 |
+
"keep_ckpts": 3,
|
23 |
+
"fft_sizes": [768, 1366, 342],
|
24 |
+
"hop_sizes": [60, 120, 20],
|
25 |
+
"win_lengths": [300, 600, 120],
|
26 |
+
"window": "hann_window",
|
27 |
+
"num_workers": 4,
|
28 |
+
"log_version": 0,
|
29 |
+
"ckpt_name_by_step": false,
|
30 |
+
"accumulate_grad_batches": 1
|
31 |
+
},
|
32 |
+
"data": {
|
33 |
+
"training_files": "filelists/44k/train.txt",
|
34 |
+
"validation_files": "filelists/44k/val.txt",
|
35 |
+
"max_wav_value": 32768.0,
|
36 |
+
"sampling_rate": 44100,
|
37 |
+
"filter_length": 2048,
|
38 |
+
"hop_length": 512,
|
39 |
+
"win_length": 2048,
|
40 |
+
"n_mel_channels": 80,
|
41 |
+
"mel_fmin": 0.0,
|
42 |
+
"mel_fmax": 22050,
|
43 |
+
"contentvec_final_proj": false
|
44 |
+
},
|
45 |
+
"model": {
|
46 |
+
"inter_channels": 192,
|
47 |
+
"hidden_channels": 192,
|
48 |
+
"filter_channels": 768,
|
49 |
+
"n_heads": 2,
|
50 |
+
"n_layers": 6,
|
51 |
+
"kernel_size": 3,
|
52 |
+
"p_dropout": 0.1,
|
53 |
+
"resblock": "1",
|
54 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
55 |
+
"resblock_dilation_sizes": [
|
56 |
+
[1, 3, 5],
|
57 |
+
[1, 3, 5],
|
58 |
+
[1, 3, 5]
|
59 |
+
],
|
60 |
+
"upsample_rates": [8, 4],
|
61 |
+
"upsample_initial_channel": 512,
|
62 |
+
"upsample_kernel_sizes": [32, 16],
|
63 |
+
"n_layers_q": 3,
|
64 |
+
"use_spectral_norm": false,
|
65 |
+
"gin_channels": 256,
|
66 |
+
"ssl_dim": 768,
|
67 |
+
"n_speakers": 200,
|
68 |
+
"type_": "ms-istft",
|
69 |
+
"gen_istft_n_fft": 16,
|
70 |
+
"gen_istft_hop_size": 4,
|
71 |
+
"subbands": 4,
|
72 |
+
"pretrained": {
|
73 |
+
"D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
|
74 |
+
"G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"spk": {}
|
78 |
+
}
|
so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"eval_interval": 800,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 0.0001,
|
8 |
+
"betas": [0.8, 0.99],
|
9 |
+
"eps": 1e-9,
|
10 |
+
"batch_size": 16,
|
11 |
+
"fp16_run": false,
|
12 |
+
"bf16_run": false,
|
13 |
+
"lr_decay": 0.999875,
|
14 |
+
"segment_size": 10240,
|
15 |
+
"init_lr_ratio": 1,
|
16 |
+
"warmup_epochs": 0,
|
17 |
+
"c_mel": 45,
|
18 |
+
"c_kl": 1.0,
|
19 |
+
"use_sr": true,
|
20 |
+
"max_speclen": 512,
|
21 |
+
"port": "8001",
|
22 |
+
"keep_ckpts": 3,
|
23 |
+
"num_workers": 4,
|
24 |
+
"log_version": 0,
|
25 |
+
"ckpt_name_by_step": false,
|
26 |
+
"accumulate_grad_batches": 1
|
27 |
+
},
|
28 |
+
"data": {
|
29 |
+
"training_files": "filelists/44k/train.txt",
|
30 |
+
"validation_files": "filelists/44k/val.txt",
|
31 |
+
"max_wav_value": 32768.0,
|
32 |
+
"sampling_rate": 44100,
|
33 |
+
"filter_length": 2048,
|
34 |
+
"hop_length": 512,
|
35 |
+
"win_length": 2048,
|
36 |
+
"n_mel_channels": 80,
|
37 |
+
"mel_fmin": 0.0,
|
38 |
+
"mel_fmax": 22050
|
39 |
+
},
|
40 |
+
"model": {
|
41 |
+
"inter_channels": 192,
|
42 |
+
"hidden_channels": 192,
|
43 |
+
"filter_channels": 768,
|
44 |
+
"n_heads": 2,
|
45 |
+
"n_layers": 6,
|
46 |
+
"kernel_size": 3,
|
47 |
+
"p_dropout": 0.1,
|
48 |
+
"resblock": "1",
|
49 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
50 |
+
"resblock_dilation_sizes": [
|
51 |
+
[1, 3, 5],
|
52 |
+
[1, 3, 5],
|
53 |
+
[1, 3, 5]
|
54 |
+
],
|
55 |
+
"upsample_rates": [8, 8, 2, 2, 2],
|
56 |
+
"upsample_initial_channel": 512,
|
57 |
+
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
58 |
+
"n_layers_q": 3,
|
59 |
+
"use_spectral_norm": false,
|
60 |
+
"gin_channels": 256,
|
61 |
+
"ssl_dim": 256,
|
62 |
+
"n_speakers": 200,
|
63 |
+
"pretrained": {
|
64 |
+
"D_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
|
65 |
+
"G_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth"
|
66 |
+
}
|
67 |
+
},
|
68 |
+
"spk": {}
|
69 |
+
}
|
so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 100,
|
4 |
+
"eval_interval": 200,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 0.0001,
|
8 |
+
"betas": [0.8, 0.99],
|
9 |
+
"eps": 1e-9,
|
10 |
+
"batch_size": 16,
|
11 |
+
"fp16_run": false,
|
12 |
+
"bf16_run": false,
|
13 |
+
"lr_decay": 0.999875,
|
14 |
+
"segment_size": 10240,
|
15 |
+
"init_lr_ratio": 1,
|
16 |
+
"warmup_epochs": 0,
|
17 |
+
"c_mel": 45,
|
18 |
+
"c_kl": 1.0,
|
19 |
+
"use_sr": true,
|
20 |
+
"max_speclen": 512,
|
21 |
+
"port": "8001",
|
22 |
+
"keep_ckpts": 3,
|
23 |
+
"num_workers": 4,
|
24 |
+
"log_version": 0,
|
25 |
+
"ckpt_name_by_step": false,
|
26 |
+
"accumulate_grad_batches": 1
|
27 |
+
},
|
28 |
+
"data": {
|
29 |
+
"training_files": "filelists/44k/train.txt",
|
30 |
+
"validation_files": "filelists/44k/val.txt",
|
31 |
+
"max_wav_value": 32768.0,
|
32 |
+
"sampling_rate": 44100,
|
33 |
+
"filter_length": 2048,
|
34 |
+
"hop_length": 512,
|
35 |
+
"win_length": 2048,
|
36 |
+
"n_mel_channels": 80,
|
37 |
+
"mel_fmin": 0.0,
|
38 |
+
"mel_fmax": 22050,
|
39 |
+
"contentvec_final_proj": false
|
40 |
+
},
|
41 |
+
"model": {
|
42 |
+
"inter_channels": 192,
|
43 |
+
"hidden_channels": 192,
|
44 |
+
"filter_channels": 768,
|
45 |
+
"n_heads": 2,
|
46 |
+
"n_layers": 6,
|
47 |
+
"kernel_size": 3,
|
48 |
+
"p_dropout": 0.1,
|
49 |
+
"resblock": "1",
|
50 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
51 |
+
"resblock_dilation_sizes": [
|
52 |
+
[1, 3, 5],
|
53 |
+
[1, 3, 5],
|
54 |
+
[1, 3, 5]
|
55 |
+
],
|
56 |
+
"upsample_rates": [8, 8, 2, 2, 2],
|
57 |
+
"upsample_initial_channel": 512,
|
58 |
+
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
59 |
+
"n_layers_q": 3,
|
60 |
+
"use_spectral_norm": false,
|
61 |
+
"gin_channels": 256,
|
62 |
+
"ssl_dim": 768,
|
63 |
+
"n_speakers": 200,
|
64 |
+
"type_": "hifi-gan",
|
65 |
+
"pretrained": {
|
66 |
+
"D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
|
67 |
+
"G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
|
68 |
+
}
|
69 |
+
},
|
70 |
+
"spk": {}
|
71 |
+
}
|
so_vits_svc_fork/preprocessing/preprocess_classify.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import keyboard
|
7 |
+
import librosa
|
8 |
+
import sounddevice as sd
|
9 |
+
import soundfile as sf
|
10 |
+
from rich.console import Console
|
11 |
+
from tqdm.rich import tqdm
|
12 |
+
|
13 |
+
LOG = getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def preprocess_classify(
|
17 |
+
input_dir: Path | str, output_dir: Path | str, create_new: bool = True
|
18 |
+
) -> None:
|
19 |
+
# paths
|
20 |
+
input_dir_ = Path(input_dir)
|
21 |
+
output_dir_ = Path(output_dir)
|
22 |
+
speed = 1
|
23 |
+
if not input_dir_.is_dir():
|
24 |
+
raise ValueError(f"{input_dir} is not a directory.")
|
25 |
+
output_dir_.mkdir(exist_ok=True)
|
26 |
+
|
27 |
+
console = Console()
|
28 |
+
# get audio paths and folders
|
29 |
+
audio_paths = list(input_dir_.glob("*.*"))
|
30 |
+
last_folders = [x for x in output_dir_.glob("*") if x.is_dir()]
|
31 |
+
console.print("Press ↑ or ↓ to change speed. Press any other key to classify.")
|
32 |
+
console.print(f"Folders: {[x.name for x in last_folders]}")
|
33 |
+
|
34 |
+
pbar_description = ""
|
35 |
+
|
36 |
+
pbar = tqdm(audio_paths)
|
37 |
+
for audio_path in pbar:
|
38 |
+
# read file
|
39 |
+
audio, sr = sf.read(audio_path)
|
40 |
+
|
41 |
+
# update description
|
42 |
+
duration = librosa.get_duration(y=audio, sr=sr)
|
43 |
+
pbar_description = f"{duration:.1f} {pbar_description}"
|
44 |
+
pbar.set_description(pbar_description)
|
45 |
+
|
46 |
+
while True:
|
47 |
+
# start playing
|
48 |
+
sd.play(librosa.effects.time_stretch(audio, rate=speed), sr, loop=True)
|
49 |
+
|
50 |
+
# wait for key press
|
51 |
+
key = str(keyboard.read_key())
|
52 |
+
if key == "down":
|
53 |
+
speed /= 1.1
|
54 |
+
console.print(f"Speed: {speed:.2f}")
|
55 |
+
elif key == "up":
|
56 |
+
speed *= 1.1
|
57 |
+
console.print(f"Speed: {speed:.2f}")
|
58 |
+
else:
|
59 |
+
break
|
60 |
+
|
61 |
+
# stop playing
|
62 |
+
sd.stop()
|
63 |
+
|
64 |
+
# print if folder changed
|
65 |
+
folders = [x for x in output_dir_.glob("*") if x.is_dir()]
|
66 |
+
if folders != last_folders:
|
67 |
+
console.print(f"Folders updated: {[x.name for x in folders]}")
|
68 |
+
last_folders = folders
|
69 |
+
|
70 |
+
# get folder
|
71 |
+
folder_candidates = [x for x in folders if x.name.startswith(key)]
|
72 |
+
if len(folder_candidates) == 0:
|
73 |
+
if create_new:
|
74 |
+
folder = output_dir_ / key
|
75 |
+
else:
|
76 |
+
console.print(f"No folder starts with {key}.")
|
77 |
+
continue
|
78 |
+
else:
|
79 |
+
if len(folder_candidates) > 1:
|
80 |
+
LOG.warning(
|
81 |
+
f"Multiple folders ({[x.name for x in folder_candidates]}) start with {key}. "
|
82 |
+
f"Using first one ({folder_candidates[0].name})."
|
83 |
+
)
|
84 |
+
folder = folder_candidates[0]
|
85 |
+
folder.mkdir(exist_ok=True)
|
86 |
+
|
87 |
+
# move file
|
88 |
+
new_path = folder / audio_path.name
|
89 |
+
audio_path.rename(new_path)
|
90 |
+
|
91 |
+
# update description
|
92 |
+
pbar_description = f"Last: {audio_path.name} -> {folder.name}"
|
93 |
+
|
94 |
+
# yield result
|
95 |
+
# yield audio_path, key, folder, new_path
|
so_vits_svc_fork/preprocessing/preprocess_flist_config.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from copy import deepcopy
|
6 |
+
from logging import getLogger
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from librosa import get_duration
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
LOG = getLogger(__name__)
|
14 |
+
CONFIG_TEMPLATE_DIR = Path(__file__).parent / "config_templates"
|
15 |
+
|
16 |
+
|
17 |
+
def preprocess_config(
|
18 |
+
input_dir: Path | str,
|
19 |
+
train_list_path: Path | str,
|
20 |
+
val_list_path: Path | str,
|
21 |
+
test_list_path: Path | str,
|
22 |
+
config_path: Path | str,
|
23 |
+
config_name: str,
|
24 |
+
):
|
25 |
+
input_dir = Path(input_dir)
|
26 |
+
train_list_path = Path(train_list_path)
|
27 |
+
val_list_path = Path(val_list_path)
|
28 |
+
test_list_path = Path(test_list_path)
|
29 |
+
config_path = Path(config_path)
|
30 |
+
train = []
|
31 |
+
val = []
|
32 |
+
test = []
|
33 |
+
spk_dict = {}
|
34 |
+
spk_id = 0
|
35 |
+
random = np.random.RandomState(1234)
|
36 |
+
for speaker in os.listdir(input_dir):
|
37 |
+
spk_dict[speaker] = spk_id
|
38 |
+
spk_id += 1
|
39 |
+
paths = []
|
40 |
+
for path in tqdm(list((input_dir / speaker).rglob("*.wav"))):
|
41 |
+
if get_duration(filename=path) < 0.3:
|
42 |
+
LOG.warning(f"skip {path} because it is too short.")
|
43 |
+
continue
|
44 |
+
paths.append(path)
|
45 |
+
random.shuffle(paths)
|
46 |
+
if len(paths) <= 4:
|
47 |
+
raise ValueError(
|
48 |
+
f"too few files in {input_dir / speaker} (expected at least 5)."
|
49 |
+
)
|
50 |
+
train += paths[2:-2]
|
51 |
+
val += paths[:2]
|
52 |
+
test += paths[-2:]
|
53 |
+
|
54 |
+
LOG.info(f"Writing {train_list_path}")
|
55 |
+
train_list_path.parent.mkdir(parents=True, exist_ok=True)
|
56 |
+
train_list_path.write_text(
|
57 |
+
"\n".join([x.as_posix() for x in train]), encoding="utf-8"
|
58 |
+
)
|
59 |
+
|
60 |
+
LOG.info(f"Writing {val_list_path}")
|
61 |
+
val_list_path.parent.mkdir(parents=True, exist_ok=True)
|
62 |
+
val_list_path.write_text("\n".join([x.as_posix() for x in val]), encoding="utf-8")
|
63 |
+
|
64 |
+
LOG.info(f"Writing {test_list_path}")
|
65 |
+
test_list_path.parent.mkdir(parents=True, exist_ok=True)
|
66 |
+
test_list_path.write_text("\n".join([x.as_posix() for x in test]), encoding="utf-8")
|
67 |
+
|
68 |
+
config = deepcopy(
|
69 |
+
json.loads(
|
70 |
+
(
|
71 |
+
CONFIG_TEMPLATE_DIR
|
72 |
+
/ (
|
73 |
+
config_name
|
74 |
+
if config_name.endswith(".json")
|
75 |
+
else config_name + ".json"
|
76 |
+
)
|
77 |
+
).read_text(encoding="utf-8")
|
78 |
+
)
|
79 |
+
)
|
80 |
+
config["spk"] = spk_dict
|
81 |
+
config["data"]["training_files"] = train_list_path.as_posix()
|
82 |
+
config["data"]["validation_files"] = val_list_path.as_posix()
|
83 |
+
LOG.info(f"Writing {config_path}")
|
84 |
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
85 |
+
with config_path.open("w", encoding="utf-8") as f:
|
86 |
+
json.dump(config, f, indent=2)
|
so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
from random import shuffle
|
6 |
+
from typing import Iterable, Literal
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
from joblib import Parallel, cpu_count, delayed
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import HubertModel
|
15 |
+
|
16 |
+
import so_vits_svc_fork.f0
|
17 |
+
from so_vits_svc_fork import utils
|
18 |
+
|
19 |
+
from ..hparams import HParams
|
20 |
+
from ..modules.mel_processing import spec_to_mel_torch, spectrogram_torch
|
21 |
+
from ..utils import get_optimal_device, get_total_gpu_memory
|
22 |
+
from .preprocess_utils import check_hubert_min_duration
|
23 |
+
|
24 |
+
LOG = getLogger(__name__)
|
25 |
+
HUBERT_MEMORY = 2900
|
26 |
+
HUBERT_MEMORY_CREPE = 3900
|
27 |
+
|
28 |
+
|
29 |
+
def _process_one(
|
30 |
+
*,
|
31 |
+
filepath: Path,
|
32 |
+
content_model: HubertModel,
|
33 |
+
device: torch.device | str = get_optimal_device(),
|
34 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
|
35 |
+
force_rebuild: bool = False,
|
36 |
+
hps: HParams,
|
37 |
+
):
|
38 |
+
audio, sr = librosa.load(filepath, sr=hps.data.sampling_rate, mono=True)
|
39 |
+
|
40 |
+
if not check_hubert_min_duration(audio, sr):
|
41 |
+
LOG.info(f"Skip {filepath} because it is too short.")
|
42 |
+
return
|
43 |
+
|
44 |
+
data_path = filepath.parent / (filepath.name + ".data.pt")
|
45 |
+
if data_path.exists() and not force_rebuild:
|
46 |
+
return
|
47 |
+
|
48 |
+
# Compute f0
|
49 |
+
f0 = so_vits_svc_fork.f0.compute_f0(
|
50 |
+
audio, sampling_rate=sr, hop_length=hps.data.hop_length, method=f0_method
|
51 |
+
)
|
52 |
+
f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
|
53 |
+
f0 = torch.from_numpy(f0).float()
|
54 |
+
uv = torch.from_numpy(uv).float()
|
55 |
+
|
56 |
+
# Compute HuBERT content
|
57 |
+
audio = torch.from_numpy(audio).float().to(device)
|
58 |
+
c = utils.get_content(
|
59 |
+
content_model,
|
60 |
+
audio,
|
61 |
+
device,
|
62 |
+
sr=sr,
|
63 |
+
legacy_final_proj=hps.data.get("contentvec_final_proj", True),
|
64 |
+
)
|
65 |
+
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
|
66 |
+
torch.cuda.empty_cache()
|
67 |
+
|
68 |
+
# Compute spectrogram
|
69 |
+
audio, sr = torchaudio.load(filepath)
|
70 |
+
spec = spectrogram_torch(audio, hps).squeeze(0)
|
71 |
+
mel_spec = spec_to_mel_torch(spec, hps)
|
72 |
+
torch.cuda.empty_cache()
|
73 |
+
|
74 |
+
# fix lengths
|
75 |
+
lmin = min(spec.shape[1], mel_spec.shape[1], f0.shape[0], uv.shape[0], c.shape[1])
|
76 |
+
spec, mel_spec, f0, uv, c = (
|
77 |
+
spec[:, :lmin],
|
78 |
+
mel_spec[:, :lmin],
|
79 |
+
f0[:lmin],
|
80 |
+
uv[:lmin],
|
81 |
+
c[:, :lmin],
|
82 |
+
)
|
83 |
+
|
84 |
+
# get speaker id
|
85 |
+
spk_name = filepath.parent.name
|
86 |
+
spk = hps.spk.__dict__[spk_name]
|
87 |
+
spk = torch.tensor(spk).long()
|
88 |
+
assert (
|
89 |
+
spec.shape[1] == mel_spec.shape[1] == f0.shape[0] == uv.shape[0] == c.shape[1]
|
90 |
+
), (spec.shape, mel_spec.shape, f0.shape, uv.shape, c.shape)
|
91 |
+
data = {
|
92 |
+
"spec": spec,
|
93 |
+
"mel_spec": mel_spec,
|
94 |
+
"f0": f0,
|
95 |
+
"uv": uv,
|
96 |
+
"content": c,
|
97 |
+
"audio": audio,
|
98 |
+
"spk": spk,
|
99 |
+
}
|
100 |
+
data = {k: v.cpu() for k, v in data.items()}
|
101 |
+
with data_path.open("wb") as f:
|
102 |
+
torch.save(data, f)
|
103 |
+
|
104 |
+
|
105 |
+
def _process_batch(filepaths: Iterable[Path], pbar_position: int, **kwargs):
|
106 |
+
hps = kwargs["hps"]
|
107 |
+
content_model = utils.get_hubert_model(
|
108 |
+
get_optimal_device(), hps.data.get("contentvec_final_proj", True)
|
109 |
+
)
|
110 |
+
|
111 |
+
for filepath in tqdm(filepaths, position=pbar_position):
|
112 |
+
_process_one(
|
113 |
+
content_model=content_model,
|
114 |
+
filepath=filepath,
|
115 |
+
**kwargs,
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def preprocess_hubert_f0(
|
120 |
+
input_dir: Path | str,
|
121 |
+
config_path: Path | str,
|
122 |
+
n_jobs: int | None = None,
|
123 |
+
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
|
124 |
+
force_rebuild: bool = False,
|
125 |
+
):
|
126 |
+
input_dir = Path(input_dir)
|
127 |
+
config_path = Path(config_path)
|
128 |
+
hps = utils.get_hparams(config_path)
|
129 |
+
if n_jobs is None:
|
130 |
+
# add cpu_count() to avoid SIGKILL
|
131 |
+
memory = get_total_gpu_memory("total")
|
132 |
+
n_jobs = min(
|
133 |
+
max(
|
134 |
+
memory
|
135 |
+
// (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY)
|
136 |
+
if memory is not None
|
137 |
+
else 1,
|
138 |
+
1,
|
139 |
+
),
|
140 |
+
cpu_count(),
|
141 |
+
)
|
142 |
+
LOG.info(f"n_jobs automatically set to {n_jobs}, memory: {memory} MiB")
|
143 |
+
|
144 |
+
filepaths = list(input_dir.rglob("*.wav"))
|
145 |
+
n_jobs = min(len(filepaths) // 16 + 1, n_jobs)
|
146 |
+
shuffle(filepaths)
|
147 |
+
filepath_chunks = np.array_split(filepaths, n_jobs)
|
148 |
+
Parallel(n_jobs=n_jobs)(
|
149 |
+
delayed(_process_batch)(
|
150 |
+
filepaths=chunk,
|
151 |
+
pbar_position=pbar_position,
|
152 |
+
f0_method=f0_method,
|
153 |
+
force_rebuild=force_rebuild,
|
154 |
+
hps=hps,
|
155 |
+
)
|
156 |
+
for (pbar_position, chunk) in enumerate(filepath_chunks)
|
157 |
+
)
|
so_vits_svc_fork/preprocessing/preprocess_resample.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
from logging import getLogger
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Iterable
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import soundfile
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
from tqdm_joblib import tqdm_joblib
|
12 |
+
|
13 |
+
from .preprocess_utils import check_hubert_min_duration
|
14 |
+
|
15 |
+
LOG = getLogger(__name__)
|
16 |
+
|
17 |
+
# input_dir and output_dir exists.
|
18 |
+
# write code to convert input dir audio files to output dir audio files,
|
19 |
+
# without changing folder structure. Use joblib to parallelize.
|
20 |
+
# Converting audio files includes:
|
21 |
+
# - resampling to specified sampling rate
|
22 |
+
# - trim silence
|
23 |
+
# - adjust volume in a smart way
|
24 |
+
# - save as 16-bit wav file
|
25 |
+
|
26 |
+
|
27 |
+
def _get_unique_filename(path: Path, existing_paths: Iterable[Path]) -> Path:
|
28 |
+
"""Return a unique path by appending a number to the original path."""
|
29 |
+
if path not in existing_paths:
|
30 |
+
return path
|
31 |
+
i = 1
|
32 |
+
while True:
|
33 |
+
new_path = path.parent / f"{path.stem}_{i}{path.suffix}"
|
34 |
+
if new_path not in existing_paths:
|
35 |
+
return new_path
|
36 |
+
i += 1
|
37 |
+
|
38 |
+
|
39 |
+
def is_relative_to(path: Path, *other):
|
40 |
+
"""Return True if the path is relative to another path or False.
|
41 |
+
Python 3.9+ has Path.is_relative_to() method, but we need to support Python 3.8.
|
42 |
+
"""
|
43 |
+
try:
|
44 |
+
path.relative_to(*other)
|
45 |
+
return True
|
46 |
+
except ValueError:
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
def _preprocess_one(
|
51 |
+
input_path: Path,
|
52 |
+
output_path: Path,
|
53 |
+
sr: int,
|
54 |
+
*,
|
55 |
+
top_db: int,
|
56 |
+
frame_seconds: float,
|
57 |
+
hop_seconds: float,
|
58 |
+
) -> None:
|
59 |
+
"""Preprocess one audio file."""
|
60 |
+
|
61 |
+
try:
|
62 |
+
audio, sr = librosa.load(input_path, sr=sr, mono=True)
|
63 |
+
|
64 |
+
# Audioread is the last backend it will attempt, so this is the exception thrown on failure
|
65 |
+
except Exception as e:
|
66 |
+
# Failure due to attempting to load a file that is not audio, so return early
|
67 |
+
LOG.warning(f"Failed to load {input_path} due to {e}")
|
68 |
+
return
|
69 |
+
|
70 |
+
if not check_hubert_min_duration(audio, sr):
|
71 |
+
LOG.info(f"Skip {input_path} because it is too short.")
|
72 |
+
return
|
73 |
+
|
74 |
+
# Adjust volume
|
75 |
+
audio /= max(audio.max(), -audio.min())
|
76 |
+
|
77 |
+
# Trim silence
|
78 |
+
audio, _ = librosa.effects.trim(
|
79 |
+
audio,
|
80 |
+
top_db=top_db,
|
81 |
+
frame_length=int(frame_seconds * sr),
|
82 |
+
hop_length=int(hop_seconds * sr),
|
83 |
+
)
|
84 |
+
|
85 |
+
if not check_hubert_min_duration(audio, sr):
|
86 |
+
LOG.info(f"Skip {input_path} because it is too short.")
|
87 |
+
return
|
88 |
+
|
89 |
+
soundfile.write(output_path, audio, samplerate=sr, subtype="PCM_16")
|
90 |
+
|
91 |
+
|
92 |
+
def preprocess_resample(
|
93 |
+
input_dir: Path | str,
|
94 |
+
output_dir: Path | str,
|
95 |
+
sampling_rate: int,
|
96 |
+
n_jobs: int = -1,
|
97 |
+
*,
|
98 |
+
top_db: int = 30,
|
99 |
+
frame_seconds: float = 0.1,
|
100 |
+
hop_seconds: float = 0.05,
|
101 |
+
) -> None:
|
102 |
+
input_dir = Path(input_dir)
|
103 |
+
output_dir = Path(output_dir)
|
104 |
+
"""Preprocess audio files in input_dir and save them to output_dir."""
|
105 |
+
|
106 |
+
out_paths = []
|
107 |
+
in_paths = list(input_dir.rglob("*.*"))
|
108 |
+
if not in_paths:
|
109 |
+
raise ValueError(f"No audio files found in {input_dir}")
|
110 |
+
for in_path in in_paths:
|
111 |
+
in_path_relative = in_path.relative_to(input_dir)
|
112 |
+
if not in_path.is_absolute() and is_relative_to(
|
113 |
+
in_path, Path("dataset_raw") / "44k"
|
114 |
+
):
|
115 |
+
new_in_path_relative = in_path_relative.relative_to("44k")
|
116 |
+
warnings.warn(
|
117 |
+
f"Recommended folder structure has changed since v1.0.0. "
|
118 |
+
"Please move your dataset directly under dataset_raw folder. "
|
119 |
+
f"Recoginzed {in_path_relative} as {new_in_path_relative}"
|
120 |
+
)
|
121 |
+
in_path_relative = new_in_path_relative
|
122 |
+
|
123 |
+
if len(in_path_relative.parts) < 2:
|
124 |
+
continue
|
125 |
+
speaker_name = in_path_relative.parts[0]
|
126 |
+
file_name = in_path_relative.with_suffix(".wav").name
|
127 |
+
out_path = output_dir / speaker_name / file_name
|
128 |
+
out_path = _get_unique_filename(out_path, out_paths)
|
129 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
130 |
+
out_paths.append(out_path)
|
131 |
+
|
132 |
+
in_and_out_paths = list(zip(in_paths, out_paths))
|
133 |
+
|
134 |
+
with tqdm_joblib(desc="Preprocessing", total=len(in_and_out_paths)):
|
135 |
+
Parallel(n_jobs=n_jobs)(
|
136 |
+
delayed(_preprocess_one)(
|
137 |
+
*args,
|
138 |
+
sr=sampling_rate,
|
139 |
+
top_db=top_db,
|
140 |
+
frame_seconds=frame_seconds,
|
141 |
+
hop_seconds=hop_seconds,
|
142 |
+
)
|
143 |
+
for args in in_and_out_paths
|
144 |
+
)
|
so_vits_svc_fork/preprocessing/preprocess_speaker_diarization.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from collections import defaultdict
|
4 |
+
from logging import getLogger
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import soundfile as sf
|
9 |
+
import torch
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
from pyannote.audio import Pipeline
|
12 |
+
from tqdm import tqdm
|
13 |
+
from tqdm_joblib import tqdm_joblib
|
14 |
+
|
15 |
+
LOG = getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def _process_one(
|
19 |
+
input_path: Path,
|
20 |
+
output_dir: Path,
|
21 |
+
sr: int,
|
22 |
+
*,
|
23 |
+
min_speakers: int = 1,
|
24 |
+
max_speakers: int = 1,
|
25 |
+
huggingface_token: str | None = None,
|
26 |
+
) -> None:
|
27 |
+
try:
|
28 |
+
audio, sr = librosa.load(input_path, sr=sr, mono=True)
|
29 |
+
except Exception as e:
|
30 |
+
LOG.warning(f"Failed to read {input_path}: {e}")
|
31 |
+
return
|
32 |
+
pipeline = Pipeline.from_pretrained(
|
33 |
+
"pyannote/speaker-diarization", use_auth_token=huggingface_token
|
34 |
+
)
|
35 |
+
if pipeline is None:
|
36 |
+
raise ValueError("Failed to load pipeline")
|
37 |
+
|
38 |
+
LOG.info(f"Processing {input_path}. This may take a while...")
|
39 |
+
diarization = pipeline(
|
40 |
+
input_path, min_speakers=min_speakers, max_speakers=max_speakers
|
41 |
+
)
|
42 |
+
|
43 |
+
LOG.info(f"Found {len(diarization)} tracks, writing to {output_dir}")
|
44 |
+
speaker_count = defaultdict(int)
|
45 |
+
|
46 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
47 |
+
for segment, track, speaker in tqdm(
|
48 |
+
list(diarization.itertracks(yield_label=True)), desc=f"Writing {input_path}"
|
49 |
+
):
|
50 |
+
if segment.end - segment.start < 1:
|
51 |
+
continue
|
52 |
+
speaker_count[speaker] += 1
|
53 |
+
audio_cut = audio[int(segment.start * sr) : int(segment.end * sr)]
|
54 |
+
sf.write(
|
55 |
+
(output_dir / f"{speaker}_{speaker_count[speaker]}.wav"),
|
56 |
+
audio_cut,
|
57 |
+
sr,
|
58 |
+
)
|
59 |
+
|
60 |
+
LOG.info(f"Speaker count: {speaker_count}")
|
61 |
+
|
62 |
+
|
63 |
+
def preprocess_speaker_diarization(
|
64 |
+
input_dir: Path | str,
|
65 |
+
output_dir: Path | str,
|
66 |
+
sr: int,
|
67 |
+
*,
|
68 |
+
min_speakers: int = 1,
|
69 |
+
max_speakers: int = 1,
|
70 |
+
huggingface_token: str | None = None,
|
71 |
+
n_jobs: int = -1,
|
72 |
+
) -> None:
|
73 |
+
if huggingface_token is not None and not huggingface_token.startswith("hf_"):
|
74 |
+
LOG.warning("Huggingface token probably should start with hf_")
|
75 |
+
if not torch.cuda.is_available():
|
76 |
+
LOG.warning("CUDA is not available. This will be extremely slow.")
|
77 |
+
input_dir = Path(input_dir)
|
78 |
+
output_dir = Path(output_dir)
|
79 |
+
input_dir.mkdir(parents=True, exist_ok=True)
|
80 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
81 |
+
input_paths = list(input_dir.rglob("*.*"))
|
82 |
+
with tqdm_joblib(desc="Preprocessing speaker diarization", total=len(input_paths)):
|
83 |
+
Parallel(n_jobs=n_jobs)(
|
84 |
+
delayed(_process_one)(
|
85 |
+
input_path,
|
86 |
+
output_dir / input_path.relative_to(input_dir).parent / input_path.stem,
|
87 |
+
sr,
|
88 |
+
max_speakers=max_speakers,
|
89 |
+
min_speakers=min_speakers,
|
90 |
+
huggingface_token=huggingface_token,
|
91 |
+
)
|
92 |
+
for input_path in input_paths
|
93 |
+
)
|
so_vits_svc_fork/preprocessing/preprocess_split.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import soundfile as sf
|
8 |
+
from joblib import Parallel, delayed
|
9 |
+
from tqdm import tqdm
|
10 |
+
from tqdm_joblib import tqdm_joblib
|
11 |
+
|
12 |
+
LOG = getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def _process_one(
|
16 |
+
input_path: Path,
|
17 |
+
output_dir: Path,
|
18 |
+
sr: int,
|
19 |
+
*,
|
20 |
+
max_length: float = 10.0,
|
21 |
+
top_db: int = 30,
|
22 |
+
frame_seconds: float = 0.5,
|
23 |
+
hop_seconds: float = 0.1,
|
24 |
+
):
|
25 |
+
try:
|
26 |
+
audio, sr = librosa.load(input_path, sr=sr, mono=True)
|
27 |
+
except Exception as e:
|
28 |
+
LOG.warning(f"Failed to read {input_path}: {e}")
|
29 |
+
return
|
30 |
+
intervals = librosa.effects.split(
|
31 |
+
audio,
|
32 |
+
top_db=top_db,
|
33 |
+
frame_length=int(sr * frame_seconds),
|
34 |
+
hop_length=int(sr * hop_seconds),
|
35 |
+
)
|
36 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
37 |
+
for start, end in tqdm(intervals, desc=f"Writing {input_path}"):
|
38 |
+
for sub_start in range(start, end, int(sr * max_length)):
|
39 |
+
sub_end = min(sub_start + int(sr * max_length), end)
|
40 |
+
audio_cut = audio[sub_start:sub_end]
|
41 |
+
sf.write(
|
42 |
+
(
|
43 |
+
output_dir
|
44 |
+
/ f"{input_path.stem}_{sub_start / sr:.3f}_{sub_end / sr:.3f}.wav"
|
45 |
+
),
|
46 |
+
audio_cut,
|
47 |
+
sr,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def preprocess_split(
|
52 |
+
input_dir: Path | str,
|
53 |
+
output_dir: Path | str,
|
54 |
+
sr: int,
|
55 |
+
*,
|
56 |
+
max_length: float = 10.0,
|
57 |
+
top_db: int = 30,
|
58 |
+
frame_seconds: float = 0.5,
|
59 |
+
hop_seconds: float = 0.1,
|
60 |
+
n_jobs: int = -1,
|
61 |
+
):
|
62 |
+
input_dir = Path(input_dir)
|
63 |
+
output_dir = Path(output_dir)
|
64 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
65 |
+
input_paths = list(input_dir.rglob("*.*"))
|
66 |
+
with tqdm_joblib(desc="Splitting", total=len(input_paths)):
|
67 |
+
Parallel(n_jobs=n_jobs)(
|
68 |
+
delayed(_process_one)(
|
69 |
+
input_path,
|
70 |
+
output_dir / input_path.relative_to(input_dir).parent,
|
71 |
+
sr,
|
72 |
+
max_length=max_length,
|
73 |
+
top_db=top_db,
|
74 |
+
frame_seconds=frame_seconds,
|
75 |
+
hop_seconds=hop_seconds,
|
76 |
+
)
|
77 |
+
for input_path in input_paths
|
78 |
+
)
|
so_vits_svc_fork/preprocessing/preprocess_utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import ndarray
|
2 |
+
|
3 |
+
|
4 |
+
def check_hubert_min_duration(audio: ndarray, sr: int) -> bool:
|
5 |
+
return len(audio) / sr >= 0.3
|
so_vits_svc_fork/py.typed
ADDED
File without changes
|
so_vits_svc_fork/train.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from logging import getLogger
|
6 |
+
from multiprocessing import cpu_count
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
import lightning.pytorch as pl
|
11 |
+
import torch
|
12 |
+
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
|
13 |
+
from lightning.pytorch.callbacks import DeviceStatsMonitor
|
14 |
+
from lightning.pytorch.loggers import TensorBoardLogger
|
15 |
+
from lightning.pytorch.strategies.ddp import DDPStrategy
|
16 |
+
from lightning.pytorch.tuner import Tuner
|
17 |
+
from torch.cuda.amp import autocast
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
21 |
+
|
22 |
+
import so_vits_svc_fork.f0
|
23 |
+
import so_vits_svc_fork.modules.commons as commons
|
24 |
+
import so_vits_svc_fork.utils
|
25 |
+
|
26 |
+
from . import utils
|
27 |
+
from .dataset import TextAudioCollate, TextAudioDataset
|
28 |
+
from .logger import is_notebook
|
29 |
+
from .modules.descriminators import MultiPeriodDiscriminator
|
30 |
+
from .modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
31 |
+
from .modules.mel_processing import mel_spectrogram_torch
|
32 |
+
from .modules.synthesizers import SynthesizerTrn
|
33 |
+
|
34 |
+
LOG = getLogger(__name__)
|
35 |
+
torch.set_float32_matmul_precision("high")
|
36 |
+
|
37 |
+
|
38 |
+
class VCDataModule(pl.LightningDataModule):
|
39 |
+
batch_size: int
|
40 |
+
|
41 |
+
def __init__(self, hparams: Any):
|
42 |
+
super().__init__()
|
43 |
+
self.__hparams = hparams
|
44 |
+
self.batch_size = hparams.train.batch_size
|
45 |
+
if not isinstance(self.batch_size, int):
|
46 |
+
self.batch_size = 1
|
47 |
+
self.collate_fn = TextAudioCollate()
|
48 |
+
|
49 |
+
# these should be called in setup(), but we need to calculate check_val_every_n_epoch
|
50 |
+
self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
|
51 |
+
self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
|
52 |
+
|
53 |
+
def train_dataloader(self):
|
54 |
+
return DataLoader(
|
55 |
+
self.train_dataset,
|
56 |
+
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
|
57 |
+
batch_size=self.batch_size,
|
58 |
+
collate_fn=self.collate_fn,
|
59 |
+
persistent_workers=True,
|
60 |
+
)
|
61 |
+
|
62 |
+
def val_dataloader(self):
|
63 |
+
return DataLoader(
|
64 |
+
self.val_dataset,
|
65 |
+
batch_size=1,
|
66 |
+
collate_fn=self.collate_fn,
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def train(
|
71 |
+
config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
|
72 |
+
):
|
73 |
+
config_path = Path(config_path)
|
74 |
+
model_path = Path(model_path)
|
75 |
+
|
76 |
+
hparams = utils.get_backup_hparams(config_path, model_path)
|
77 |
+
utils.ensure_pretrained_model(
|
78 |
+
model_path,
|
79 |
+
hparams.model.get(
|
80 |
+
"pretrained",
|
81 |
+
{
|
82 |
+
"D_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
|
83 |
+
"G_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth",
|
84 |
+
},
|
85 |
+
),
|
86 |
+
)
|
87 |
+
|
88 |
+
datamodule = VCDataModule(hparams)
|
89 |
+
strategy = (
|
90 |
+
(
|
91 |
+
"ddp_find_unused_parameters_true"
|
92 |
+
if os.name != "nt"
|
93 |
+
else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
|
94 |
+
)
|
95 |
+
if torch.cuda.device_count() > 1
|
96 |
+
else "auto"
|
97 |
+
)
|
98 |
+
LOG.info(f"Using strategy: {strategy}")
|
99 |
+
trainer = pl.Trainer(
|
100 |
+
logger=TensorBoardLogger(
|
101 |
+
model_path, "lightning_logs", hparams.train.get("log_version", 0)
|
102 |
+
),
|
103 |
+
# profiler="simple",
|
104 |
+
val_check_interval=hparams.train.eval_interval,
|
105 |
+
max_epochs=hparams.train.epochs,
|
106 |
+
check_val_every_n_epoch=None,
|
107 |
+
precision="16-mixed"
|
108 |
+
if hparams.train.fp16_run
|
109 |
+
else "bf16-mixed"
|
110 |
+
if hparams.train.get("bf16_run", False)
|
111 |
+
else 32,
|
112 |
+
strategy=strategy,
|
113 |
+
callbacks=([pl.callbacks.RichProgressBar()] if not is_notebook() else [])
|
114 |
+
+ [DeviceStatsMonitor()],
|
115 |
+
benchmark=True,
|
116 |
+
enable_checkpointing=False,
|
117 |
+
)
|
118 |
+
tuner = Tuner(trainer)
|
119 |
+
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
|
120 |
+
|
121 |
+
# automatic batch size scaling
|
122 |
+
batch_size = hparams.train.batch_size
|
123 |
+
batch_split = str(batch_size).split("-")
|
124 |
+
batch_size = batch_split[0]
|
125 |
+
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
|
126 |
+
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
|
127 |
+
if batch_size == "auto":
|
128 |
+
batch_size = "binsearch"
|
129 |
+
if batch_size in ["power", "binsearch"]:
|
130 |
+
model.tuning = True
|
131 |
+
tuner.scale_batch_size(
|
132 |
+
model,
|
133 |
+
mode=batch_size,
|
134 |
+
datamodule=datamodule,
|
135 |
+
steps_per_trial=1,
|
136 |
+
init_val=init_val,
|
137 |
+
max_trials=max_trials,
|
138 |
+
)
|
139 |
+
model.tuning = False
|
140 |
+
else:
|
141 |
+
batch_size = int(batch_size)
|
142 |
+
# automatic learning rate scaling is not supported for multiple optimizers
|
143 |
+
"""if hparams.train.learning_rate == "auto":
|
144 |
+
lr_finder = tuner.lr_find(model)
|
145 |
+
LOG.info(lr_finder.results)
|
146 |
+
fig = lr_finder.plot(suggest=True)
|
147 |
+
fig.savefig(model_path / "lr_finder.png")"""
|
148 |
+
|
149 |
+
trainer.fit(model, datamodule=datamodule)
|
150 |
+
|
151 |
+
|
152 |
+
class VitsLightning(pl.LightningModule):
|
153 |
+
def __init__(self, reset_optimizer: bool = False, **hparams: Any):
|
154 |
+
super().__init__()
|
155 |
+
self._temp_epoch = 0 # Add this line to initialize the _temp_epoch attribute
|
156 |
+
self.save_hyperparameters("reset_optimizer")
|
157 |
+
self.save_hyperparameters(*[k for k in hparams.keys()])
|
158 |
+
torch.manual_seed(self.hparams.train.seed)
|
159 |
+
self.net_g = SynthesizerTrn(
|
160 |
+
self.hparams.data.filter_length // 2 + 1,
|
161 |
+
self.hparams.train.segment_size // self.hparams.data.hop_length,
|
162 |
+
**self.hparams.model,
|
163 |
+
)
|
164 |
+
self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm)
|
165 |
+
self.automatic_optimization = False
|
166 |
+
self.learning_rate = self.hparams.train.learning_rate
|
167 |
+
self.optim_g = torch.optim.AdamW(
|
168 |
+
self.net_g.parameters(),
|
169 |
+
self.learning_rate,
|
170 |
+
betas=self.hparams.train.betas,
|
171 |
+
eps=self.hparams.train.eps,
|
172 |
+
)
|
173 |
+
self.optim_d = torch.optim.AdamW(
|
174 |
+
self.net_d.parameters(),
|
175 |
+
self.learning_rate,
|
176 |
+
betas=self.hparams.train.betas,
|
177 |
+
eps=self.hparams.train.eps,
|
178 |
+
)
|
179 |
+
self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
180 |
+
self.optim_g, gamma=self.hparams.train.lr_decay
|
181 |
+
)
|
182 |
+
self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
183 |
+
self.optim_d, gamma=self.hparams.train.lr_decay
|
184 |
+
)
|
185 |
+
self.optimizers_count = 2
|
186 |
+
self.load(reset_optimizer)
|
187 |
+
self.tuning = False
|
188 |
+
|
189 |
+
def on_train_start(self) -> None:
|
190 |
+
if not self.tuning:
|
191 |
+
self.set_current_epoch(self._temp_epoch)
|
192 |
+
total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader)
|
193 |
+
self.set_total_batch_idx(total_batch_idx)
|
194 |
+
global_step = total_batch_idx * self.optimizers_count
|
195 |
+
self.set_global_step(global_step)
|
196 |
+
|
197 |
+
# check if using tpu or mps
|
198 |
+
if isinstance(self.trainer.accelerator, (TPUAccelerator, MPSAccelerator)):
|
199 |
+
# patch torch.stft to use cpu
|
200 |
+
LOG.warning("Using TPU/MPS. Patching torch.stft to use cpu.")
|
201 |
+
|
202 |
+
def stft(
|
203 |
+
input: torch.Tensor,
|
204 |
+
n_fft: int,
|
205 |
+
hop_length: int | None = None,
|
206 |
+
win_length: int | None = None,
|
207 |
+
window: torch.Tensor | None = None,
|
208 |
+
center: bool = True,
|
209 |
+
pad_mode: str = "reflect",
|
210 |
+
normalized: bool = False,
|
211 |
+
onesided: bool | None = None,
|
212 |
+
return_complex: bool | None = None,
|
213 |
+
) -> torch.Tensor:
|
214 |
+
device = input.device
|
215 |
+
input = input.cpu()
|
216 |
+
if window is not None:
|
217 |
+
window = window.cpu()
|
218 |
+
return torch.functional.stft(
|
219 |
+
input,
|
220 |
+
n_fft,
|
221 |
+
hop_length,
|
222 |
+
win_length,
|
223 |
+
window,
|
224 |
+
center,
|
225 |
+
pad_mode,
|
226 |
+
normalized,
|
227 |
+
onesided,
|
228 |
+
return_complex,
|
229 |
+
).to(device)
|
230 |
+
|
231 |
+
torch.stft = stft
|
232 |
+
|
233 |
+
elif "bf" in self.trainer.precision:
|
234 |
+
LOG.warning("Using bf. Patching torch.stft to use fp32.")
|
235 |
+
|
236 |
+
def stft(
|
237 |
+
input: torch.Tensor,
|
238 |
+
n_fft: int,
|
239 |
+
hop_length: int | None = None,
|
240 |
+
win_length: int | None = None,
|
241 |
+
window: torch.Tensor | None = None,
|
242 |
+
center: bool = True,
|
243 |
+
pad_mode: str = "reflect",
|
244 |
+
normalized: bool = False,
|
245 |
+
onesided: bool | None = None,
|
246 |
+
return_complex: bool | None = None,
|
247 |
+
) -> torch.Tensor:
|
248 |
+
dtype = input.dtype
|
249 |
+
input = input.float()
|
250 |
+
if window is not None:
|
251 |
+
window = window.float()
|
252 |
+
return torch.functional.stft(
|
253 |
+
input,
|
254 |
+
n_fft,
|
255 |
+
hop_length,
|
256 |
+
win_length,
|
257 |
+
window,
|
258 |
+
center,
|
259 |
+
pad_mode,
|
260 |
+
normalized,
|
261 |
+
onesided,
|
262 |
+
return_complex,
|
263 |
+
).to(dtype)
|
264 |
+
|
265 |
+
torch.stft = stft
|
266 |
+
|
267 |
+
def on_train_end(self) -> None:
|
268 |
+
self.save_checkpoints(adjust=0)
|
269 |
+
|
270 |
+
def save_checkpoints(self, adjust=1):
|
271 |
+
if self.tuning or self.trainer.sanity_checking:
|
272 |
+
return
|
273 |
+
|
274 |
+
# only save checkpoints if we are on the main device
|
275 |
+
if (
|
276 |
+
hasattr(self.device, "index")
|
277 |
+
and self.device.index != None
|
278 |
+
and self.device.index != 0
|
279 |
+
):
|
280 |
+
return
|
281 |
+
|
282 |
+
# `on_train_end` will be the actual epoch, not a -1, so we have to call it with `adjust = 0`
|
283 |
+
current_epoch = self.current_epoch + adjust
|
284 |
+
total_batch_idx = self.total_batch_idx - 1 + adjust
|
285 |
+
|
286 |
+
utils.save_checkpoint(
|
287 |
+
self.net_g,
|
288 |
+
self.optim_g,
|
289 |
+
self.learning_rate,
|
290 |
+
current_epoch,
|
291 |
+
Path(self.hparams.model_dir)
|
292 |
+
/ f"G_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
|
293 |
+
)
|
294 |
+
utils.save_checkpoint(
|
295 |
+
self.net_d,
|
296 |
+
self.optim_d,
|
297 |
+
self.learning_rate,
|
298 |
+
current_epoch,
|
299 |
+
Path(self.hparams.model_dir)
|
300 |
+
/ f"D_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
|
301 |
+
)
|
302 |
+
keep_ckpts = self.hparams.train.get("keep_ckpts", 0)
|
303 |
+
if keep_ckpts > 0:
|
304 |
+
utils.clean_checkpoints(
|
305 |
+
path_to_models=self.hparams.model_dir,
|
306 |
+
n_ckpts_to_keep=keep_ckpts,
|
307 |
+
sort_by_time=True,
|
308 |
+
)
|
309 |
+
|
310 |
+
def set_current_epoch(self, epoch: int):
|
311 |
+
LOG.info(f"Setting current epoch to {epoch}")
|
312 |
+
self.trainer.fit_loop.epoch_progress.current.completed = epoch
|
313 |
+
self.trainer.fit_loop.epoch_progress.current.processed = epoch
|
314 |
+
assert self.current_epoch == epoch, f"{self.current_epoch} != {epoch}"
|
315 |
+
|
316 |
+
def set_global_step(self, global_step: int):
|
317 |
+
LOG.info(f"Setting global step to {global_step}")
|
318 |
+
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed = (
|
319 |
+
global_step
|
320 |
+
)
|
321 |
+
self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
|
322 |
+
global_step
|
323 |
+
)
|
324 |
+
assert self.global_step == global_step, f"{self.global_step} != {global_step}"
|
325 |
+
|
326 |
+
def set_total_batch_idx(self, total_batch_idx: int):
|
327 |
+
LOG.info(f"Setting total batch idx to {total_batch_idx}")
|
328 |
+
self.trainer.fit_loop.epoch_loop.batch_progress.total.ready = (
|
329 |
+
total_batch_idx + 1
|
330 |
+
)
|
331 |
+
self.trainer.fit_loop.epoch_loop.batch_progress.total.completed = (
|
332 |
+
total_batch_idx
|
333 |
+
)
|
334 |
+
assert (
|
335 |
+
self.total_batch_idx == total_batch_idx + 1
|
336 |
+
), f"{self.total_batch_idx} != {total_batch_idx + 1}"
|
337 |
+
|
338 |
+
@property
|
339 |
+
def total_batch_idx(self) -> int:
|
340 |
+
return self.trainer.fit_loop.epoch_loop.total_batch_idx + 1
|
341 |
+
|
342 |
+
def load(self, reset_optimizer: bool = False):
|
343 |
+
latest_g_path = utils.latest_checkpoint_path(self.hparams.model_dir, "G_*.pth")
|
344 |
+
latest_d_path = utils.latest_checkpoint_path(self.hparams.model_dir, "D_*.pth")
|
345 |
+
if latest_g_path is not None and latest_d_path is not None:
|
346 |
+
try:
|
347 |
+
_, _, _, epoch = utils.load_checkpoint(
|
348 |
+
latest_g_path,
|
349 |
+
self.net_g,
|
350 |
+
self.optim_g,
|
351 |
+
reset_optimizer,
|
352 |
+
)
|
353 |
+
_, _, _, epoch = utils.load_checkpoint(
|
354 |
+
latest_d_path,
|
355 |
+
self.net_d,
|
356 |
+
self.optim_d,
|
357 |
+
reset_optimizer,
|
358 |
+
)
|
359 |
+
self._temp_epoch = epoch
|
360 |
+
self.scheduler_g.last_epoch = epoch - 1
|
361 |
+
self.scheduler_d.last_epoch = epoch - 1
|
362 |
+
except Exception as e:
|
363 |
+
raise RuntimeError("Failed to load checkpoint") from e
|
364 |
+
else:
|
365 |
+
LOG.warning("No checkpoint found. Start from scratch.")
|
366 |
+
|
367 |
+
def configure_optimizers(self):
|
368 |
+
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
|
369 |
+
|
370 |
+
def log_image_dict(
|
371 |
+
self, image_dict: dict[str, Any], dataformats: str = "HWC"
|
372 |
+
) -> None:
|
373 |
+
if not isinstance(self.logger, TensorBoardLogger):
|
374 |
+
warnings.warn("Image logging is only supported with TensorBoardLogger.")
|
375 |
+
return
|
376 |
+
writer: SummaryWriter = self.logger.experiment
|
377 |
+
for k, v in image_dict.items():
|
378 |
+
try:
|
379 |
+
writer.add_image(k, v, self.total_batch_idx, dataformats=dataformats)
|
380 |
+
except Exception as e:
|
381 |
+
warnings.warn(f"Failed to log image {k}: {e}")
|
382 |
+
|
383 |
+
def log_audio_dict(self, audio_dict: dict[str, Any]) -> None:
|
384 |
+
if not isinstance(self.logger, TensorBoardLogger):
|
385 |
+
warnings.warn("Audio logging is only supported with TensorBoardLogger.")
|
386 |
+
return
|
387 |
+
writer: SummaryWriter = self.logger.experiment
|
388 |
+
for k, v in audio_dict.items():
|
389 |
+
writer.add_audio(
|
390 |
+
k,
|
391 |
+
v.float(),
|
392 |
+
self.total_batch_idx,
|
393 |
+
sample_rate=self.hparams.data.sampling_rate,
|
394 |
+
)
|
395 |
+
|
396 |
+
def log_dict_(self, log_dict: dict[str, Any], **kwargs) -> None:
|
397 |
+
if not isinstance(self.logger, TensorBoardLogger):
|
398 |
+
warnings.warn("Logging is only supported with TensorBoardLogger.")
|
399 |
+
return
|
400 |
+
writer: SummaryWriter = self.logger.experiment
|
401 |
+
for k, v in log_dict.items():
|
402 |
+
writer.add_scalar(k, v, self.total_batch_idx)
|
403 |
+
kwargs["logger"] = False
|
404 |
+
self.log_dict(log_dict, **kwargs)
|
405 |
+
|
406 |
+
def log_(self, key: str, value: Any, **kwargs) -> None:
|
407 |
+
self.log_dict_({key: value}, **kwargs)
|
408 |
+
|
409 |
+
def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
|
410 |
+
self.net_g.train()
|
411 |
+
self.net_d.train()
|
412 |
+
|
413 |
+
# get optims
|
414 |
+
optim_g, optim_d = self.optimizers()
|
415 |
+
|
416 |
+
# Generator
|
417 |
+
# train
|
418 |
+
self.toggle_optimizer(optim_g)
|
419 |
+
c, f0, spec, mel, y, g, lengths, uv = batch
|
420 |
+
(
|
421 |
+
y_hat,
|
422 |
+
y_hat_mb,
|
423 |
+
ids_slice,
|
424 |
+
z_mask,
|
425 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
426 |
+
pred_lf0,
|
427 |
+
norm_lf0,
|
428 |
+
lf0,
|
429 |
+
) = self.net_g(c, f0, uv, spec, g=g, c_lengths=lengths, spec_lengths=lengths)
|
430 |
+
y_mel = commons.slice_segments(
|
431 |
+
mel,
|
432 |
+
ids_slice,
|
433 |
+
self.hparams.train.segment_size // self.hparams.data.hop_length,
|
434 |
+
)
|
435 |
+
y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1), self.hparams)
|
436 |
+
y_mel = y_mel[..., : y_hat_mel.shape[-1]]
|
437 |
+
y = commons.slice_segments(
|
438 |
+
y,
|
439 |
+
ids_slice * self.hparams.data.hop_length,
|
440 |
+
self.hparams.train.segment_size,
|
441 |
+
)
|
442 |
+
y = y[..., : y_hat.shape[-1]]
|
443 |
+
|
444 |
+
# generator loss
|
445 |
+
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat)
|
446 |
+
|
447 |
+
with autocast(enabled=False):
|
448 |
+
loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.train.c_mel
|
449 |
+
loss_kl = (
|
450 |
+
kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.train.c_kl
|
451 |
+
)
|
452 |
+
loss_fm = feature_loss(fmap_r, fmap_g)
|
453 |
+
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
454 |
+
loss_lf0 = F.mse_loss(pred_lf0, lf0)
|
455 |
+
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0
|
456 |
+
|
457 |
+
# MB-iSTFT-VITS
|
458 |
+
loss_subband = torch.tensor(0.0)
|
459 |
+
if self.hparams.model.get("type_") == "mb-istft":
|
460 |
+
from .modules.decoders.mb_istft import PQMF, subband_stft_loss
|
461 |
+
|
462 |
+
y_mb = PQMF(y.device, self.hparams.model.subbands).analysis(y)
|
463 |
+
loss_subband = subband_stft_loss(self.hparams, y_mb, y_hat_mb)
|
464 |
+
loss_gen_all += loss_subband
|
465 |
+
|
466 |
+
# log loss
|
467 |
+
self.log_("lr", self.optim_g.param_groups[0]["lr"])
|
468 |
+
self.log_dict_(
|
469 |
+
{
|
470 |
+
"loss/g/total": loss_gen_all,
|
471 |
+
"loss/g/fm": loss_fm,
|
472 |
+
"loss/g/mel": loss_mel,
|
473 |
+
"loss/g/kl": loss_kl,
|
474 |
+
"loss/g/lf0": loss_lf0,
|
475 |
+
},
|
476 |
+
prog_bar=True,
|
477 |
+
)
|
478 |
+
if self.hparams.model.get("type_") == "mb-istft":
|
479 |
+
self.log_("loss/g/subband", loss_subband)
|
480 |
+
if self.total_batch_idx % self.hparams.train.log_interval == 0:
|
481 |
+
self.log_image_dict(
|
482 |
+
{
|
483 |
+
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
484 |
+
y_mel[0].data.cpu().float().numpy()
|
485 |
+
),
|
486 |
+
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
487 |
+
y_hat_mel[0].data.cpu().float().numpy()
|
488 |
+
),
|
489 |
+
"all/mel": utils.plot_spectrogram_to_numpy(
|
490 |
+
mel[0].data.cpu().float().numpy()
|
491 |
+
),
|
492 |
+
"all/lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
|
493 |
+
lf0[0, 0, :].cpu().float().numpy(),
|
494 |
+
pred_lf0[0, 0, :].detach().cpu().float().numpy(),
|
495 |
+
),
|
496 |
+
"all/norm_lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
|
497 |
+
lf0[0, 0, :].cpu().float().numpy(),
|
498 |
+
norm_lf0[0, 0, :].detach().cpu().float().numpy(),
|
499 |
+
),
|
500 |
+
}
|
501 |
+
)
|
502 |
+
|
503 |
+
accumulate_grad_batches = self.hparams.train.get("accumulate_grad_batches", 1)
|
504 |
+
should_update = (
|
505 |
+
batch_idx + 1
|
506 |
+
) % accumulate_grad_batches == 0 or self.trainer.is_last_batch
|
507 |
+
# optimizer
|
508 |
+
self.manual_backward(loss_gen_all / accumulate_grad_batches)
|
509 |
+
if should_update:
|
510 |
+
self.log_(
|
511 |
+
"grad_norm_g", commons.clip_grad_value_(self.net_g.parameters(), None)
|
512 |
+
)
|
513 |
+
optim_g.step()
|
514 |
+
optim_g.zero_grad()
|
515 |
+
self.untoggle_optimizer(optim_g)
|
516 |
+
|
517 |
+
# Discriminator
|
518 |
+
# train
|
519 |
+
self.toggle_optimizer(optim_d)
|
520 |
+
y_d_hat_r, y_d_hat_g, _, _ = self.net_d(y, y_hat.detach())
|
521 |
+
|
522 |
+
# discriminator loss
|
523 |
+
with autocast(enabled=False):
|
524 |
+
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
525 |
+
y_d_hat_r, y_d_hat_g
|
526 |
+
)
|
527 |
+
loss_disc_all = loss_disc
|
528 |
+
|
529 |
+
# log loss
|
530 |
+
self.log_("loss/d/total", loss_disc_all, prog_bar=True)
|
531 |
+
|
532 |
+
# optimizer
|
533 |
+
self.manual_backward(loss_disc_all / accumulate_grad_batches)
|
534 |
+
if should_update:
|
535 |
+
self.log_(
|
536 |
+
"grad_norm_d", commons.clip_grad_value_(self.net_d.parameters(), None)
|
537 |
+
)
|
538 |
+
optim_d.step()
|
539 |
+
optim_d.zero_grad()
|
540 |
+
self.untoggle_optimizer(optim_d)
|
541 |
+
|
542 |
+
# end of epoch
|
543 |
+
if self.trainer.is_last_batch:
|
544 |
+
self.scheduler_g.step()
|
545 |
+
self.scheduler_d.step()
|
546 |
+
|
547 |
+
def validation_step(self, batch, batch_idx):
|
548 |
+
# avoid logging with wrong global step
|
549 |
+
if self.global_step == 0:
|
550 |
+
return
|
551 |
+
with torch.no_grad():
|
552 |
+
self.net_g.eval()
|
553 |
+
c, f0, _, mel, y, g, _, uv = batch
|
554 |
+
y_hat = self.net_g.infer(c, f0, uv, g=g)
|
555 |
+
y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1).float(), self.hparams)
|
556 |
+
self.log_audio_dict(
|
557 |
+
{f"gen/audio_{batch_idx}": y_hat[0], f"gt/audio_{batch_idx}": y[0]}
|
558 |
+
)
|
559 |
+
self.log_image_dict(
|
560 |
+
{
|
561 |
+
"gen/mel": utils.plot_spectrogram_to_numpy(
|
562 |
+
y_hat_mel[0].cpu().float().numpy()
|
563 |
+
),
|
564 |
+
"gt/mel": utils.plot_spectrogram_to_numpy(
|
565 |
+
mel[0].cpu().float().numpy()
|
566 |
+
),
|
567 |
+
}
|
568 |
+
)
|
569 |
+
|
570 |
+
def on_validation_end(self) -> None:
|
571 |
+
self.save_checkpoints()
|
so_vits_svc_fork/utils.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import subprocess
|
7 |
+
import warnings
|
8 |
+
from itertools import groupby
|
9 |
+
from logging import getLogger
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, Literal, Sequence
|
12 |
+
|
13 |
+
import matplotlib
|
14 |
+
import matplotlib.pylab as plt
|
15 |
+
import numpy as np
|
16 |
+
import requests
|
17 |
+
import torch
|
18 |
+
import torch.backends.mps
|
19 |
+
import torch.nn as nn
|
20 |
+
import torchaudio
|
21 |
+
from cm_time import timer
|
22 |
+
from numpy import ndarray
|
23 |
+
from tqdm import tqdm
|
24 |
+
from transformers import HubertModel
|
25 |
+
|
26 |
+
from so_vits_svc_fork.hparams import HParams
|
27 |
+
|
28 |
+
LOG = getLogger(__name__)
|
29 |
+
HUBERT_SAMPLING_RATE = 16000
|
30 |
+
IS_COLAB = os.getenv("COLAB_RELEASE_TAG", False)
|
31 |
+
|
32 |
+
|
33 |
+
def get_optimal_device(index: int = 0) -> torch.device:
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
return torch.device(f"cuda:{index % torch.cuda.device_count()}")
|
36 |
+
elif torch.backends.mps.is_available():
|
37 |
+
return torch.device("mps")
|
38 |
+
else:
|
39 |
+
try:
|
40 |
+
import torch_xla.core.xla_model as xm # noqa
|
41 |
+
|
42 |
+
if xm.xrt_world_size() > 0:
|
43 |
+
return torch.device("xla")
|
44 |
+
# return xm.xla_device()
|
45 |
+
except ImportError:
|
46 |
+
pass
|
47 |
+
return torch.device("cpu")
|
48 |
+
|
49 |
+
|
50 |
+
def download_file(
|
51 |
+
url: str,
|
52 |
+
filepath: Path | str,
|
53 |
+
chunk_size: int = 64 * 1024,
|
54 |
+
tqdm_cls: type = tqdm,
|
55 |
+
skip_if_exists: bool = False,
|
56 |
+
overwrite: bool = False,
|
57 |
+
**tqdm_kwargs: Any,
|
58 |
+
):
|
59 |
+
if skip_if_exists is True and overwrite is True:
|
60 |
+
raise ValueError("skip_if_exists and overwrite cannot be both True")
|
61 |
+
filepath = Path(filepath)
|
62 |
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
63 |
+
temppath = filepath.parent / f"{filepath.name}.download"
|
64 |
+
if filepath.exists():
|
65 |
+
if skip_if_exists:
|
66 |
+
return
|
67 |
+
elif not overwrite:
|
68 |
+
filepath.unlink()
|
69 |
+
else:
|
70 |
+
raise FileExistsError(f"{filepath} already exists")
|
71 |
+
temppath.unlink(missing_ok=True)
|
72 |
+
resp = requests.get(url, stream=True)
|
73 |
+
total = int(resp.headers.get("content-length", 0))
|
74 |
+
kwargs = dict(
|
75 |
+
total=total,
|
76 |
+
unit="iB",
|
77 |
+
unit_scale=True,
|
78 |
+
unit_divisor=1024,
|
79 |
+
desc=f"Downloading {filepath.name}",
|
80 |
+
)
|
81 |
+
kwargs.update(tqdm_kwargs)
|
82 |
+
with temppath.open("wb") as f, tqdm_cls(**kwargs) as pbar:
|
83 |
+
for data in resp.iter_content(chunk_size=chunk_size):
|
84 |
+
size = f.write(data)
|
85 |
+
pbar.update(size)
|
86 |
+
temppath.rename(filepath)
|
87 |
+
|
88 |
+
|
89 |
+
PRETRAINED_MODEL_URLS = {
|
90 |
+
"hifi-gan": [
|
91 |
+
[
|
92 |
+
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
|
93 |
+
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth",
|
94 |
+
],
|
95 |
+
[
|
96 |
+
"https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/D_0.pth",
|
97 |
+
"https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/G_0.pth",
|
98 |
+
],
|
99 |
+
],
|
100 |
+
"contentvec": [
|
101 |
+
[
|
102 |
+
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt"
|
103 |
+
],
|
104 |
+
[
|
105 |
+
"https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/checkpoint_best_legacy_500.pt"
|
106 |
+
],
|
107 |
+
[
|
108 |
+
"http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt"
|
109 |
+
],
|
110 |
+
],
|
111 |
+
}
|
112 |
+
from joblib import Parallel, delayed
|
113 |
+
|
114 |
+
|
115 |
+
def ensure_pretrained_model(
|
116 |
+
folder_path: Path | str, type_: str | dict[str, str], **tqdm_kwargs: Any
|
117 |
+
) -> tuple[Path, ...] | None:
|
118 |
+
folder_path = Path(folder_path)
|
119 |
+
|
120 |
+
# new code
|
121 |
+
if not isinstance(type_, str):
|
122 |
+
try:
|
123 |
+
Parallel(n_jobs=len(type_))(
|
124 |
+
[
|
125 |
+
delayed(download_file)(
|
126 |
+
url,
|
127 |
+
folder_path / filename,
|
128 |
+
position=i,
|
129 |
+
skip_if_exists=True,
|
130 |
+
**tqdm_kwargs,
|
131 |
+
)
|
132 |
+
for i, (filename, url) in enumerate(type_.items())
|
133 |
+
]
|
134 |
+
)
|
135 |
+
return tuple(folder_path / filename for filename in type_.values())
|
136 |
+
except Exception as e:
|
137 |
+
LOG.error(f"Failed to download {type_}")
|
138 |
+
LOG.exception(e)
|
139 |
+
|
140 |
+
# old code
|
141 |
+
models_candidates = PRETRAINED_MODEL_URLS.get(type_, None)
|
142 |
+
if models_candidates is None:
|
143 |
+
LOG.warning(f"Unknown pretrained model type: {type_}")
|
144 |
+
return
|
145 |
+
for model_urls in models_candidates:
|
146 |
+
paths = [folder_path / model_url.split("/")[-1] for model_url in model_urls]
|
147 |
+
try:
|
148 |
+
Parallel(n_jobs=len(paths))(
|
149 |
+
[
|
150 |
+
delayed(download_file)(
|
151 |
+
url, path, position=i, skip_if_exists=True, **tqdm_kwargs
|
152 |
+
)
|
153 |
+
for i, (url, path) in enumerate(zip(model_urls, paths))
|
154 |
+
]
|
155 |
+
)
|
156 |
+
return tuple(paths)
|
157 |
+
except Exception as e:
|
158 |
+
LOG.error(f"Failed to download {model_urls}")
|
159 |
+
LOG.exception(e)
|
160 |
+
|
161 |
+
|
162 |
+
class HubertModelWithFinalProj(HubertModel):
|
163 |
+
def __init__(self, config):
|
164 |
+
super().__init__(config)
|
165 |
+
|
166 |
+
# The final projection layer is only used for backward compatibility.
|
167 |
+
# Following https://github.com/auspicious3000/contentvec/issues/6
|
168 |
+
# Remove this layer is necessary to achieve the desired outcome.
|
169 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
170 |
+
|
171 |
+
|
172 |
+
def remove_weight_norm_if_exists(module, name: str = "weight"):
|
173 |
+
r"""Removes the weight normalization reparameterization from a module.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
module (Module): containing module
|
177 |
+
name (str, optional): name of weight parameter
|
178 |
+
|
179 |
+
Example:
|
180 |
+
>>> m = weight_norm(nn.Linear(20, 40))
|
181 |
+
>>> remove_weight_norm(m)
|
182 |
+
"""
|
183 |
+
from torch.nn.utils.weight_norm import WeightNorm
|
184 |
+
|
185 |
+
for k, hook in module._forward_pre_hooks.items():
|
186 |
+
if isinstance(hook, WeightNorm) and hook.name == name:
|
187 |
+
hook.remove(module)
|
188 |
+
del module._forward_pre_hooks[k]
|
189 |
+
return module
|
190 |
+
|
191 |
+
|
192 |
+
def get_hubert_model(
|
193 |
+
device: str | torch.device, final_proj: bool = True
|
194 |
+
) -> HubertModel:
|
195 |
+
if final_proj:
|
196 |
+
model = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best")
|
197 |
+
else:
|
198 |
+
model = HubertModel.from_pretrained("lengyue233/content-vec-best")
|
199 |
+
# Hubert is always used in inference mode, we can safely remove weight-norms
|
200 |
+
for m in model.modules():
|
201 |
+
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
|
202 |
+
remove_weight_norm_if_exists(m)
|
203 |
+
|
204 |
+
return model.to(device)
|
205 |
+
|
206 |
+
|
207 |
+
def get_content(
|
208 |
+
cmodel: HubertModel,
|
209 |
+
audio: torch.Tensor | ndarray[Any, Any],
|
210 |
+
device: torch.device | str,
|
211 |
+
sr: int,
|
212 |
+
legacy_final_proj: bool = False,
|
213 |
+
) -> torch.Tensor:
|
214 |
+
audio = torch.as_tensor(audio)
|
215 |
+
if sr != HUBERT_SAMPLING_RATE:
|
216 |
+
audio = (
|
217 |
+
torchaudio.transforms.Resample(sr, HUBERT_SAMPLING_RATE)
|
218 |
+
.to(audio.device)(audio)
|
219 |
+
.to(device)
|
220 |
+
)
|
221 |
+
if audio.ndim == 1:
|
222 |
+
audio = audio.unsqueeze(0)
|
223 |
+
with torch.no_grad(), timer() as t:
|
224 |
+
if legacy_final_proj:
|
225 |
+
warnings.warn("legacy_final_proj is deprecated")
|
226 |
+
if not hasattr(cmodel, "final_proj"):
|
227 |
+
raise ValueError("HubertModel does not have final_proj")
|
228 |
+
c = cmodel(audio, output_hidden_states=True)["hidden_states"][9]
|
229 |
+
c = cmodel.final_proj(c)
|
230 |
+
else:
|
231 |
+
c = cmodel(audio)["last_hidden_state"]
|
232 |
+
c = c.transpose(1, 2)
|
233 |
+
wav_len = audio.shape[-1] / HUBERT_SAMPLING_RATE
|
234 |
+
LOG.info(
|
235 |
+
f"HuBERT inference time : {t.elapsed:.3f}s, RTF: {t.elapsed / wav_len:.3f}"
|
236 |
+
)
|
237 |
+
return c
|
238 |
+
|
239 |
+
|
240 |
+
def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> None:
|
241 |
+
not_in_to = list(filter(lambda x: x not in to_, from_.keys()))
|
242 |
+
not_in_from = list(filter(lambda x: x not in from_, to_.keys()))
|
243 |
+
if not_in_to:
|
244 |
+
warnings.warn(f"Keys not found in model state dict:" f"{not_in_to}")
|
245 |
+
if not_in_from:
|
246 |
+
warnings.warn(f"Keys not found in checkpoint state dict:" f"{not_in_from}")
|
247 |
+
shape_missmatch = []
|
248 |
+
for k, v in from_.items():
|
249 |
+
if k not in to_:
|
250 |
+
pass
|
251 |
+
elif hasattr(v, "shape"):
|
252 |
+
if not hasattr(to_[k], "shape"):
|
253 |
+
raise ValueError(f"Key {k} is not a tensor")
|
254 |
+
if to_[k].shape == v.shape:
|
255 |
+
to_[k] = v
|
256 |
+
else:
|
257 |
+
shape_missmatch.append((k, to_[k].shape, v.shape))
|
258 |
+
elif isinstance(v, dict):
|
259 |
+
assert isinstance(to_[k], dict)
|
260 |
+
_substitute_if_same_shape(to_[k], v)
|
261 |
+
else:
|
262 |
+
to_[k] = v
|
263 |
+
if shape_missmatch:
|
264 |
+
warnings.warn(
|
265 |
+
f"Shape mismatch: {[f'{k}: {v1} -> {v2}' for k, v1, v2 in shape_missmatch]}"
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
def safe_load(model: torch.nn.Module, state_dict: dict[str, Any]) -> None:
|
270 |
+
model_state_dict = model.state_dict()
|
271 |
+
_substitute_if_same_shape(model_state_dict, state_dict)
|
272 |
+
model.load_state_dict(model_state_dict)
|
273 |
+
|
274 |
+
|
275 |
+
def load_checkpoint(
|
276 |
+
checkpoint_path: Path | str,
|
277 |
+
model: torch.nn.Module,
|
278 |
+
optimizer: torch.optim.Optimizer | None = None,
|
279 |
+
skip_optimizer: bool = False,
|
280 |
+
) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]:
|
281 |
+
if not Path(checkpoint_path).is_file():
|
282 |
+
raise FileNotFoundError(f"File {checkpoint_path} not found")
|
283 |
+
with Path(checkpoint_path).open("rb") as f:
|
284 |
+
with warnings.catch_warnings():
|
285 |
+
warnings.filterwarnings(
|
286 |
+
"ignore", category=UserWarning, message="TypedStorage is deprecated"
|
287 |
+
)
|
288 |
+
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
|
289 |
+
iteration = checkpoint_dict["iteration"]
|
290 |
+
learning_rate = checkpoint_dict["learning_rate"]
|
291 |
+
|
292 |
+
# safe load module
|
293 |
+
if hasattr(model, "module"):
|
294 |
+
safe_load(model.module, checkpoint_dict["model"])
|
295 |
+
else:
|
296 |
+
safe_load(model, checkpoint_dict["model"])
|
297 |
+
# safe load optim
|
298 |
+
if (
|
299 |
+
optimizer is not None
|
300 |
+
and not skip_optimizer
|
301 |
+
and checkpoint_dict["optimizer"] is not None
|
302 |
+
):
|
303 |
+
with warnings.catch_warnings():
|
304 |
+
warnings.simplefilter("ignore")
|
305 |
+
safe_load(optimizer, checkpoint_dict["optimizer"])
|
306 |
+
|
307 |
+
LOG.info(f"Loaded checkpoint '{checkpoint_path}' (epoch {iteration})")
|
308 |
+
return model, optimizer, learning_rate, iteration
|
309 |
+
|
310 |
+
|
311 |
+
def save_checkpoint(
|
312 |
+
model: torch.nn.Module,
|
313 |
+
optimizer: torch.optim.Optimizer,
|
314 |
+
learning_rate: float,
|
315 |
+
iteration: int,
|
316 |
+
checkpoint_path: Path | str,
|
317 |
+
) -> None:
|
318 |
+
LOG.info(
|
319 |
+
"Saving model and optimizer state at epoch {} to {}".format(
|
320 |
+
iteration, checkpoint_path
|
321 |
+
)
|
322 |
+
)
|
323 |
+
if hasattr(model, "module"):
|
324 |
+
state_dict = model.module.state_dict()
|
325 |
+
else:
|
326 |
+
state_dict = model.state_dict()
|
327 |
+
with Path(checkpoint_path).open("wb") as f:
|
328 |
+
torch.save(
|
329 |
+
{
|
330 |
+
"model": state_dict,
|
331 |
+
"iteration": iteration,
|
332 |
+
"optimizer": optimizer.state_dict(),
|
333 |
+
"learning_rate": learning_rate,
|
334 |
+
},
|
335 |
+
f,
|
336 |
+
)
|
337 |
+
|
338 |
+
|
339 |
+
def clean_checkpoints(
|
340 |
+
path_to_models: Path | str, n_ckpts_to_keep: int = 2, sort_by_time: bool = True
|
341 |
+
) -> None:
|
342 |
+
"""Freeing up space by deleting saved ckpts
|
343 |
+
|
344 |
+
Arguments:
|
345 |
+
path_to_models -- Path to the model directory
|
346 |
+
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
347 |
+
sort_by_time -- True -> chronologically delete ckpts
|
348 |
+
False -> lexicographically delete ckpts
|
349 |
+
"""
|
350 |
+
LOG.info("Cleaning old checkpoints...")
|
351 |
+
path_to_models = Path(path_to_models)
|
352 |
+
|
353 |
+
# Define sort key functions
|
354 |
+
name_key = lambda p: int(re.match(r"[GD]_(\d+)", p.stem).group(1))
|
355 |
+
time_key = lambda p: p.stat().st_mtime
|
356 |
+
path_key = lambda p: (p.stem[0], time_key(p) if sort_by_time else name_key(p))
|
357 |
+
|
358 |
+
models = list(
|
359 |
+
filter(
|
360 |
+
lambda p: (
|
361 |
+
p.is_file()
|
362 |
+
and re.match(r"[GD]_\d+", p.stem)
|
363 |
+
and not p.stem.endswith("_0")
|
364 |
+
),
|
365 |
+
path_to_models.glob("*.pth"),
|
366 |
+
)
|
367 |
+
)
|
368 |
+
|
369 |
+
models_sorted = sorted(models, key=path_key)
|
370 |
+
|
371 |
+
models_sorted_grouped = groupby(models_sorted, lambda p: p.stem[0])
|
372 |
+
|
373 |
+
for group_name, group_items in models_sorted_grouped:
|
374 |
+
to_delete_list = list(group_items)[:-n_ckpts_to_keep]
|
375 |
+
|
376 |
+
for to_delete in to_delete_list:
|
377 |
+
if to_delete.exists():
|
378 |
+
LOG.info(f"Removing {to_delete}")
|
379 |
+
if IS_COLAB:
|
380 |
+
to_delete.write_text("")
|
381 |
+
to_delete.unlink()
|
382 |
+
|
383 |
+
|
384 |
+
def latest_checkpoint_path(dir_path: Path | str, regex: str = "G_*.pth") -> Path | None:
|
385 |
+
dir_path = Path(dir_path)
|
386 |
+
name_key = lambda p: int(re.match(r"._(\d+)\.pth", p.name).group(1))
|
387 |
+
paths = list(sorted(dir_path.glob(regex), key=name_key))
|
388 |
+
if len(paths) == 0:
|
389 |
+
return None
|
390 |
+
return paths[-1]
|
391 |
+
|
392 |
+
|
393 |
+
def plot_spectrogram_to_numpy(spectrogram: ndarray) -> ndarray:
|
394 |
+
matplotlib.use("Agg")
|
395 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
396 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
397 |
+
plt.colorbar(im, ax=ax)
|
398 |
+
plt.xlabel("Frames")
|
399 |
+
plt.ylabel("Channels")
|
400 |
+
plt.tight_layout()
|
401 |
+
|
402 |
+
fig.canvas.draw()
|
403 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
404 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
405 |
+
plt.close()
|
406 |
+
return data
|
407 |
+
|
408 |
+
|
409 |
+
def get_backup_hparams(
|
410 |
+
config_path: Path, model_path: Path, init: bool = True
|
411 |
+
) -> HParams:
|
412 |
+
model_path.mkdir(parents=True, exist_ok=True)
|
413 |
+
config_save_path = model_path / "config.json"
|
414 |
+
if init:
|
415 |
+
with config_path.open() as f:
|
416 |
+
data = f.read()
|
417 |
+
with config_save_path.open("w") as f:
|
418 |
+
f.write(data)
|
419 |
+
else:
|
420 |
+
with config_save_path.open() as f:
|
421 |
+
data = f.read()
|
422 |
+
config = json.loads(data)
|
423 |
+
|
424 |
+
hparams = HParams(**config)
|
425 |
+
hparams.model_dir = model_path.as_posix()
|
426 |
+
return hparams
|
427 |
+
|
428 |
+
|
429 |
+
def get_hparams(config_path: Path | str) -> HParams:
|
430 |
+
config = json.loads(Path(config_path).read_text("utf-8"))
|
431 |
+
hparams = HParams(**config)
|
432 |
+
return hparams
|
433 |
+
|
434 |
+
|
435 |
+
def repeat_expand_2d(content: torch.Tensor, target_len: int) -> torch.Tensor:
|
436 |
+
# content : [h, t]
|
437 |
+
src_len = content.shape[-1]
|
438 |
+
if target_len < src_len:
|
439 |
+
return content[:, :target_len]
|
440 |
+
else:
|
441 |
+
return torch.nn.functional.interpolate(
|
442 |
+
content.unsqueeze(0), size=target_len, mode="nearest"
|
443 |
+
).squeeze(0)
|
444 |
+
|
445 |
+
|
446 |
+
def plot_data_to_numpy(x: ndarray, y: ndarray) -> ndarray:
|
447 |
+
matplotlib.use("Agg")
|
448 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
449 |
+
plt.plot(x)
|
450 |
+
plt.plot(y)
|
451 |
+
plt.tight_layout()
|
452 |
+
|
453 |
+
fig.canvas.draw()
|
454 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
455 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
456 |
+
plt.close()
|
457 |
+
return data
|
458 |
+
|
459 |
+
|
460 |
+
def get_gpu_memory(type_: Literal["total", "free", "used"]) -> Sequence[int] | None:
|
461 |
+
command = f"nvidia-smi --query-gpu=memory.{type_} --format=csv"
|
462 |
+
try:
|
463 |
+
memory_free_info = (
|
464 |
+
subprocess.check_output(command.split())
|
465 |
+
.decode("ascii")
|
466 |
+
.split("\n")[:-1][1:]
|
467 |
+
)
|
468 |
+
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
469 |
+
return memory_free_values
|
470 |
+
except Exception:
|
471 |
+
return
|
472 |
+
|
473 |
+
|
474 |
+
def get_total_gpu_memory(type_: Literal["total", "free", "used"]) -> int | None:
|
475 |
+
memories = get_gpu_memory(type_)
|
476 |
+
if memories is None:
|
477 |
+
return
|
478 |
+
return sum(memories)
|