XXXXRT666 commited on
Commit
d4d21ad
·
1 Parent(s): d0754c2
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +196 -0
  3. .pre-commit-config.yaml +15 -0
  4. GPT_SoVITS/Accelerate/MLX/__init__.py +12 -0
  5. GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py +181 -0
  6. GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py +99 -0
  7. GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py +103 -0
  8. GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py +65 -0
  9. GPT_SoVITS/Accelerate/MLX/structs_mlx.py +152 -0
  10. GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +238 -0
  11. GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py +530 -0
  12. GPT_SoVITS/Accelerate/PyTorch/__init__.py +30 -0
  13. GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py +158 -0
  14. GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py +166 -0
  15. GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py +175 -0
  16. GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py +166 -0
  17. GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py +145 -0
  18. GPT_SoVITS/Accelerate/PyTorch/export.py +467 -0
  19. GPT_SoVITS/Accelerate/PyTorch/nn.py +69 -0
  20. GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py +67 -0
  21. GPT_SoVITS/Accelerate/PyTorch/structs.py +151 -0
  22. GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +223 -0
  23. GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +672 -0
  24. GPT_SoVITS/Accelerate/__init__.py +30 -0
  25. GPT_SoVITS/Accelerate/logger.py +203 -0
  26. GPT_SoVITS/configs/.gitignore +1 -0
  27. GPT_SoVITS/configs/s2.json +91 -0
  28. GPT_SoVITS/configs/s2v2Pro.json +91 -0
  29. GPT_SoVITS/configs/s2v2ProPlus.json +91 -0
  30. GPT_SoVITS/eres2net/ERes2NetV2.py +252 -0
  31. GPT_SoVITS/eres2net/fusion.py +27 -0
  32. GPT_SoVITS/eres2net/kaldi.py +844 -0
  33. GPT_SoVITS/eres2net/pooling_layers.py +101 -0
  34. GPT_SoVITS/f5_tts/model/__init__.py +3 -0
  35. GPT_SoVITS/f5_tts/model/backbones/README.md +20 -0
  36. GPT_SoVITS/f5_tts/model/backbones/dit.py +193 -0
  37. GPT_SoVITS/f5_tts/model/backbones/mmdit.py +144 -0
  38. GPT_SoVITS/f5_tts/model/backbones/unett.py +218 -0
  39. GPT_SoVITS/f5_tts/model/modules.py +665 -0
  40. GPT_SoVITS/feature_extractor/__init__.py +3 -0
  41. GPT_SoVITS/feature_extractor/cnhubert.py +46 -0
  42. GPT_SoVITS/inference_webui.py +1104 -0
  43. GPT_SoVITS/module/attentions.py +658 -0
  44. GPT_SoVITS/module/attentions_onnx.py +385 -0
  45. GPT_SoVITS/module/commons.py +185 -0
  46. GPT_SoVITS/module/core_vq.py +365 -0
  47. GPT_SoVITS/module/data_utils.py +1073 -0
  48. GPT_SoVITS/module/losses.py +70 -0
  49. GPT_SoVITS/module/mel_processing.py +142 -0
  50. GPT_SoVITS/module/models.py +1411 -0
.gitattributes CHANGED
@@ -1 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  GPT_SoVITS/text/ja_userdic/userdict.csv filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ GPT_SoVITS/text/G2PWModel/* filter=lfs diff=lfs merge=lfs -text
37
+ GPT_SoVITS/text/G2PWModel/** filter=lfs diff=lfs merge=lfs -text
38
  GPT_SoVITS/text/ja_userdic/userdict.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .vscode
3
+ __pycache__
4
+ *.pyc
5
+ env
6
+ runtime
7
+ .idea
8
+ output
9
+ logs
10
+ SoVITS_weights*/
11
+ GPT_weights*/
12
+ TEMP
13
+ weight.json
14
+ ffmpeg*
15
+ ffprobe*
16
+ cfg.json
17
+ speakers.json
18
+ ref_audios
19
+ tools/AP_BWE/24kto48k/*
20
+ !tools/AP_BWE/24kto48k/readme.txt
21
+ onnx_export
22
+
23
+ # Byte-compiled / optimized / DLL files
24
+ __pycache__/
25
+ *.py[cod]
26
+ *$py.class
27
+
28
+ # C extensions
29
+ *.so
30
+
31
+ # Distribution / packaging
32
+ .Python
33
+ build/
34
+ develop-eggs/
35
+ dist/
36
+ downloads/
37
+ eggs/
38
+ .eggs/
39
+ lib/
40
+ lib64/
41
+ parts/
42
+ sdist/
43
+ var/
44
+ wheels/
45
+ share/python-wheels/
46
+ *.egg-info/
47
+ .installed.cfg
48
+ *.egg
49
+ MANIFEST
50
+
51
+ # PyInstaller
52
+ # Usually these files are written by a python script from a template
53
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
54
+ *.manifest
55
+ *.spec
56
+
57
+ # Installer logs
58
+ pip-log.txt
59
+ pip-delete-this-directory.txt
60
+
61
+ # Unit test / coverage reports
62
+ htmlcov/
63
+ .tox/
64
+ .nox/
65
+ .coverage
66
+ .coverage.*
67
+ .cache
68
+ nosetests.xml
69
+ coverage.xml
70
+ *.cover
71
+ *.py,cover
72
+ .hypothesis/
73
+ .pytest_cache/
74
+ cover/
75
+
76
+ # Translations
77
+ *.mo
78
+ *.pot
79
+
80
+ # Django stuff:
81
+ *.log
82
+ local_settings.py
83
+ db.sqlite3
84
+ db.sqlite3-journal
85
+
86
+ # Flask stuff:
87
+ instance/
88
+ .webassets-cache
89
+
90
+ # Scrapy stuff:
91
+ .scrapy
92
+
93
+ # Sphinx documentation
94
+ docs/_build/
95
+
96
+ # PyBuilder
97
+ .pybuilder/
98
+ target/
99
+
100
+ # Jupyter Notebook
101
+ .ipynb_checkpoints
102
+
103
+ # IPython
104
+ profile_default/
105
+ ipython_config.py
106
+
107
+ # pyenv
108
+ # For a library or package, you might want to ignore these files since the code is
109
+ # intended to run in multiple environments; otherwise, check them in:
110
+ # .python-version
111
+
112
+ # pipenv
113
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
114
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
115
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
116
+ # install all needed dependencies.
117
+ #Pipfile.lock
118
+
119
+ # UV
120
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
121
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
122
+ # commonly ignored for libraries.
123
+ #uv.lock
124
+
125
+ # poetry
126
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
127
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
128
+ # commonly ignored for libraries.
129
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
130
+ #poetry.lock
131
+
132
+ # pdm
133
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
134
+ #pdm.lock
135
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
136
+ # in version control.
137
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
138
+ .pdm.toml
139
+ .pdm-python
140
+ .pdm-build/
141
+
142
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
143
+ __pypackages__/
144
+
145
+ # Celery stuff
146
+ celerybeat-schedule
147
+ celerybeat.pid
148
+
149
+ # SageMath parsed files
150
+ *.sage.py
151
+
152
+ # Environments
153
+ .env
154
+ .venv
155
+ env/
156
+ venv/
157
+ ENV/
158
+ env.bak/
159
+ venv.bak/
160
+
161
+ # Spyder project settings
162
+ .spyderproject
163
+ .spyproject
164
+
165
+ # Rope project settings
166
+ .ropeproject
167
+
168
+ # mkdocs documentation
169
+ /site
170
+
171
+ # mypy
172
+ .mypy_cache/
173
+ .dmypy.json
174
+ dmypy.json
175
+
176
+ # Pyre type checker
177
+ .pyre/
178
+
179
+ # pytype static type analyzer
180
+ .pytype/
181
+
182
+ # Cython debug symbols
183
+ cython_debug/
184
+
185
+ # PyCharm
186
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
187
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
188
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
189
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
190
+ #.idea/
191
+
192
+ # Ruff stuff:
193
+ .ruff_cache/
194
+
195
+ # PyPI configuration file
196
+ .pypirc
.pre-commit-config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ci:
2
+ autoupdate_schedule: monthly
3
+
4
+ repos:
5
+ - repo: https://github.com/astral-sh/ruff-pre-commit
6
+ rev: v0.11.7
7
+ hooks:
8
+ # Run the linter.
9
+ - id: ruff
10
+ types_or: [ python, pyi ]
11
+ args: [ --fix , "--exit-zero" ]
12
+ # Run the formatter.
13
+ - id: ruff-format
14
+ types_or: [ python, pyi ]
15
+ args: [ --line-length, "120", --target-version, "py310" ]
GPT_SoVITS/Accelerate/MLX/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import platform
3
+
4
+ if importlib.util.find_spec("mlx") is not None and platform.system() == "Darwin":
5
+ from .sample_funcs_mlx import sample_naive as sample_naive_mlx
6
+ from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
7
+
8
+ backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
9
+ else:
10
+ backends = []
11
+
12
+ __all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from ..structs_mlx import KVCacheQ
9
+ from ..t2s_model_abc import (
10
+ AttentionABC,
11
+ KVCache,
12
+ KVCacheHND,
13
+ T2SDecoderABC,
14
+ TransformerBlockABC,
15
+ TransformerDecoderABC,
16
+ )
17
+
18
+ Array = mx.array
19
+
20
+
21
+ class Attention(AttentionABC):
22
+ def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
23
+ super().__init__(n_head, hidden_dim, max_seq_length)
24
+ self.kc_class = KVCacheHND
25
+
26
+ @staticmethod
27
+ def quantized_scaled_dot_product_attention(
28
+ queries: Array,
29
+ q_keys: tuple[Array, Array, Array],
30
+ q_values: tuple[Array, Array, Array],
31
+ scale: float,
32
+ mask: Array,
33
+ group_size: int = 32,
34
+ bits: int = 8,
35
+ ) -> Array:
36
+ queries *= scale
37
+
38
+ scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
39
+ scores = mx.where(mask, scores, -mx.inf)
40
+ scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
41
+ out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
42
+
43
+ return out
44
+
45
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
46
+ bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
47
+
48
+ q, k, v = self.in_proj(x).split(3, axis=-1)
49
+
50
+ q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
51
+
52
+ q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
53
+
54
+ kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
55
+ assert len(kv_cache) == 2
56
+
57
+ max_idx = int(input_pos.max())
58
+
59
+ q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
60
+
61
+ mask = attn_mask[..., :max_idx]
62
+
63
+ attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
64
+
65
+ attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
66
+
67
+ attn = self.out_proj(attn)
68
+
69
+ return attn
70
+
71
+ # def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
72
+ # bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
73
+
74
+ # q, k, v = self.in_proj(x).split(3, axis=-1)
75
+
76
+ # q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
77
+
78
+ # q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
79
+
80
+ # kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
81
+
82
+ # assert len(kv_cache) == 3
83
+ # (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
84
+
85
+ # k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
86
+
87
+ # mask = attn_mask[..., : int(input_pos.max())]
88
+
89
+ # attn = Attention.quantized_scaled_dot_product_attention(
90
+ # q,
91
+ # (k_q, k_s, k_b),
92
+ # (v_q, v_s, v_b),
93
+ # self.scale,
94
+ # mask,
95
+ # group_size,
96
+ # bits,
97
+ # )
98
+
99
+ # attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
100
+
101
+ # output = self.out_proj(attn)
102
+
103
+ # return output
104
+
105
+
106
+ class TransformerBlock(TransformerBlockABC):
107
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
108
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
109
+
110
+ self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
111
+
112
+
113
+ class TransformerDecoder(TransformerDecoderABC):
114
+ def __init__(
115
+ self,
116
+ hidden_dim: int,
117
+ n_layer: int,
118
+ n_head: int,
119
+ ffn_dim: int,
120
+ vocab_size: int,
121
+ max_seq_length: int,
122
+ max_batch_size: int,
123
+ *args,
124
+ **kwds,
125
+ ) -> None:
126
+ super().__init__(
127
+ hidden_dim,
128
+ n_layer,
129
+ n_head,
130
+ ffn_dim,
131
+ vocab_size,
132
+ max_seq_length,
133
+ max_batch_size,
134
+ *args,
135
+ **kwds,
136
+ )
137
+
138
+ self.layers = [
139
+ TransformerBlock(
140
+ n_head,
141
+ ffn_dim,
142
+ hidden_dim,
143
+ max_seq_length,
144
+ *args,
145
+ **kwds,
146
+ )
147
+ for _ in range(n_layer)
148
+ ]
149
+
150
+
151
+ class T2SDecoder(T2SDecoderABC):
152
+ def __init__(
153
+ self,
154
+ config: dict,
155
+ max_seq_length: int = 2000,
156
+ max_batch_size: int = 10,
157
+ ) -> None:
158
+ super().__init__(config, max_seq_length, max_batch_size)
159
+
160
+ self.h = TransformerDecoder(
161
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
162
+ )
163
+
164
+ self.kv_class = KVCacheHND
165
+ self.group_size = 32
166
+ self.bits = 8
167
+ self.mode = "affine"
168
+
169
+ def set_mode(self, mode: str):
170
+ assert mode in ["affine", "mxfp4"]
171
+ self.mode = mode
172
+ if self.mode == "mxfp4":
173
+ self.bits = 4
174
+ else:
175
+ self.bits = 8
176
+
177
+ def quantized(self):
178
+ nn.quantize(self, self.group_size, self.bits, mode=self.mode)
179
+ # for layer in self.h.layers:
180
+ # nn.quantize(layer.feed_forward, self.group_size, self.bits)
181
+ # nn.quantize(layer.attention, self.group_size, self.bits)
GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import mlx.core as mx
6
+
7
+ from ..structs_mlx import KVCache, KVCacheQ
8
+ from ..t2s_model_abc import (
9
+ AttentionABC,
10
+ KVCacheHND,
11
+ T2SDecoderABC,
12
+ TransformerBlockABC,
13
+ TransformerDecoderABC,
14
+ )
15
+
16
+ Array = mx.array
17
+
18
+
19
+ class Attention(AttentionABC):
20
+ def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
21
+ super().__init__(n_head, hidden_dim, max_seq_length)
22
+ self.kc_class = KVCacheHND
23
+
24
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
25
+ bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
26
+
27
+ q, k, v = self.in_proj(x).split(3, axis=-1)
28
+
29
+ q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
30
+
31
+ q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
32
+
33
+ kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
34
+ assert len(kv_cache) == 2
35
+
36
+ k, v = kv_cache
37
+
38
+ attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
39
+
40
+ attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
41
+
42
+ attn = self.out_proj(attn)
43
+
44
+ return attn
45
+
46
+
47
+ class TransformerBlock(TransformerBlockABC):
48
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
49
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
50
+
51
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
52
+
53
+
54
+ class TransformerDecoder(TransformerDecoderABC):
55
+ def __init__(
56
+ self,
57
+ hidden_dim: int,
58
+ n_layer: int,
59
+ n_head: int,
60
+ ffn_dim: int,
61
+ vocab_size: int,
62
+ max_seq_length: int,
63
+ max_batch_size: int,
64
+ ) -> None:
65
+ super().__init__(
66
+ hidden_dim,
67
+ n_layer,
68
+ n_head,
69
+ ffn_dim,
70
+ vocab_size,
71
+ max_seq_length,
72
+ max_batch_size,
73
+ )
74
+
75
+ self.layers = [
76
+ TransformerBlock(
77
+ n_head,
78
+ ffn_dim,
79
+ hidden_dim,
80
+ max_seq_length,
81
+ )
82
+ for _ in range(n_layer)
83
+ ]
84
+
85
+
86
+ class T2SDecoder(T2SDecoderABC):
87
+ def __init__(
88
+ self,
89
+ config: dict,
90
+ max_seq_length: int = 2000,
91
+ max_batch_size: int = 10,
92
+ ) -> None:
93
+ super().__init__(config, max_seq_length, max_batch_size)
94
+
95
+ self.h = TransformerDecoder(
96
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
97
+ )
98
+
99
+ self.kv_class = KVCacheHND
GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import mlx.core as mx
6
+
7
+ from ..structs_mlx import KVCache, KVCacheQ
8
+ from ..t2s_model_abc import (
9
+ AttentionABC,
10
+ KVCacheHND,
11
+ T2SDecoderABC,
12
+ TransformerBlockABC,
13
+ TransformerDecoderABC,
14
+ )
15
+
16
+ Array = mx.array
17
+
18
+
19
+ class Attention(AttentionABC):
20
+ def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
21
+ super().__init__(n_head, hidden_dim, max_seq_length)
22
+ self.kc_class = KVCacheHND
23
+
24
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
25
+ bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
26
+
27
+ q, k, v = self.in_proj(x).split(3, axis=-1)
28
+
29
+ q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
30
+
31
+ q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
32
+
33
+ kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
34
+ assert len(kv_cache) == 2
35
+
36
+ max_idx = int(input_pos.max())
37
+
38
+ q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
39
+
40
+ mask = attn_mask[..., :max_idx]
41
+
42
+ attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
43
+
44
+ attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
45
+
46
+ attn = self.out_proj(attn)
47
+
48
+ return attn
49
+
50
+
51
+ class TransformerBlock(TransformerBlockABC):
52
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
53
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
54
+
55
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
56
+
57
+
58
+ class TransformerDecoder(TransformerDecoderABC):
59
+ def __init__(
60
+ self,
61
+ hidden_dim: int,
62
+ n_layer: int,
63
+ n_head: int,
64
+ ffn_dim: int,
65
+ vocab_size: int,
66
+ max_seq_length: int,
67
+ max_batch_size: int,
68
+ ) -> None:
69
+ super().__init__(
70
+ hidden_dim,
71
+ n_layer,
72
+ n_head,
73
+ ffn_dim,
74
+ vocab_size,
75
+ max_seq_length,
76
+ max_batch_size,
77
+ )
78
+
79
+ self.layers = [
80
+ TransformerBlock(
81
+ n_head,
82
+ ffn_dim,
83
+ hidden_dim,
84
+ max_seq_length,
85
+ )
86
+ for _ in range(n_layer)
87
+ ]
88
+
89
+
90
+ class T2SDecoder(T2SDecoderABC):
91
+ def __init__(
92
+ self,
93
+ config: dict,
94
+ max_seq_length: int = 2000,
95
+ max_batch_size: int = 10,
96
+ ) -> None:
97
+ super().__init__(config, max_seq_length, max_batch_size)
98
+
99
+ self.h = TransformerDecoder(
100
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
101
+ )
102
+
103
+ self.kv_class = KVCacheHND
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, cast
2
+
3
+ import mlx.core as mx
4
+
5
+ Array = mx.array
6
+
7
+
8
+ class SampleProtocolMLX(Protocol):
9
+ @staticmethod
10
+ def __call__(
11
+ logits: Array,
12
+ previous_tokens: Array,
13
+ temperature: float,
14
+ top_k: int,
15
+ top_p: float,
16
+ repetition_penalty: float,
17
+ ) -> Array: ...
18
+
19
+
20
+ class sample_naive(SampleProtocolMLX):
21
+ # @partial(mx.compile)
22
+ @staticmethod
23
+ def __call__(
24
+ logits,
25
+ previous_tokens,
26
+ temperature,
27
+ top_k,
28
+ top_p,
29
+ repetition_penalty,
30
+ ):
31
+ if temperature <= 1e-5:
32
+ probs = mx.softmax(logits, axis=-1)
33
+ return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
34
+
35
+ if repetition_penalty != 1.0:
36
+ batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
37
+ previous_tokens = previous_tokens.astype(mx.int64)
38
+ selected_logists = logits[batch_idx, previous_tokens]
39
+ selected_logists = mx.where(
40
+ selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
41
+ )
42
+ logits[batch_idx, previous_tokens] = selected_logists
43
+
44
+ if top_p < 1.0:
45
+ sorted_indices = mx.argsort(-logits, axis=-1)
46
+ sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
47
+ cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
48
+ sorted_indices_to_remove = cum_probs > top_p
49
+ sorted_indices_to_remove[:, -1] = False
50
+ indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
51
+ batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
52
+ indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
53
+ logits = mx.where(indices_to_remove, -mx.inf, logits)
54
+
55
+ if temperature < 1.0:
56
+ logits = logits / temperature
57
+
58
+ v = mx.topk(logits, top_k)
59
+ pivot = mx.expand_dims(v[:, 0], -1)
60
+ logits = mx.where(logits < pivot, -mx.inf, logits)
61
+
62
+ gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
63
+ idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
64
+
65
+ return idx_next
GPT_SoVITS/Accelerate/MLX/structs_mlx.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, MutableSequence, Protocol, TypeAlias, cast
9
+
10
+ import mlx.core as mx
11
+ import torch
12
+
13
+ from ..PyTorch.structs import T2SRequest
14
+ from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
15
+
16
+ Tensor = torch.Tensor
17
+ Array = mx.array
18
+
19
+
20
+ @dataclass(slots=True)
21
+ class T2SRequestMLX:
22
+ x: List[Array]
23
+ x_lens: Array
24
+ prompts: Array
25
+ bert_feature: List[Array]
26
+ valid_length: int
27
+ top_k: int = 5
28
+ top_p: float = 1
29
+ early_stop_num: int = -1
30
+ temperature: float = 1.0
31
+ repetition_penalty: float = 1.35
32
+
33
+ @classmethod
34
+ def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
35
+ x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
36
+ x_lens = mx.array(request.x_lens.cpu())
37
+ prompts = mx.array(request.prompts.cpu())
38
+ bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
39
+
40
+ return cls(
41
+ x,
42
+ x_lens,
43
+ prompts,
44
+ bert_feature,
45
+ request.valid_length,
46
+ request.top_k,
47
+ request.top_p,
48
+ request.early_stop_num,
49
+ request.temperature,
50
+ request.repetition_penalty,
51
+ )
52
+
53
+
54
+ KVCache: TypeAlias = tuple[Array, Array]
55
+ KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
56
+
57
+
58
+ class KVCacheProtocol(Protocol):
59
+ @staticmethod
60
+ def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
61
+
62
+ @staticmethod
63
+ def update_cache(
64
+ input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
65
+ ) -> KVCache | KVCacheQ: ...
66
+
67
+ @staticmethod
68
+ def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
69
+
70
+ @staticmethod
71
+ def init_cache(
72
+ batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
73
+ ) -> KVCache | KVCacheQ: ...
74
+
75
+
76
+ class T2SDecoderProtocol(Protocol):
77
+ max_seq_length: int
78
+ EOS: int
79
+ n_head: int
80
+
81
+ def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
82
+
83
+
84
+ class T2SSessionMLX:
85
+ def __init__(
86
+ self,
87
+ decoder: T2SDecoderProtocol,
88
+ request_torch: T2SRequest,
89
+ sample_func: type[SampleProtocolMLX] = sample_naive,
90
+ device: mx.Device = mx.Device(mx.cpu),
91
+ dtype: mx.Dtype = mx.float32,
92
+ ):
93
+ with mx.stream(device):
94
+ request = T2SRequestMLX.from_torch(request_torch)
95
+
96
+ self.decoder = decoder
97
+ self.request = request
98
+ self.device = device
99
+ self.dtype = dtype
100
+
101
+ bsz = len(request.x)
102
+ y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
103
+ self.bsz = bsz
104
+ self.y_len = y_len
105
+
106
+ # Cache
107
+ self.kv_cache: MutableSequence[KVCache | KVCacheQ]
108
+ self.sample = sample_func()
109
+
110
+ # Forward args
111
+ self.x = [i.astype(mx.int32) for i in request.x]
112
+ self.x_lens = request.x_lens.astype(mx.int32)
113
+ self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
114
+ self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
115
+ self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
116
+
117
+ self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
118
+
119
+ self.input_pos = mx.zeros_like(self.prefill_len)
120
+ self.input_pos += self.prefill_len
121
+
122
+ # EOS
123
+ self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
124
+ self.y_results: List[Array] = [None] * len(self.x) # type: ignore
125
+
126
+ self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
127
+
128
+ max_len = int(self.prefill_len.max(-1))
129
+ attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
130
+
131
+ for bs in range(bsz):
132
+ pos = int(self.x_lens[bs])
133
+ seq_len = pos + y_len
134
+
135
+ attn_mask[bs, :seq_len, :pos] = True
136
+
137
+ ar_mask = ~mx.triu(
138
+ x=mx.ones(
139
+ shape=(
140
+ y_len,
141
+ y_len,
142
+ ),
143
+ dtype=mx.bool_,
144
+ ),
145
+ k=1,
146
+ )
147
+ attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
148
+
149
+ attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
150
+ self.attn_mask = attn_mask
151
+
152
+ mx.eval(self.attn_mask)
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import time
4
+ import traceback
5
+ from typing import cast
6
+
7
+ import mlx.core as mx
8
+ import torch
9
+ from rich.progress import BarColumn, Progress, TextColumn
10
+
11
+ from ..logger import SpeedColumnToken, console, logger
12
+ from ..PyTorch.structs import T2SEngineProtocol, T2SRequest, T2SResult
13
+ from .backends import mlx_quantized, mlx_static, mlx_varlen
14
+ from .structs_mlx import T2SSessionMLX
15
+ from .t2s_model_abc import T2SDecoderABC
16
+
17
+ Array = mx.array
18
+ Tensor = torch.Tensor
19
+
20
+
21
+ class T2SEngine(T2SEngineProtocol):
22
+ def __init__(
23
+ self,
24
+ decoder_model: T2SDecoderABC,
25
+ device: mx.Device | str = mx.Device(mx.cpu),
26
+ dtype: torch.dtype | mx.Dtype = torch.float32,
27
+ ) -> None:
28
+ if isinstance(device, str):
29
+ match device:
30
+ case "mx.cpu":
31
+ device = mx.Device(mx.cpu)
32
+ case "mx.gpu":
33
+ device = mx.Device(mx.gpu)
34
+
35
+ match dtype:
36
+ case torch.float32:
37
+ dtype = mx.float32
38
+ case torch.float16:
39
+ dtype = mx.float16
40
+ case torch.bfloat16:
41
+ dtype = mx.bfloat16
42
+
43
+ device = cast(mx.Device, device)
44
+ dtype = cast(mx.Dtype, dtype)
45
+
46
+ assert device.type.value in {0, 1}
47
+ assert dtype in {mx.float16, mx.bfloat16, mx.float32}
48
+
49
+ self.device = device
50
+ self.dtype = dtype
51
+
52
+ mx.set_default_device(device)
53
+ decoder_model.set_dtype(self.dtype)
54
+
55
+ self.decoder_model: T2SDecoderABC = decoder_model
56
+ self.decoder_model.compile()
57
+
58
+ def _handle_request(self, request: T2SRequest):
59
+ decoder = self.decoder_model
60
+ session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
61
+ batch_idx = mx.arange(session.bsz)
62
+
63
+ t1 = 0.0
64
+ infer_speed = 0.0
65
+ infer_time = 0.0
66
+
67
+ with (
68
+ mx.stream(session.device),
69
+ Progress(
70
+ TextColumn("[cyan]{task.description}"),
71
+ BarColumn(),
72
+ TextColumn("{task.completed}/{task.total}"),
73
+ SpeedColumnToken(show_speed=True),
74
+ console=console,
75
+ transient=True,
76
+ ) as progress,
77
+ ):
78
+ max_token = min(2000 - int(session.input_pos.max()), 1500)
79
+
80
+ task = progress.add_task("T2S Decoding", total=max_token)
81
+ for idx in range(1500):
82
+ progress.update(task, advance=1)
83
+ if idx == 0:
84
+ session.kv_cache = decoder.init_cache(session.bsz)
85
+ xy_dec = decoder.h.prefill(
86
+ session.xy_pos,
87
+ session.attn_mask,
88
+ session.kv_cache,
89
+ ) # bs, seq_len, embed_dim
90
+ xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
91
+ else:
92
+ args, kwds = decoder.pre_forward(session)
93
+ xy_dec = decoder.h(
94
+ session.input_pos,
95
+ session.xy_pos,
96
+ session.kv_cache,
97
+ batch_idx,
98
+ *args,
99
+ **kwds,
100
+ )
101
+
102
+ decoder.post_forward(idx, session)
103
+ logits = decoder.ar_predict_layer(xy_dec[:, -1])
104
+ session.input_pos += 1
105
+
106
+ if idx == 0:
107
+ logits[:, -1] = -mx.inf
108
+
109
+ samples = session.sample(
110
+ logits=logits,
111
+ previous_tokens=session.y[:, : session.y_len + idx],
112
+ top_k=request.top_k,
113
+ top_p=request.top_p,
114
+ repetition_penalty=request.repetition_penalty,
115
+ temperature=request.temperature,
116
+ )
117
+
118
+ session.y[batch_idx, session.y_len + idx] = samples
119
+
120
+ argmax_token = mx.argmax(logits, axis=-1)
121
+ sample_token = samples.squeeze(1)
122
+ EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
123
+
124
+ newly_done_mask = EOS_mask & (~session.completed)
125
+ newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
126
+ pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
127
+ pos_sorted = mx.sort(pos, axis=0)
128
+ valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
129
+ pos_final = pos_sorted[: int(valid_count)]
130
+ newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
131
+
132
+ if newly_done_indices.size > 0:
133
+ for i in newly_done_indices:
134
+ session.y_results[int(i)] = session.y[i, session.y_len : session.y_len + idx]
135
+ session.completed[newly_done_indices] = True
136
+
137
+ if mx.all(session.completed).item():
138
+ if session.y[:, session.y_len :].sum() == 0:
139
+ session.y_results = [mx.array([0]) for _ in range(session.bsz)]
140
+ logger.error("Bad Zero Prediction")
141
+ else:
142
+ logger.info(
143
+ f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
144
+ )
145
+ logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
146
+ infer_time = time.perf_counter() - t1
147
+ infer_speed = (idx - 1) / infer_time
148
+ break
149
+
150
+ if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
151
+ for j in range(session.bsz):
152
+ if not session.completed[j].item():
153
+ session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
154
+ session.completed[j] = True
155
+ logger.error("Bad Full Prediction")
156
+ logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
157
+ infer_time = time.perf_counter() - t1
158
+ infer_speed = (idx - 1) / infer_time
159
+ break
160
+
161
+ y_emb = decoder.ar_audio_embedding(samples)
162
+ session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
163
+ mx.eval(session.xy_pos, session.y)
164
+
165
+ if idx == 1:
166
+ t1 = time.perf_counter()
167
+
168
+ if idx % 100 == 0:
169
+ mx.clear_cache()
170
+
171
+ match session.device:
172
+ case mx.gpu:
173
+ mx.clear_cache()
174
+ case mx.cpu:
175
+ gc.collect()
176
+
177
+ result_mlx = session.y_results[: request.valid_length]
178
+ mx.eval(result_mlx)
179
+ result = [torch.tensor(k) for k in result_mlx]
180
+ return result, infer_speed, infer_time
181
+
182
+ def generate(self, request: T2SRequest):
183
+ try:
184
+ result, infer_speed, infer_time = self._handle_request(request)
185
+ t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
186
+ except Exception as e:
187
+ t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
188
+ return t2s_result
189
+
190
+ @staticmethod
191
+ def replace_key(state_dict: dict[str, Tensor]):
192
+ state_dict_mlx: list[tuple[str, Array]] = []
193
+ for key, value in state_dict.items():
194
+ key = (
195
+ key.replace("model.", "")
196
+ .replace("in_proj_", "in_proj.")
197
+ .replace("self_attn", "attention")
198
+ .replace("linear", "feed_forward.linear")
199
+ .replace("norm1", "attention_norm")
200
+ .replace("norm2", "ffn_norm")
201
+ )
202
+ value_mlx = mx.array(value)
203
+ state_dict_mlx.append((key, value_mlx))
204
+ return state_dict_mlx
205
+
206
+ @staticmethod
207
+ def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
208
+ logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
209
+ dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
210
+ config = dict_s1["config"]
211
+ match backend:
212
+ case "MLX-Varlen":
213
+ decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
214
+ case "MLX-Static":
215
+ decoder_cls = mlx_static.T2SDecoder
216
+ case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
217
+ decoder_cls = mlx_quantized.T2SDecoder
218
+ case _:
219
+ raise RuntimeError(f"Backend {backend} Not Found")
220
+
221
+ decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
222
+ state_dict = dict_s1["weight"]
223
+ state_dict_mlx = T2SEngine.replace_key(state_dict)
224
+ decoder.load_weights(state_dict_mlx)
225
+ decoder.eval()
226
+ mx.eval(decoder)
227
+
228
+ if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
229
+ if backend == "MLX-Quantized-Affine":
230
+ decoder.set_mode("affine")
231
+ elif backend == "MLX-Quantized-MXFP4":
232
+ decoder.set_mode("mxfp4")
233
+ else:
234
+ raise RuntimeError(f"Quantized Backend {backend} Not Supported")
235
+ decoder.quantized()
236
+ mx.eval(decoder)
237
+
238
+ return decoder
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from abc import ABC, abstractmethod
5
+ from typing import MutableSequence, cast
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+
10
+ from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
11
+
12
+ Array = mx.array
13
+
14
+
15
+ class TokenEmbedding(nn.Module):
16
+ def __init__(
17
+ self,
18
+ embedding_dim: int,
19
+ vocab_size: int,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.vocab_size = vocab_size
24
+ self.embedding_dim = embedding_dim
25
+
26
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
27
+
28
+ @property
29
+ def weight(self):
30
+ return self.word_embeddings.weight
31
+
32
+ def embedding(self, index: int):
33
+ return self.word_embeddings.weight[index : index + 1]
34
+
35
+ def __call__(self, x: Array):
36
+ x = self.word_embeddings(x)
37
+ return x
38
+
39
+
40
+ class SinePositionalEmbedding(nn.Module):
41
+ def __init__(
42
+ self,
43
+ embedding_dim: int,
44
+ scale: bool = False,
45
+ max_batch_size: int = 10,
46
+ max_seq_len: int = 2000,
47
+ ):
48
+ super().__init__()
49
+ self.embedding_dim = embedding_dim
50
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
51
+ self.alpha = mx.ones(1)
52
+ self.max_batch_size = max_batch_size
53
+ self.max_seq_len = max_seq_len
54
+
55
+ self.reverse = False
56
+ self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
57
+ self.compute_pe()
58
+
59
+ def compute_pe(self):
60
+ """Reset the positional encodings."""
61
+
62
+ if self.reverse:
63
+ position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
64
+ else:
65
+ position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
66
+ div_term = mx.exp(
67
+ mx.arange(
68
+ 0,
69
+ self.embedding_dim,
70
+ 2,
71
+ )
72
+ * -(math.log(10000.0) / self.embedding_dim)
73
+ )
74
+ pe = self._pe
75
+ pe[:, :, 0::2] = mx.sin(position * div_term)
76
+ pe[:, :, 1::2] = mx.cos(position * div_term)
77
+
78
+ def __call__(self, input_pos: Array, x: Array):
79
+ """
80
+ Args:
81
+ input_pos (Array): [batch_size, ]
82
+ x (Array): [batch_size, 1, embed_dim]
83
+
84
+ Returns:
85
+ embedded_x (Array): [batch_size, 1, embed_dim]
86
+ """
87
+
88
+ batch_size = cast(tuple[int, ...], x.shape)[0]
89
+ pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
90
+
91
+ return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
92
+
93
+ def prefill(self, x: Array):
94
+ """
95
+ Args:
96
+ x (Array): [batch_size, seq_len, embed_dim]
97
+
98
+ Returns:
99
+ embedded_x (Array): [batch_size, seq_len, embed_dim]
100
+ """
101
+ pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
102
+ return x * self.x_scale + self.alpha * pe_values
103
+
104
+
105
+ class KVCacheHND(KVCacheProtocol):
106
+ @staticmethod
107
+ def empty(kv_cache):
108
+ assert len(kv_cache) == 2
109
+ k_cache, v_cache = kv_cache
110
+
111
+ k_cache[:] = 0
112
+ v_cache[:] = 0
113
+
114
+ @staticmethod
115
+ def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
116
+ # input_pos: [B, ], k_val: [B, H, 1, D]
117
+ assert len(kv_cache) == 2
118
+ k_out, v_out = kv_cache
119
+ ip0 = input_pos - 1
120
+
121
+ k_out[cache_idx, :, ip0, None] = k_val
122
+ v_out[cache_idx, :, ip0, None] = v_val
123
+
124
+ return k_out, v_out
125
+
126
+ @staticmethod
127
+ def prefill_kv(k_val, v_val, kv_cache):
128
+ # k_val: [B, S, H, D]
129
+ assert len(kv_cache) == 2
130
+ k_cache, v_cache = kv_cache
131
+
132
+ k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
133
+ v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
134
+
135
+ @staticmethod
136
+ def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
137
+ cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
138
+
139
+ return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
140
+
141
+
142
+ class KVCacheHNDQuantized(KVCacheProtocol):
143
+ @staticmethod
144
+ def _el_per_int(bits: int) -> int:
145
+ return 32 // bits
146
+
147
+ @staticmethod
148
+ def _packed_dim(head_dim: int, bits: int = 8) -> int:
149
+ el_per_int = KVCacheHNDQuantized._el_per_int(bits)
150
+ if head_dim % el_per_int != 0:
151
+ raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
152
+ return head_dim // el_per_int
153
+
154
+ @staticmethod
155
+ def _group_count(head_dim: int, group_size: int = 32) -> int:
156
+ assert group_size in {32, 64, 128}
157
+ if head_dim % group_size != 0:
158
+ raise ValueError(f"{head_dim} is not divisible by {group_size=}")
159
+ return head_dim // group_size
160
+
161
+ @staticmethod
162
+ def empty(kv_cache) -> None:
163
+ assert len(kv_cache) == 3
164
+ (k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
165
+
166
+ k_q[:] = 0
167
+ k_s[:] = 0
168
+ k_b[:] = 0
169
+ v_q[:] = 0
170
+ v_s[:] = 0
171
+ v_b[:] = 0
172
+
173
+ @staticmethod
174
+ def update_cache(
175
+ input_pos,
176
+ k_val,
177
+ v_val,
178
+ kv_cache,
179
+ cache_idx,
180
+ ):
181
+ # input_pos: [B, ], k_val: [B, H, 1, D]
182
+
183
+ assert len(kv_cache) == 3
184
+ (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
185
+
186
+ k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
187
+ v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
188
+
189
+ ip0 = input_pos - 1
190
+
191
+ k_q_out[cache_idx, :, ip0, None] = k_q
192
+ k_s_out[cache_idx, :, ip0, None] = k_s
193
+ k_b_out[cache_idx, :, ip0, None] = k_b
194
+
195
+ v_q_out[cache_idx, :, ip0, None] = v_q
196
+ v_s_out[cache_idx, :, ip0, None] = v_s
197
+ v_b_out[cache_idx, :, ip0, None] = v_b
198
+
199
+ return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
200
+
201
+ @staticmethod
202
+ def prefill_kv(
203
+ k_val,
204
+ v_val,
205
+ kv_cache,
206
+ ) -> None:
207
+ assert len(kv_cache) == 3
208
+ (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
209
+
210
+ S = cast(tuple[int, ...], k_val.shape)[1]
211
+
212
+ k_sw = k_val.swapaxes(1, 2)
213
+ v_sw = v_val.swapaxes(1, 2)
214
+
215
+ k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
216
+ v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
217
+
218
+ k_q_out[..., :S, :] = k_q
219
+ k_s_out[..., :S, :] = k_s
220
+ k_b_out[..., :S, :] = k_b
221
+
222
+ v_q_out[..., :S, :] = v_q
223
+ v_s_out[..., :S, :] = v_s
224
+ v_b_out[..., :S, :] = v_b
225
+
226
+ @staticmethod
227
+ def init_cache(
228
+ batch_size: int,
229
+ max_seq_length: int,
230
+ n_heads: int,
231
+ head_dim: int,
232
+ dtype: mx.Dtype,
233
+ *,
234
+ group_size: int = 32,
235
+ bits: int = 8,
236
+ ) -> KVCacheQ:
237
+ packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
238
+ group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
239
+
240
+ packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
241
+ group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
242
+
243
+ k_q = mx.zeros(packed_shape, dtype=mx.uint32)
244
+ k_s = mx.zeros(group_shape, dtype=dtype)
245
+ k_b = mx.zeros(group_shape, dtype=dtype)
246
+
247
+ v_q = mx.zeros(packed_shape, dtype=mx.uint32)
248
+ v_s = mx.zeros(group_shape, dtype=dtype)
249
+ v_b = mx.zeros(group_shape, dtype=dtype)
250
+
251
+ return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
252
+
253
+
254
+ class AttentionABC(ABC, nn.Module):
255
+ def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
256
+ super().__init__()
257
+
258
+ self.n_head = n_head
259
+ self.hidden_dim = hidden_dim
260
+ assert hidden_dim % n_head == 0
261
+ self.head_dim = hidden_dim // n_head
262
+
263
+ self.max_seq_length = max_seq_length
264
+
265
+ # key, query, value projections for all heads, but in a batch
266
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
267
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
268
+
269
+ self.scale = 1 / math.sqrt(self.head_dim)
270
+
271
+ self.kc_class: KVCacheProtocol
272
+
273
+ @abstractmethod
274
+ def __call__(
275
+ self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
276
+ ) -> Array: ...
277
+
278
+ def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
279
+ bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
280
+
281
+ q, k, v = self.in_proj(x).split(3, axis=-1)
282
+
283
+ q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
284
+
285
+ self.kc_class.prefill_kv(k, v, kv_cache)
286
+
287
+ q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
288
+
289
+ attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
290
+
291
+ attn = mx.nan_to_num(attn)
292
+
293
+ attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
294
+
295
+ output = self.out_proj(attn)
296
+
297
+ return output
298
+
299
+
300
+ class FeedForward(nn.Module):
301
+ def __init__(self, dim: int, hidden_dim: int) -> None:
302
+ super().__init__()
303
+
304
+ self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
305
+ self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
306
+
307
+ def __call__(self, x: Array):
308
+ return self.linear2(nn.relu(self.linear1(x)))
309
+
310
+
311
+ class TransformerBlockABC(nn.Module):
312
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
313
+ super().__init__()
314
+
315
+ self.hidden_dim = hidden_dim
316
+ self.max_seq_length = max_seq_length
317
+
318
+ self.attention: AttentionABC
319
+
320
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
321
+ self.attention_norm = nn.LayerNorm(self.hidden_dim)
322
+ self.ffn_norm = nn.LayerNorm(self.hidden_dim)
323
+
324
+ def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
325
+ h = self.attention_norm(
326
+ x
327
+ + self.attention(
328
+ x,
329
+ input_pos,
330
+ kv_cache,
331
+ cache_idx,
332
+ attn_mask,
333
+ )
334
+ )
335
+ out = self.ffn_norm(h + self.feed_forward(h))
336
+ return out
337
+
338
+ def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
339
+ h = self.attention_norm(
340
+ x
341
+ + self.attention.prefill(
342
+ x,
343
+ kv_cache,
344
+ attn_mask,
345
+ )
346
+ )
347
+ out = self.ffn_norm(h + self.feed_forward(h))
348
+
349
+ return out
350
+
351
+
352
+ class TransformerDecoderABC(nn.Module):
353
+ def __init__(
354
+ self,
355
+ hidden_dim: int,
356
+ n_layer: int,
357
+ n_head: int,
358
+ ffn_dim: int,
359
+ vocab_size: int,
360
+ max_seq_length: int,
361
+ max_batch_size: int,
362
+ *args,
363
+ **kwds,
364
+ ) -> None:
365
+ super().__init__()
366
+
367
+ self.hidden_dim = hidden_dim
368
+ self.n_head = n_head
369
+ assert hidden_dim % n_head == 0
370
+
371
+ self.head_dim = hidden_dim // n_head
372
+ self.vocab_size = vocab_size
373
+
374
+ self.n_layer = n_layer
375
+
376
+ self.layers: MutableSequence[TransformerBlockABC]
377
+
378
+ self.max_seq_length = max_seq_length
379
+ self.max_batch_size = max_batch_size
380
+
381
+ def __call__(
382
+ self,
383
+ input_pos: Array,
384
+ x: Array,
385
+ kv_caches: MutableSequence[KVCache | KVCacheQ],
386
+ cache_idx: Array,
387
+ *args,
388
+ **kwds,
389
+ ):
390
+ for layer, kv_cache in zip(self.layers, kv_caches):
391
+ x = layer(
392
+ x,
393
+ input_pos,
394
+ kv_cache,
395
+ cache_idx,
396
+ *args,
397
+ **kwds,
398
+ )
399
+
400
+ return x
401
+
402
+ def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
403
+ for layer, kv_cache in zip(self.layers, kv_caches):
404
+ x = layer.prefill(
405
+ x,
406
+ mask,
407
+ kv_cache,
408
+ )
409
+ return x
410
+
411
+
412
+ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
413
+ def __init__(
414
+ self,
415
+ config: dict,
416
+ max_seq_length: int = 2000,
417
+ max_batch_size: int = 10,
418
+ ) -> None:
419
+ super().__init__()
420
+
421
+ hidden_dim: int = config["model"]["hidden_dim"]
422
+ embedding_dim: int = config["model"]["embedding_dim"]
423
+ n_head: int = config["model"]["head"]
424
+ n_layer: int = config["model"]["n_layer"]
425
+ vocab_size: int = config["model"]["vocab_size"]
426
+ phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
427
+ EOS: int = config["model"]["EOS"]
428
+ ffn_dim: int = hidden_dim * 4
429
+
430
+ self.n_layer = int(n_layer)
431
+ self.hidden_dim = int(hidden_dim)
432
+ self.n_head = int(n_head)
433
+ assert hidden_dim % n_head == 0
434
+
435
+ self.head_dim = int(hidden_dim // n_head)
436
+ self.embedding_dim = int(embedding_dim)
437
+ self.ffn_dim = int(ffn_dim)
438
+ self.vocab_size = int(vocab_size)
439
+ self.phoneme_vocab_size = int(phoneme_vocab_size)
440
+ self.max_seq_length = max_seq_length
441
+ self.max_batch_size = max_batch_size
442
+ self.EOS = EOS
443
+ assert self.EOS == self.vocab_size - 1
444
+
445
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
446
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
447
+ self.h: TransformerDecoderABC
448
+
449
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
450
+ self.ar_text_position = SinePositionalEmbedding(
451
+ self.embedding_dim,
452
+ scale=False,
453
+ max_batch_size=max_batch_size,
454
+ max_seq_len=max_seq_length,
455
+ )
456
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
457
+ self.ar_audio_position = SinePositionalEmbedding(
458
+ self.embedding_dim,
459
+ scale=False,
460
+ max_batch_size=max_batch_size,
461
+ max_seq_len=max_seq_length,
462
+ )
463
+
464
+ self.kv_class: KVCacheProtocol
465
+
466
+ def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
467
+ bsz = bsz or self.h.max_batch_size
468
+ assert bsz <= self.h.max_batch_size
469
+ seq_lens = self.h.max_seq_length
470
+ dtype = self.bert_proj.bias.dtype
471
+ cache: MutableSequence[KVCache | KVCacheQ] = [
472
+ self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
473
+ for _ in range(self.n_layer)
474
+ ]
475
+ mx.eval(cache)
476
+ return cache
477
+
478
+ def embed(
479
+ self,
480
+ x: list[Array],
481
+ y: Array,
482
+ bert_features: list[Array],
483
+ ):
484
+ x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
485
+ x_len_max = max(x_len)
486
+ xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
487
+ bert_features[0].dtype
488
+ )
489
+
490
+ bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
491
+
492
+ y_len = cast(tuple[int, ...], y.shape)[1]
493
+ y_emb = self.ar_audio_embedding(y)
494
+ y_pos = self.ar_audio_position.prefill(y_emb)
495
+
496
+ for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
497
+ x_emb = self.ar_text_embedding(x_)
498
+ bert = self.bert_proj(bert_feature)
499
+ x_emb = x_emb + bert
500
+ x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
501
+ xy_pos[[bs], :len_] = x_pos
502
+ xy_pos[[bs], len_ : len_ + y_len] = y_pos
503
+
504
+ mx.eval(xy_pos)
505
+ return xy_pos
506
+
507
+ def compile(self):
508
+ setattr(self.h, "__call__", mx.compile(self.h.__call__))
509
+ # setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
510
+
511
+ def pre_forward(self, session: T2SSessionMLX):
512
+ attn_mask = session.attn_mask
513
+ return list(), dict(attn_mask=attn_mask)
514
+
515
+ def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
516
+ if idx == 0:
517
+ prefill_len = session.prefill_len
518
+ bsz = session.bsz
519
+
520
+ range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
521
+ prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
522
+ attn_mask = range_tensor < prefill_len_expanded
523
+ attn_mask = mx.repeat(attn_mask, self.n_head, 1)
524
+
525
+ session.attn_mask = attn_mask
526
+
527
+ attn_mask = session.attn_mask
528
+ input_pos = session.input_pos
529
+ attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
530
+ mx.eval(attn_mask)
GPT_SoVITS/Accelerate/PyTorch/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import torch
4
+
5
+ from .sample_funcs import sample_naive
6
+ from .structs import T2SRequest, T2SResult
7
+ from .t2s_engine import T2SEngine as T2SEngineTorch
8
+
9
+ torch.set_grad_enabled(False)
10
+
11
+ backends = ["torch_varlen"]
12
+ if torch.cuda.is_available():
13
+ backends.append("torch_static_cuda_graph")
14
+ # if importlib.util.find_spec("sageattention") is not None:
15
+ # for i in range(torch.cuda.device_count()):
16
+ # major, minor = torch.cuda.get_device_capability(i)
17
+ # sm_version = major + minor / 10.0
18
+ # if sm_version >= 7.0:
19
+ # backends.append("sage_attn_varlen_cuda_graph")
20
+ if importlib.util.find_spec("flash_attn") is not None:
21
+ for i in range(torch.cuda.device_count()):
22
+ major, minor = torch.cuda.get_device_capability(i)
23
+ sm_version = major + minor / 10.0
24
+ if sm_version >= 7.5:
25
+ backends.append("flash_attn_varlen_cuda_graph")
26
+ # if torch.mps.is_available():
27
+ # backends.append("mps_flash_attn_varlen")
28
+
29
+
30
+ __all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ from typing import Dict, List, Tuple
6
+
7
+ import kernels
8
+ import torch
9
+
10
+ from .. import nn
11
+ from ..structs import T2SSession
12
+ from ..t2s_model_abc import (
13
+ AttentionABC,
14
+ CUDAGraphCacheABC,
15
+ FeedForward,
16
+ KVCacheNHD,
17
+ KVCacheProtocol,
18
+ T2SDecoderABC,
19
+ TransformerBlockABC,
20
+ TransformerDecoderABC,
21
+ )
22
+
23
+ flash_attn_kernel = None
24
+ try:
25
+ import flash_attn_interface as flash_attn # type: ignore
26
+
27
+ flash_attn_kernel = flash_attn.flash_attn_with_kvcache
28
+ except ModuleNotFoundError:
29
+ try:
30
+ import flash_attn # type: ignore
31
+
32
+ flash_attn_kernel = flash_attn.flash_attn_with_kvcache
33
+
34
+ except ModuleNotFoundError:
35
+ pass
36
+
37
+ if flash_attn_kernel is None:
38
+ flash_attn_kernel = kernels.get_kernel("kernels-community/flash-attn").flash_attn_with_kvcache
39
+
40
+
41
+ Tensor = torch.Tensor
42
+
43
+
44
+ class Attention(AttentionABC):
45
+ def __init__(self, n_head, hidden_dim, max_seq_length):
46
+ super().__init__(n_head, hidden_dim, max_seq_length)
47
+
48
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
49
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
50
+
51
+ def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor:
52
+ bsz, seqlen, _ = x.shape
53
+
54
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
55
+
56
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
57
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
58
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
59
+
60
+ attn: Tensor = flash_attn.flash_attn_with_kvcache( # type: ignore
61
+ q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
62
+ )
63
+
64
+ attn = attn.view(bsz, seqlen, self.hidden_dim)
65
+
66
+ attn = self.out_proj(attn)
67
+
68
+ return attn
69
+
70
+
71
+ class TransformerBlock(TransformerBlockABC):
72
+ def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
73
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
74
+
75
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
76
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
77
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
78
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
79
+
80
+
81
+ class TransformerDecoder(TransformerDecoderABC):
82
+ def __init__(
83
+ self,
84
+ hidden_dim,
85
+ n_layer,
86
+ n_head,
87
+ ffn_dim,
88
+ vocab_size,
89
+ max_seq_length,
90
+ max_batch_size,
91
+ ) -> None:
92
+ super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
93
+
94
+ self.layers = nn.ModuleList( # type: ignore
95
+ TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
96
+ )
97
+
98
+
99
+ class T2SDecoder(T2SDecoderABC):
100
+ def __init__(
101
+ self,
102
+ config,
103
+ max_seq_length=2000,
104
+ max_batch_size=10,
105
+ ) -> None:
106
+ assert torch.cuda.is_available()
107
+ super().__init__(config, max_seq_length, max_batch_size)
108
+
109
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
110
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
111
+ self.h: TransformerDecoderABC = TransformerDecoder(
112
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
113
+ )
114
+
115
+ self.kv_class = KVCacheNHD
116
+
117
+ def post_forward(self, idx: int, session: T2SSession) -> None:
118
+ return super().post_forward(idx, session)
119
+
120
+ def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
121
+ return super().pre_forward(session)
122
+
123
+
124
+ class CUDAGraphCache(CUDAGraphCacheABC):
125
+ def __init__(
126
+ self,
127
+ decoder: T2SDecoder,
128
+ ) -> None:
129
+ self.is_applicable = True
130
+ super().__init__(decoder)
131
+
132
+ def release_graph(self, session: T2SSession):
133
+ if session.id == self.id:
134
+ self.assigned = False
135
+ else:
136
+ del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
137
+
138
+ def get_cache_graph(self, session: T2SSession):
139
+ assert self.graph
140
+ session.graph = self.graph
141
+ session.stream = self.stream
142
+
143
+ session.xy_pos_ = self.xy_pos
144
+ session.xy_dec_ = self.xy_dec
145
+ session.input_pos = self.input_pos.copy_(session.input_pos)
146
+
147
+ for cache, cache_ in zip(self.kv_cache, session.kv_cache):
148
+ cache.sync_cache(cache_)
149
+
150
+ def capture_new_graph(self, session: T2SSession):
151
+ session.xy_pos_ = self.xy_pos.clone()
152
+ session.xy_dec_ = self.xy_dec.clone()
153
+ session.input_pos = self.input_pos.clone().copy_(session.input_pos)
154
+
155
+ args, kwds = self.decoder.pre_forward(session)
156
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
157
+ session.graph = graph
158
+ session.stream = torch.cuda.Stream() # type: ignore
GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ from .. import nn
5
+ from ..structs import KVCacheProtocol, T2SSession
6
+ from ..t2s_model_abc import (
7
+ AttentionABC,
8
+ CUDAGraphCacheABC,
9
+ FeedForward,
10
+ KVCacheHND,
11
+ T2SDecoderABC,
12
+ TransformerBlockABC,
13
+ TransformerDecoderABC,
14
+ )
15
+
16
+ Tensor = torch.Tensor
17
+
18
+
19
+ class Attention(AttentionABC):
20
+ def __init__(self, n_head, hidden_dim, max_seq_length):
21
+ super().__init__(n_head, hidden_dim, max_seq_length)
22
+
23
+ # key, query, value projections for all heads, but in a batch
24
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
25
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
26
+
27
+ def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
28
+ bsz, seqlen, _ = x.shape
29
+
30
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
31
+
32
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
33
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
34
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
35
+
36
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
37
+
38
+ k, v = kv_cache.update(input_pos, k, v)
39
+
40
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
41
+
42
+ attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
43
+
44
+ attn = self.out_proj(attn)
45
+
46
+ return attn
47
+
48
+
49
+ class TransformerBlock(TransformerBlockABC):
50
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
51
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
52
+
53
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
54
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
55
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
56
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
57
+
58
+
59
+ class TransformerDecoder(TransformerDecoderABC):
60
+ def __init__(
61
+ self,
62
+ hidden_dim,
63
+ n_layer,
64
+ n_head,
65
+ ffn_dim,
66
+ vocab_size,
67
+ max_seq_length,
68
+ max_batch_size,
69
+ ) -> None:
70
+ super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
71
+
72
+ self.layers = nn.ModuleList( # type: ignore
73
+ TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
74
+ )
75
+
76
+
77
+ class T2SDecoder(T2SDecoderABC):
78
+ def __init__(
79
+ self,
80
+ config,
81
+ max_seq_length=2000,
82
+ max_batch_size=10,
83
+ ) -> None:
84
+ super().__init__(config, max_seq_length, max_batch_size)
85
+
86
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
87
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
88
+ self.h: TransformerDecoderABC = TransformerDecoder(
89
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
90
+ )
91
+
92
+ self.kv_class = KVCacheHND
93
+
94
+ def pre_forward(self, session: T2SSession):
95
+ attn_mask = session.attn_mask
96
+ return list(), dict(attn_mask=attn_mask)
97
+
98
+ def post_forward(self, idx: int, session: T2SSession) -> None:
99
+ if idx == 0:
100
+ prefill_len = session.prefill_len
101
+ bsz = session.bsz
102
+
103
+ range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
104
+ prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
105
+ attn_mask = range_tensor < prefill_len_expanded
106
+ attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
107
+
108
+ session.attn_mask = attn_mask
109
+
110
+ attn_mask = session.attn_mask
111
+ input_pos = session.input_pos
112
+ attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
113
+
114
+
115
+ class CUDAGraphCache(CUDAGraphCacheABC):
116
+ def __init__(
117
+ self,
118
+ decoder,
119
+ ) -> None:
120
+ self.is_applicable = False
121
+ super().__init__(decoder)
122
+ if torch.cuda.is_available():
123
+ self.attn_mask = (
124
+ torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
125
+ .bool()
126
+ .to(self.device, self.dtype)
127
+ )
128
+
129
+ def release_graph(self, session: T2SSession):
130
+ if session.id == self.id:
131
+ self.assigned = False
132
+ else:
133
+ del (
134
+ session.graph,
135
+ session.xy_pos_,
136
+ session.xy_dec_,
137
+ session.input_pos,
138
+ session.kv_cache,
139
+ session.attn_mask,
140
+ )
141
+
142
+ def get_cache_graph(self, session: T2SSession):
143
+ assert self.graph
144
+ session.graph = self.graph
145
+ session.stream = self.stream
146
+
147
+ session.xy_pos_ = self.xy_pos
148
+ session.xy_dec_ = self.xy_dec
149
+ session.input_pos = self.input_pos.copy_(session.input_pos)
150
+
151
+ session.attn_mask = self.attn_mask
152
+
153
+ for cache, cache_ in zip(self.kv_cache, session.kv_cache):
154
+ cache.sync_cache(cache_)
155
+
156
+ def capture_new_graph(self, session: T2SSession):
157
+ session.xy_pos_ = self.xy_pos.clone()
158
+ session.xy_dec_ = self.xy_dec.clone()
159
+ session.input_pos = self.input_pos.clone().copy_(session.input_pos)
160
+
161
+ session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
162
+
163
+ args, kwds = self.decoder.pre_forward(session)
164
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
165
+ session.graph = graph
166
+ session.stream = torch.cuda.Stream() # type: ignore
GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sageattention # type: ignore
2
+ import torch
3
+
4
+ from .. import nn
5
+ from ..structs import T2SSession
6
+ from ..t2s_model_abc import (
7
+ AttentionABC,
8
+ CUDAGraphCacheABC,
9
+ FeedForward,
10
+ KVCacheHND,
11
+ KVCacheProtocol,
12
+ T2SDecoderABC,
13
+ TransformerBlockABC,
14
+ TransformerDecoderABC,
15
+ )
16
+
17
+ Tensor = torch.Tensor
18
+
19
+
20
+ class Attention(AttentionABC):
21
+ def __init__(self, n_head, hidden_dim, max_seq_length):
22
+ super().__init__(n_head, hidden_dim, max_seq_length)
23
+
24
+ # key, query, value projections for all heads, but in a batch
25
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
26
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
27
+
28
+ def __call__(
29
+ self,
30
+ x: Tensor,
31
+ input_pos: Tensor,
32
+ kv_cache: KVCacheProtocol,
33
+ cu_seqlens_q: Tensor,
34
+ cu_seqlens_kv: Tensor,
35
+ ) -> Tensor:
36
+ bsz, seqlen, _ = x.shape
37
+
38
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
39
+
40
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
41
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
42
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
43
+
44
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
45
+
46
+ k, v = kv_cache.update(input_pos, k, v)
47
+
48
+ attn: Tensor = sageattention.sageattn_varlen(
49
+ q,
50
+ k,
51
+ v,
52
+ cu_seqlens_q=cu_seqlens_q,
53
+ cu_seqlens_kv=cu_seqlens_kv,
54
+ max_seqlen_q=1,
55
+ max_seqlen_k=self.max_seq_length,
56
+ )
57
+
58
+ attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
59
+
60
+ attn = self.out_proj(attn)
61
+
62
+ return attn
63
+
64
+
65
+ class TransformerBlock(TransformerBlockABC):
66
+ def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
67
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
68
+
69
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
70
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
71
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
72
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
73
+
74
+
75
+ class TransformerDecoder(TransformerDecoderABC):
76
+ def __init__(
77
+ self,
78
+ hidden_dim,
79
+ n_layer,
80
+ n_head,
81
+ ffn_dim,
82
+ vocab_size,
83
+ max_seq_length,
84
+ max_batch_size,
85
+ ) -> None:
86
+ super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
87
+
88
+ self.layers = nn.ModuleList( # type: ignore
89
+ TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
90
+ )
91
+
92
+
93
+ class T2SDecoder(T2SDecoderABC):
94
+ def __init__(
95
+ self,
96
+ config,
97
+ max_seq_length=2000,
98
+ max_batch_size=10,
99
+ ) -> None:
100
+ super().__init__(config, max_seq_length, max_batch_size)
101
+
102
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
103
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
104
+ self.h: TransformerDecoderABC = TransformerDecoder(
105
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
106
+ )
107
+
108
+ self.kv_class = KVCacheHND
109
+
110
+ def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
111
+ return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
112
+
113
+ def post_forward(self, idx: int, session: T2SSession):
114
+ if idx == 0:
115
+ session.cu_seqlens_q = torch.arange(0, session.bsz + 1, dtype=torch.int32)
116
+ session.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), session.input_pos])
117
+ else:
118
+ cu_seqlens_q = session.cu_seqlens_q
119
+ cu_seqlens_kv = session.cu_seqlens_kv
120
+ cu_seqlens_kv.add_(cu_seqlens_q)
121
+
122
+
123
+ class CUDAGraphCache(CUDAGraphCacheABC):
124
+ def __init__(
125
+ self,
126
+ decoder: T2SDecoder,
127
+ ) -> None:
128
+ self.is_applicable = False
129
+ super().__init__(decoder)
130
+
131
+ if torch.cuda.is_available():
132
+ self.cu_seqlens_q = torch.arange(0, decoder.max_batch_size + 1, dtype=torch.int32).to(self.device)
133
+ self.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), self.input_pos]).to(self.device)
134
+
135
+ def release_graph(self, session: T2SSession):
136
+ if session.id == self.id:
137
+ self.assigned = False
138
+ else:
139
+ del (
140
+ session.graph,
141
+ session.xy_pos_,
142
+ session.xy_dec_,
143
+ session.input_pos,
144
+ session.kv_cache,
145
+ session.cu_seqlens_q,
146
+ session.cu_seqlens_kv,
147
+ )
148
+
149
+ def get_cache_graph(self, session: T2SSession):
150
+ assert self.graph
151
+ session.graph = self.graph
152
+ session.stream = self.stream
153
+
154
+ session.xy_pos_ = self.xy_pos
155
+ session.xy_dec_ = self.xy_dec
156
+ session.input_pos = self.input_pos.copy_(session.input_pos)
157
+
158
+ session.cu_seqlens_q = self.cu_seqlens_q
159
+ session.cu_seqlens_kv = self.cu_seqlens_kv
160
+
161
+ for cache, cache_ in zip(self.kv_cache, session.kv_cache):
162
+ cache.sync_cache(cache_)
163
+
164
+ def capture_new_graph(self, session: T2SSession):
165
+ session.xy_pos_ = self.xy_pos.clone()
166
+ session.xy_dec_ = self.xy_dec.clone()
167
+ session.input_pos = self.input_pos.clone().copy_(session.input_pos)
168
+
169
+ session.cu_seqlens_q = self.cu_seqlens_q.clone().copy_(session.cu_seqlens_q)
170
+ session.cu_seqlens_kv = self.cu_seqlens_kv.clone().copy_(session.cu_seqlens_kv)
171
+
172
+ args, kwds = self.decoder.pre_forward(session)
173
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
174
+ session.graph = graph
175
+ session.stream = torch.cuda.Stream() # type: ignore
GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ from .. import nn
5
+ from ..structs import KVCacheProtocol, T2SSession
6
+ from ..t2s_model_abc import (
7
+ AttentionABC,
8
+ CUDAGraphCacheABC,
9
+ FeedForward,
10
+ KVCacheHND,
11
+ T2SDecoderABC,
12
+ TransformerBlockABC,
13
+ TransformerDecoderABC,
14
+ )
15
+
16
+ Tensor = torch.Tensor
17
+
18
+
19
+ class Attention(AttentionABC):
20
+ def __init__(self, n_head, hidden_dim, max_seq_length):
21
+ super().__init__(n_head, hidden_dim, max_seq_length)
22
+
23
+ # key, query, value projections for all heads, but in a batch
24
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
25
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
26
+
27
+ def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
28
+ bsz, seqlen, _ = x.shape
29
+
30
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
31
+
32
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
33
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
34
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
35
+
36
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
37
+
38
+ k, v = kv_cache.update(input_pos, k, v)
39
+
40
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
41
+
42
+ attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
43
+
44
+ attn = self.out_proj(attn)
45
+
46
+ return attn
47
+
48
+
49
+ class TransformerBlock(TransformerBlockABC):
50
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
51
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
52
+
53
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
54
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
55
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
56
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
57
+
58
+
59
+ class TransformerDecoder(TransformerDecoderABC):
60
+ def __init__(
61
+ self,
62
+ hidden_dim,
63
+ n_layer,
64
+ n_head,
65
+ ffn_dim,
66
+ vocab_size,
67
+ max_seq_length,
68
+ max_batch_size,
69
+ ) -> None:
70
+ super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
71
+
72
+ self.layers = nn.ModuleList( # type: ignore
73
+ TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
74
+ )
75
+
76
+
77
+ class T2SDecoder(T2SDecoderABC):
78
+ def __init__(
79
+ self,
80
+ config,
81
+ max_seq_length=2000,
82
+ max_batch_size=10,
83
+ ) -> None:
84
+ super().__init__(config, max_seq_length, max_batch_size)
85
+
86
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
87
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
88
+ self.h: TransformerDecoderABC = TransformerDecoder(
89
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
90
+ )
91
+
92
+ self.kv_class = KVCacheHND
93
+
94
+ def pre_forward(self, session: T2SSession):
95
+ attn_mask = session.attn_mask
96
+ return list(), dict(attn_mask=attn_mask)
97
+
98
+ def post_forward(self, idx: int, session: T2SSession) -> None:
99
+ if idx == 0:
100
+ prefill_len = session.prefill_len
101
+ bsz = session.bsz
102
+
103
+ range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
104
+ prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
105
+ attn_mask = range_tensor < prefill_len_expanded
106
+ attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
107
+
108
+ session.attn_mask = attn_mask
109
+
110
+ attn_mask = session.attn_mask
111
+ input_pos = session.input_pos
112
+ attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
113
+
114
+
115
+ class CUDAGraphCache(CUDAGraphCacheABC):
116
+ def __init__(
117
+ self,
118
+ decoder,
119
+ ) -> None:
120
+ self.is_applicable = True
121
+ super().__init__(decoder)
122
+ if torch.cuda.is_available():
123
+ self.attn_mask = (
124
+ torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
125
+ .bool()
126
+ .to(self.device, self.dtype)
127
+ )
128
+
129
+ def release_graph(self, session: T2SSession):
130
+ if session.id == self.id:
131
+ self.assigned = False
132
+ else:
133
+ del (
134
+ session.graph,
135
+ session.xy_pos_,
136
+ session.xy_dec_,
137
+ session.input_pos,
138
+ session.kv_cache,
139
+ session.attn_mask,
140
+ )
141
+
142
+ def get_cache_graph(self, session: T2SSession):
143
+ assert self.graph
144
+ session.graph = self.graph
145
+ session.stream = self.stream
146
+
147
+ session.xy_pos_ = self.xy_pos
148
+ session.xy_dec_ = self.xy_dec
149
+ session.input_pos = self.input_pos.copy_(session.input_pos)
150
+
151
+ session.attn_mask = self.attn_mask
152
+
153
+ for cache, cache_ in zip(self.kv_cache, session.kv_cache):
154
+ cache.sync_cache(cache_)
155
+
156
+ def capture_new_graph(self, session: T2SSession):
157
+ session.xy_pos_ = self.xy_pos.clone()
158
+ session.xy_dec_ = self.xy_dec.clone()
159
+ session.input_pos = self.input_pos.clone().copy_(session.input_pos)
160
+
161
+ session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
162
+
163
+ args, kwds = self.decoder.pre_forward(session)
164
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
165
+ session.graph = graph
166
+ session.stream = torch.cuda.Stream() # type: ignore
GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NoReturn
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ from .. import nn
7
+ from ..structs import KVCacheProtocol, T2SSession
8
+ from ..t2s_model_abc import (
9
+ AttentionABC,
10
+ CUDAGraphCacheABC,
11
+ FeedForward,
12
+ KVCacheHNDVarlen,
13
+ T2SDecoderABC,
14
+ TransformerBlockABC,
15
+ TransformerDecoderABC,
16
+ )
17
+
18
+ Tensor = torch.Tensor
19
+
20
+
21
+ class Attention(AttentionABC):
22
+ def __init__(self, n_head, hidden_dim, max_seq_length):
23
+ super().__init__(n_head, hidden_dim, max_seq_length)
24
+
25
+ # key, query, value projections for all heads, but in a batch
26
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
27
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
28
+
29
+ def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
30
+ bsz, seqlen, _ = x.shape
31
+
32
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
33
+
34
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
35
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
36
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
37
+
38
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
39
+
40
+ k, v = kv_cache.update(input_pos, k, v)
41
+
42
+ max_idx = input_pos.max()
43
+
44
+ q, k, v = map(lambda x: x[..., :max_idx, :], (q, k, v))
45
+
46
+ mask = attn_mask[..., :max_idx]
47
+
48
+ attn = F.scaled_dot_product_attention(q, k, v, mask)
49
+
50
+ attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
51
+
52
+ attn = self.out_proj(attn)
53
+
54
+ return attn
55
+
56
+
57
+ class TransformerBlock(TransformerBlockABC):
58
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
59
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
60
+
61
+ self.attention = Attention(n_head, hidden_dim, max_seq_length)
62
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
63
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
64
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
65
+
66
+
67
+ class TransformerDecoder(TransformerDecoderABC):
68
+ def __init__(
69
+ self,
70
+ hidden_dim,
71
+ n_layer,
72
+ n_head,
73
+ ffn_dim,
74
+ vocab_size,
75
+ max_seq_length,
76
+ max_batch_size,
77
+ ) -> None:
78
+ super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
79
+
80
+ self.layers = nn.ModuleList( # type: ignore
81
+ TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
82
+ )
83
+
84
+
85
+ class T2SDecoder(T2SDecoderABC):
86
+ def __init__(
87
+ self,
88
+ config,
89
+ max_seq_length=2000,
90
+ max_batch_size=10,
91
+ ) -> None:
92
+ super().__init__(config, max_seq_length, max_batch_size)
93
+
94
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
95
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
96
+ self.h: TransformerDecoderABC = TransformerDecoder(
97
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
98
+ )
99
+
100
+ self.kv_class = KVCacheHNDVarlen
101
+
102
+ def capture(
103
+ self,
104
+ *args,
105
+ **kwds,
106
+ ) -> NoReturn:
107
+ raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
108
+
109
+ def pre_forward(self, session: T2SSession):
110
+ attn_mask = session.attn_mask
111
+ return list(), dict(attn_mask=attn_mask)
112
+
113
+ def post_forward(self, idx: int, session: T2SSession) -> None:
114
+ if idx == 0:
115
+ prefill_len = session.prefill_len
116
+ bsz = session.bsz
117
+
118
+ range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
119
+ prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
120
+ attn_mask = range_tensor < prefill_len_expanded
121
+ attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
122
+
123
+ session.attn_mask = attn_mask
124
+
125
+ attn_mask = session.attn_mask
126
+ input_pos = session.input_pos
127
+ attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
128
+
129
+
130
+ class CUDAGraphCache(CUDAGraphCacheABC):
131
+ def __init__(
132
+ self,
133
+ decoder,
134
+ ) -> None:
135
+ self.is_applicable = False
136
+ super().__init__(decoder)
137
+
138
+ def release_graph(self, session: T2SSession):
139
+ raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
140
+
141
+ def get_cache_graph(self, session: T2SSession):
142
+ raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
143
+
144
+ def capture_new_graph(self, session: T2SSession):
145
+ raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
GPT_SoVITS/Accelerate/PyTorch/export.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from pathlib import Path
6
+ from typing import MutableSequence, TypeAlias
7
+
8
+ import torch
9
+ import typer
10
+ from torch.export import Dim
11
+ from torch.nn import functional as F
12
+
13
+ from ..logger import logger
14
+ from . import nn
15
+ from .t2s_model_abc import AttentionABC, FeedForward, T2SDecoderABC, TransformerBlockABC, TransformerDecoderABC
16
+
17
+ Tensor = torch.Tensor
18
+
19
+ KVCache: TypeAlias = tuple[Tensor, Tensor]
20
+
21
+ app = typer.Typer(
22
+ context_settings={"help_option_names": ["-h", "--help"]},
23
+ add_completion=False,
24
+ )
25
+
26
+
27
+ class Stage(str, enum.Enum):
28
+ embed = "embed"
29
+ decode = "decode"
30
+
31
+
32
+ class KVCacheONNX:
33
+ @staticmethod
34
+ def empty(kv_cache):
35
+ assert len(kv_cache) == 2
36
+ k_cache, v_cache = kv_cache
37
+
38
+ k_cache[:] = 0
39
+ v_cache[:] = 0
40
+
41
+ @staticmethod
42
+ def update_cache(
43
+ input_pos: Tensor, k_val: Tensor, v_val: Tensor, kv_cache: tuple[Tensor, Tensor], cache_idx: Tensor
44
+ ):
45
+ # input_pos: [B, ], k_val: [B, H, 1, D]
46
+ k_out, v_out = kv_cache
47
+ ip0 = input_pos - 1
48
+
49
+ k_out[cache_idx, :, ip0, None] = k_val
50
+ v_out[cache_idx, :, ip0, None] = v_val
51
+
52
+ return k_out, v_out
53
+
54
+ @staticmethod
55
+ def prefill_kv(k_val: Tensor, v_val: Tensor, kv_cache: tuple[Tensor, Tensor]):
56
+ # k_val: [B, S, H, D]
57
+ k_cache, v_cache = kv_cache
58
+
59
+ k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
60
+ v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
61
+
62
+ @staticmethod
63
+ def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: torch.dtype):
64
+ cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
65
+
66
+ return (torch.zeros(cache_shape, dtype=dtype), torch.zeros(cache_shape, dtype=dtype))
67
+
68
+
69
+ class AttentionONNX(AttentionABC):
70
+ def __init__(self, n_heads: int, head_dim: int, max_seq_length: int):
71
+ super().__init__(n_heads, head_dim, max_seq_length)
72
+
73
+ self.in_proj = nn.Linear(self.hidden_dim, self.hidden_dim * 3, bias=True)
74
+ self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
75
+
76
+ def __call__(self, *args, **kwds): # type: ignore
77
+ pass
78
+
79
+ def onnx_prefill(self, x: Tensor, kv_cache: KVCache, attn_mask: Tensor) -> Tensor:
80
+ bsz, seqlen, _ = x.shape
81
+
82
+ torch._check(attn_mask.size(-2) == x.size(-2))
83
+
84
+ q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
85
+
86
+ q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
87
+
88
+ KVCacheONNX.prefill_kv(k, v, kv_cache)
89
+
90
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
91
+
92
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
93
+
94
+ attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
95
+
96
+ output = self.out_proj(attn)
97
+
98
+ return output
99
+
100
+ def onnx_decode(self, x: Tensor, input_pos: Tensor, kv_cache: KVCache, cache_idx: Tensor, attn_mask: Tensor):
101
+ bsz, seqlen, _ = x.shape
102
+
103
+ torch._check(attn_mask.size(-2) == 1)
104
+
105
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
106
+
107
+ q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
108
+
109
+ q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
110
+
111
+ kv_cache = KVCacheONNX.update_cache(input_pos, k, v, kv_cache, cache_idx)
112
+
113
+ max_idx = int(input_pos.max())
114
+
115
+ q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
116
+
117
+ mask = attn_mask[..., :max_idx]
118
+
119
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
120
+
121
+ attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
122
+
123
+ attn = self.out_proj(attn)
124
+
125
+ return attn
126
+
127
+
128
+ class TransformerBlockONNX(TransformerBlockABC):
129
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
130
+ super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
131
+
132
+ self.attention: AttentionONNX = AttentionONNX(n_head, hidden_dim, max_seq_length) # type: ignore
133
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
134
+ self.attention_norm = nn.LayerNorm(self.hidden_dim)
135
+ self.ffn_norm = nn.LayerNorm(self.hidden_dim)
136
+
137
+ def onnx_prefill(self, x: Tensor, attn_mask: Tensor, kv_cache: KVCache):
138
+ h = self.attention_norm(
139
+ x
140
+ + self.attention.onnx_prefill(
141
+ x,
142
+ kv_cache,
143
+ attn_mask,
144
+ )
145
+ )
146
+ out = self.ffn_norm(h + self.feed_forward(h))
147
+
148
+ return out
149
+
150
+ def onnx_decode(self, x: Tensor, input_pos: Tensor, kv_cache: KVCache, cache_idx: Tensor, attn_mask: Tensor):
151
+ h = self.attention_norm(
152
+ x
153
+ + self.attention.onnx_decode(
154
+ x,
155
+ input_pos,
156
+ kv_cache,
157
+ cache_idx,
158
+ attn_mask,
159
+ )
160
+ )
161
+ out = self.ffn_norm(h + self.feed_forward(h))
162
+ return out
163
+
164
+
165
+ class TransformerDecoderONNX(TransformerDecoderABC):
166
+ def __init__(
167
+ self,
168
+ hidden_dim: int,
169
+ n_layer: int,
170
+ n_head: int,
171
+ ffn_dim: int,
172
+ vocab_size: int,
173
+ max_seq_length: int,
174
+ max_batch_size: int,
175
+ ) -> None:
176
+ super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
177
+
178
+ self.layers: MutableSequence[TransformerBlockONNX] = nn.ModuleList( # type: ignore
179
+ TransformerBlockONNX(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
180
+ )
181
+
182
+ def onnx_prefill(self, x: Tensor, mask: Tensor, *kv_caches: KVCache):
183
+ for layer, kv_cache in zip(self.layers, kv_caches):
184
+ x = layer.onnx_prefill(
185
+ x,
186
+ mask,
187
+ kv_cache,
188
+ )
189
+ return x
190
+
191
+ def onnx_decode(
192
+ self,
193
+ input_pos: Tensor,
194
+ x: Tensor,
195
+ cache_idx: Tensor,
196
+ attn_mask: Tensor,
197
+ *kv_caches: KVCache,
198
+ ):
199
+ for layer, kv_cache in zip(self.layers, kv_caches):
200
+ x = layer.onnx_decode(
201
+ x,
202
+ input_pos,
203
+ kv_cache,
204
+ cache_idx,
205
+ attn_mask,
206
+ )
207
+
208
+ return x
209
+
210
+
211
+ class T2SDecoderONNX(T2SDecoderABC):
212
+ def __init__(self, config: dict, max_seq_length: int = 2000, max_batch_size: int = 10) -> None:
213
+ super().__init__(config, max_seq_length, max_batch_size)
214
+
215
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
216
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
217
+
218
+ self.h = TransformerDecoderONNX(
219
+ self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
220
+ )
221
+
222
+ def pre_forward(self, session) -> tuple[list[Tensor], dict[str, Tensor]]:
223
+ return super().pre_forward(session)
224
+
225
+ def post_forward(self, idx: int, session) -> None:
226
+ return super().post_forward(idx, session)
227
+
228
+ def embed_onnx_(
229
+ self,
230
+ x: Tensor,
231
+ x_len: Tensor,
232
+ y: torch.Tensor,
233
+ bert_features: Tensor,
234
+ ):
235
+ B = x.shape[0]
236
+ D = self.embedding_dim
237
+ T_TOTAL = 500
238
+ xy_pos = torch.zeros((B, T_TOTAL, D)).to(bert_features[0].dtype)
239
+
240
+ bert_features = bert_features.transpose(1, 2)
241
+
242
+ y_len = y.shape[1]
243
+ y_emb = self.ar_audio_embedding(y)
244
+ y_pos = self.ar_audio_position.prefill(y_emb)
245
+
246
+ for bs, x_, len_, bert_feature in zip(torch.arange(x.shape[0]), x, x_len, bert_features):
247
+ x_emb = self.ar_text_embedding(x_[:len_])
248
+
249
+ bert = self.bert_proj(bert_feature[:len_])
250
+
251
+ print(bert.shape, bert_feature[:len_])
252
+
253
+ return bert, bert_feature[:len_].unsqueeze(0)
254
+
255
+ return bert[:20].unsqueeze(0), None
256
+ x_emb = x_emb + bert
257
+ x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
258
+
259
+ xy_pos[None, bs, :len_] = bert
260
+ # xy_pos[None, bs, len_ : len_ + y_len] = y_pos
261
+
262
+ return xy_pos[:, -1], None
263
+
264
+ return xy_pos[: x.shape[0]], x_len
265
+
266
+ def embed_onnx(
267
+ self,
268
+ x: torch.Tensor, # [B, Tx]
269
+ x_len: torch.Tensor, # [B]
270
+ y: torch.Tensor, # [1, Ty, D]
271
+ bert_features: torch.Tensor, # [B, 1024, Tx]
272
+ ):
273
+ # [B, 1024, Tx] -> [B, Tx, 1024]
274
+ bert_features = bert_features.transpose(1, 2)
275
+
276
+ Ty = y.shape[1]
277
+ Tx = x.shape[1]
278
+ B = x.shape[0]
279
+ D = self.embedding_dim
280
+ T_TOTAL = 500
281
+
282
+ # mask: [B, Tx],[j] Col < x_len[i]
283
+ col = torch.arange(Tx, device=x.device).unsqueeze(0) # [1, Tx]
284
+ mask_x = col < x_len.view(-1, 1) # [B, Tx]
285
+ mask_x3 = mask_x.unsqueeze(-1) # [B, Tx, 1]
286
+
287
+ torch._check((Ty >= 0) and (Ty <= 250), "y_len out of range")
288
+ torch._check((Tx >= 0) and (Tx <= 250), "x_len out of range")
289
+
290
+ y_emb = self.ar_audio_embedding(y) # [1, Ty, D]
291
+ y_pos = self.ar_audio_position.prefill(y_emb) # [1, Ty, D]
292
+
293
+ x_emb_full = self.ar_text_embedding(x) # [B, Tx, D]
294
+ bert_full = self.bert_proj(bert_features[[0], : x_len[0]]) # [B, Tx, D]
295
+
296
+ print(bert_full[0].shape, bert_features[0, : x_len[0]])
297
+
298
+ return bert_full[0], bert_features[0, : x_len[0]]
299
+
300
+ x_sum_full = x_emb_full + bert_full # [B, Tx, D]
301
+ x_pos_full = self.ar_text_position.prefill(x_sum_full) # [B, Tx, D]
302
+
303
+ xy_pos = torch.zeros((B, T_TOTAL, D), dtype=x_pos_full.dtype, device=x_pos_full.device)
304
+
305
+ xy_pos[:, :Tx, :] = torch.where(
306
+ mask_x3,
307
+ bert_full[:, :Tx, :].to(xy_pos.dtype),
308
+ xy_pos[:, :Tx, :],
309
+ )
310
+
311
+ return xy_pos[:, -1], None
312
+
313
+ # Start From offset=x_len, Ty
314
+ # [Ty] Index: offsets + [0..Ty-1]
315
+ offsets = x_len.clamp(min=0, max=T_TOTAL - Ty) # [B]
316
+ idx_y = offsets.unsqueeze(1) + torch.arange(Ty, device=x_pos_full.device) # [B, Ty]
317
+ # scatter to dim=1
318
+ # expand index to [B, Ty, D]
319
+ idx_y3 = idx_y.unsqueeze(-1).expand(B, Ty, D)
320
+ y_pos_b = y_pos.expand(B, Ty, D).to(xy_pos.dtype) # [B, Ty, D]
321
+ xy_pos = xy_pos.scatter(1, idx_y3, y_pos_b)
322
+
323
+ return xy_pos, x_len
324
+
325
+
326
+ def torchscript_export(model: T2SDecoderONNX, stage="embed"):
327
+ if stage == "embed":
328
+ x = torch.randint(1, 600, (model.max_batch_size, 50))
329
+ x_len = torch.randint(30, 50, (model.max_batch_size,))
330
+ y = torch.randint(1, 600, (1, 200))
331
+ bert_features = torch.rand((model.max_batch_size, 1024, 50))
332
+
333
+ x_len[-1] = 50
334
+
335
+ mask = torch.arange(x_len.max().item(), device=x.device).unsqueeze(0) < x_len.unsqueeze(1)
336
+
337
+ x = x * mask
338
+ bert_features = bert_features * mask.unsqueeze(1)
339
+
340
+ try:
341
+ a, c = model.embed_onnx_(x, x_len, y, bert_features)
342
+ b, d = model.embed_onnx(x, x_len, y, bert_features)
343
+ print("-" * 20)
344
+ print(a - b, (a - b).sum(), (a - b).square().mean())
345
+ print(c - d, (c - d).sum(), (c - d).square().mean())
346
+ exit()
347
+ assert torch.allclose(a, b, atol=1e-6, rtol=1e-8), (a - b).square().mean()
348
+
349
+ setattr(model, "forward", model.embed_onnx)
350
+ scripted_model = torch.jit.script(model, example_inputs=[(x, x_len, y, bert_features)])
351
+
352
+ onnx_program = torch.onnx.export(
353
+ scripted_model,
354
+ (x, x_len, y, bert_features),
355
+ input_names=["text", "text_len", "prompt", "bert_features"],
356
+ output_names=["xy_pos", "input_pos"],
357
+ dynamic_axes={
358
+ "text": {0: "Batch_Size", 1: "Sequence_Length_X"},
359
+ "prompt": {0: "Batch_Size", 1: "Sequence_Length_Y"},
360
+ "bert_features": {0: "Batch_Size", 1: "Sequence_Length_X"},
361
+ },
362
+ opset_version=21,
363
+ training=False,
364
+ do_constant_folding=True,
365
+ external_data=False,
366
+ )
367
+ assert onnx_program
368
+ onnx_program.save("onnx_export/AR_Embedding_TorchScript.onnx")
369
+
370
+ except Exception:
371
+ logger.bind(show_locals=False).exception("")
372
+
373
+
374
+ def dynamo_export(model: T2SDecoderONNX, stage="embed"):
375
+ if stage == "embed":
376
+ x = torch.randint(1, 600, (model.max_batch_size, 50))
377
+ x_len = torch.randint(30, 50, (model.max_batch_size,))
378
+ y = torch.randint(1, 600, (1, 200))
379
+ bert_features = torch.rand((model.max_batch_size, 1024, 50))
380
+
381
+ x_len[-1] = 50
382
+
383
+ mask = torch.arange(x_len.max().item(), device=x.device).unsqueeze(0) < x_len.unsqueeze(1)
384
+
385
+ x = x * mask
386
+ bert_features = (bert_features.transpose(1, 2) * mask.unsqueeze(-1)).transpose(1, 2)
387
+
388
+ dynamic_shapes = [
389
+ {
390
+ 0: Dim("Batch_Size", min=1, max=4),
391
+ 1: Dim("Sequence_Length_X", min=1, max=50),
392
+ },
393
+ {
394
+ 0: Dim("Batch_Size", min=1, max=4),
395
+ },
396
+ {
397
+ 1: Dim("Sequence_Length_Y", min=1, max=250),
398
+ },
399
+ {
400
+ 0: Dim("Batch_Size", min=1, max=4),
401
+ 2: Dim("Sequence_Length_X", min=1, max=50),
402
+ },
403
+ ]
404
+ try:
405
+ a = model.embed_onnx_(x, x_len, y, bert_features)[0]
406
+ b = model.embed_onnx(x, x_len, y, bert_features)[0]
407
+ print(a - b, (a - b).square().mean())
408
+ exit()
409
+ assert torch.allclose(a, b, atol=1e-6, rtol=1e-8), (a - b).square().mean()
410
+
411
+ setattr(model, "forward", model.embed_onnx)
412
+ onnx_program = torch.onnx.export(
413
+ model,
414
+ (x, x_len, y, bert_features),
415
+ input_names=["text", "text_len", "prompt", "bert_features"],
416
+ output_names=["xy_pos", "input_pos"],
417
+ dynamo=True,
418
+ dynamic_shapes=dynamic_shapes,
419
+ opset_version=21,
420
+ training=False,
421
+ do_constant_folding=True,
422
+ external_data=False,
423
+ )
424
+ assert onnx_program
425
+ onnx_program.save("onnx_export/AR_Embedding_Dynamo.onnx")
426
+ except Exception:
427
+ logger.bind(show_locals=False).exception("")
428
+
429
+
430
+ @app.command()
431
+ def export(
432
+ ckpt_path: Path = typer.Option(
433
+ ...,
434
+ "--ckpt-path",
435
+ file_okay=True,
436
+ dir_okay=False,
437
+ exists=True,
438
+ readable=True,
439
+ show_default=False,
440
+ help="AR Checkpoint",
441
+ ),
442
+ dynamo: bool = typer.Option(False, is_flag=True, flag_value=True, help="Use Torch Dynamo"),
443
+ stages: list[Stage] = typer.Option([Stage.embed], "--stages", help="Stage to export"),
444
+ ):
445
+ os.makedirs("onnx_export", exist_ok=True)
446
+ dict_s1 = torch.load(ckpt_path, "cpu", mmap=True)
447
+ condig = dict_s1["config"]
448
+ model = T2SDecoderONNX(condig, 2000, 4)
449
+ state_dict = dict_s1["weight"]
450
+ model.load_state_dict(state_dict)
451
+
452
+ for stage in stages:
453
+ if dynamo:
454
+ dynamo_export(model, stage)
455
+ else:
456
+ torchscript_export(model, stage)
457
+
458
+
459
+ def get_prog_name() -> str:
460
+ script_rel = ".".join(["GPT_SoVITS", "Accelerate", "PyTorch", osp.basename(__file__)]).strip(".py")
461
+ return f"python -s -m {script_rel}"
462
+
463
+
464
+ if __name__ == "__main__":
465
+ t = time.perf_counter()
466
+ app(prog_name=get_prog_name())
467
+ logger.info(f"Exec Time: {time.perf_counter() - t:.2f} secs")
GPT_SoVITS/Accelerate/PyTorch/nn.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced Type Hint nn.Module
3
+ Modified From https://github.com/labmlai/labml/blob/master/helpers/labml_helpers/module.py
4
+ """
5
+
6
+ from typing import Any
7
+
8
+ import torch.nn
9
+ from torch.nn import (
10
+ functional as functional,
11
+ )
12
+ from torch.nn import (
13
+ utils as utils,
14
+ )
15
+ from torch.nn.modules import * # type: ignore # noqa: F403
16
+ from torch.nn.parameter import (
17
+ Parameter as Parameter,
18
+ )
19
+
20
+ Tensor = torch.Tensor
21
+
22
+
23
+ class Module(torch.nn.Module):
24
+ r"""
25
+ Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
26
+ ``forward`` for better type checking.
27
+
28
+ `PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
29
+ """
30
+
31
+ def _forward_unimplemented(self, *input: Any) -> None:
32
+ # To stop PyTorch from giving abstract methods warning
33
+ pass
34
+
35
+ def __init_subclass__(cls, **kwargs):
36
+ if cls.__dict__.get("__call__", None) is None:
37
+ return
38
+
39
+ setattr(cls, "forward", cls.__dict__["__call__"])
40
+ delattr(cls, "__call__")
41
+
42
+ @property
43
+ def device(self) -> torch.device:
44
+ params = self.parameters()
45
+ try:
46
+ sample_param = next(params)
47
+ return sample_param.device
48
+ except StopIteration:
49
+ raise RuntimeError(f"Unable to determine device of {self.__class__.__name__}") from None
50
+
51
+
52
+ class Linear(torch.nn.Linear):
53
+ def __call__(self, input: Tensor) -> Tensor:
54
+ return super().__call__(input)
55
+
56
+
57
+ class Dropout(torch.nn.Dropout):
58
+ def __call__(self, input: Tensor) -> Tensor:
59
+ return super().__call__(input)
60
+
61
+
62
+ class Embedding(torch.nn.Embedding):
63
+ def __call__(self, input: Tensor) -> Tensor:
64
+ return super().__call__(input)
65
+
66
+
67
+ class LayerNorm(torch.nn.LayerNorm):
68
+ def __call__(self, input: Tensor) -> Tensor:
69
+ return super().__call__(input)
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ Tensor = torch.Tensor
7
+
8
+
9
+ class SampleProtocol(Protocol):
10
+ @staticmethod
11
+ def __call__(
12
+ logits: Tensor,
13
+ previous_tokens: Tensor,
14
+ temperature: float,
15
+ top_k: int,
16
+ top_p: float,
17
+ repetition_penalty: float,
18
+ ) -> Tensor: ...
19
+
20
+
21
+ class sample_naive(SampleProtocol):
22
+ @staticmethod
23
+ def __call__(
24
+ logits: Tensor,
25
+ previous_tokens: Tensor,
26
+ temperature: float,
27
+ top_k: int,
28
+ top_p: float,
29
+ repetition_penalty: float,
30
+ ):
31
+ if temperature <= 1e-5:
32
+ probs = F.softmax(logits, dim=-1)
33
+ return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
34
+
35
+ if repetition_penalty != 1.0:
36
+ previous_tokens = previous_tokens.long()
37
+ score = torch.gather(logits, dim=1, index=previous_tokens)
38
+ score = torch.where(
39
+ score < 0,
40
+ score * repetition_penalty,
41
+ score / repetition_penalty,
42
+ )
43
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
44
+
45
+ if top_p < 1.0:
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
48
+ cum_probs[cum_probs > 1] = 1
49
+ sorted_indices_to_remove = cum_probs > top_p
50
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
51
+ indices_to_remove = sorted_indices_to_remove.scatter(
52
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
53
+ )
54
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
55
+
56
+ if temperature < 1.0:
57
+ logits /= temperature
58
+
59
+ v, _ = torch.topk(logits, top_k)
60
+ pivot = v[:, -1].unsqueeze(-1)
61
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
62
+
63
+ probs = F.softmax(logits, dim=-1)
64
+ q = -torch.log(torch.rand_like(probs))
65
+ idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
66
+
67
+ return idx_next
GPT_SoVITS/Accelerate/PyTorch/structs.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Literal, MutableSequence, Optional, Protocol
9
+
10
+ import torch
11
+
12
+ from .sample_funcs import SampleProtocol, sample_naive
13
+
14
+ Tensor = torch.Tensor
15
+
16
+
17
+ @dataclass
18
+ class T2SResult:
19
+ result: list[Tensor] | None = None
20
+ infer_speed: tuple[float, float] = (0.0, 0.0)
21
+ status: Literal["Success", "Error"] = "Success"
22
+ exception: Optional[Exception] = None
23
+ traceback: Optional[str] = None
24
+
25
+
26
+ @dataclass
27
+ class T2SRequest:
28
+ x: list[torch.Tensor]
29
+ x_lens: Tensor
30
+ prompts: torch.Tensor
31
+ bert_feature: list[Tensor]
32
+ valid_length: int
33
+ top_k: int = 5
34
+ top_p: float = 1
35
+ early_stop_num: int = -1
36
+ temperature: float = 1.0
37
+ repetition_penalty: float = 1.35
38
+ use_cuda_graph: bool = False
39
+ debug: bool = False
40
+
41
+
42
+ class KVCacheProtocol(Protocol):
43
+ k_cache: Tensor
44
+ v_cache: Tensor
45
+
46
+ def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None: ...
47
+
48
+ def empty(self) -> None: ...
49
+
50
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
51
+
52
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
53
+
54
+ def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
55
+
56
+
57
+ class T2SDecoderProtocol(Protocol):
58
+ max_seq_length: int
59
+ EOS: int
60
+ n_head: int
61
+
62
+ @property
63
+ def device(self) -> torch.device: ...
64
+
65
+ def embed(self, x: list[Tensor], y: Tensor, bert_features: list[Tensor]) -> Tensor: ...
66
+
67
+
68
+ class T2SEngineProtocol(Protocol):
69
+ def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ...
70
+
71
+ def generate(self, request: T2SRequest) -> T2SResult: ...
72
+
73
+
74
+ class T2SSession:
75
+ def __init__(
76
+ self,
77
+ decoder: T2SDecoderProtocol,
78
+ request: T2SRequest,
79
+ sapmle_func: type[SampleProtocol] = sample_naive,
80
+ device: torch.device = torch.device("cpu"),
81
+ dtype: torch.dtype = torch.float32,
82
+ ):
83
+ with device:
84
+ self.decoder = decoder
85
+ self.request = request
86
+ self.device = device
87
+ self.dtype = dtype
88
+
89
+ bsz = len(request.x)
90
+ y_len = request.prompts.size(-1)
91
+ self.bsz = bsz
92
+ self.y_len = y_len
93
+ request.prompts = request.prompts.to(device, torch.int32)
94
+
95
+ # Cache
96
+ self.kv_cache: MutableSequence[KVCacheProtocol]
97
+ self.sample = sapmle_func()
98
+
99
+ # Forward args
100
+ self.x = [i.to(device) for i in request.x]
101
+ self.x_lens = request.x_lens.to(torch.int32)
102
+ self.y = torch.zeros((bsz, decoder.max_seq_length)).to(torch.int32)
103
+ self.y[:, : request.prompts.shape[-1]] = request.prompts
104
+ self.bert_feature = [i.to(device, dtype) for i in request.bert_feature]
105
+
106
+ self.prefill_len = self.x_lens + request.prompts.size(1)
107
+
108
+ self.input_pos = torch.zeros_like(self.prefill_len)
109
+ self.input_pos.add_(self.prefill_len)
110
+
111
+ # CUDA Graph
112
+ self.stream: Optional[torch.cuda.Stream] = None
113
+ self.graph: Optional[torch.cuda.CUDAGraph] = None
114
+ self.xy_pos_: Tensor
115
+ self.xy_dec_: Tensor
116
+
117
+ # EOS
118
+ self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
119
+ self.y_results: list[Tensor] = [None] * len(self.x) # type: ignore
120
+
121
+ self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
122
+
123
+ max_len = int(self.prefill_len.max().item())
124
+ attn_mask = torch.zeros(size=(bsz, max_len, max_len), dtype=torch.bool)
125
+
126
+ for bs in range(bsz):
127
+ pos = int(self.x_lens[bs])
128
+ seq_len = pos + y_len
129
+
130
+ attn_mask[bs, :seq_len, :pos] = True
131
+
132
+ ar_mask = ~torch.triu(
133
+ input=torch.ones(
134
+ size=(
135
+ y_len,
136
+ y_len,
137
+ ),
138
+ dtype=torch.bool,
139
+ ),
140
+ diagonal=1,
141
+ )
142
+ attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
143
+
144
+ self.attn_mask = attn_mask
145
+ self.attn_mask = attn_mask.unsqueeze(0).expand(-1, decoder.n_head, -1, -1)
146
+
147
+ self.id: int = -1
148
+
149
+ # Sage Attn & Transformer Engine Impl
150
+ self.cu_seqlens_q: Tensor
151
+ self.cu_seqlens_kv: Tensor
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import os
4
+ import sys
5
+ import time
6
+ import traceback
7
+ from importlib import import_module
8
+
9
+ import torch
10
+ from rich.progress import BarColumn, Progress, TextColumn
11
+
12
+ from ..logger import SpeedColumnToken, console, logger
13
+ from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
14
+ from .t2s_model_abc import (
15
+ CUDAGraphCacheABC,
16
+ T2SDecoderABC,
17
+ TorchProfiler,
18
+ )
19
+
20
+
21
+ class T2SEngine(T2SEngineProtocol):
22
+ def __init__(
23
+ self,
24
+ decoder_model: T2SDecoderABC,
25
+ device: torch.device = torch.device("cpu"),
26
+ dtype: torch.dtype = torch.float32,
27
+ ) -> None:
28
+ assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
29
+ assert dtype in {torch.float16, torch.bfloat16, torch.float32}
30
+
31
+ self.device = device
32
+ self.dtype = dtype
33
+
34
+ self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
35
+
36
+ self.graphcache: CUDAGraphCacheABC = self.init_cache()
37
+
38
+ def _handle_request(self, request: T2SRequest):
39
+ with self.device:
40
+ decoder = self.decoder_model
41
+ session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
42
+ batch_idx = torch.arange(session.bsz)
43
+
44
+ t1 = 0.0
45
+ infer_speed = 0.0
46
+ infer_time = 0.0
47
+
48
+ torch_profiler = TorchProfiler(request.debug)
49
+ with (
50
+ torch_profiler.profiler(),
51
+ Progress(
52
+ TextColumn("[cyan]{task.description}"),
53
+ BarColumn(),
54
+ TextColumn("{task.completed}/{task.total} tokens"),
55
+ SpeedColumnToken(show_speed=True),
56
+ console=console,
57
+ transient=True,
58
+ ) as progress,
59
+ ):
60
+ max_token = int(min(2000 - session.input_pos.max(), 1500))
61
+ task = progress.add_task("T2S Decoding", total=max_token)
62
+
63
+ for idx in range(max_token):
64
+ progress.update(task, advance=1)
65
+ if idx == 0:
66
+ session.kv_cache = decoder.init_cache(session.bsz)
67
+ xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
68
+ xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
69
+ else:
70
+ if (
71
+ request.use_cuda_graph
72
+ and session.graph is None
73
+ and self.graphcache.is_applicable
74
+ and torch.cuda.is_available()
75
+ ):
76
+ self.graphcache.assign_graph(session)
77
+
78
+ with torch_profiler.record("AR"):
79
+ if session.graph:
80
+ assert session.stream
81
+ session.stream.wait_stream(torch.cuda.default_stream())
82
+ with torch.cuda.stream(session.stream):
83
+ session.xy_pos_.copy_(session.xy_pos)
84
+ session.graph.replay()
85
+ xy_dec = session.xy_dec_.clone()
86
+ else:
87
+ args, kwds = decoder.pre_forward(session)
88
+ xy_dec = decoder.h(
89
+ session.input_pos,
90
+ session.xy_pos,
91
+ session.kv_cache,
92
+ *args,
93
+ **kwds,
94
+ )
95
+
96
+ with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
97
+ decoder.post_forward(idx, session)
98
+ logits = decoder.ar_predict_layer(xy_dec[:, -1])
99
+
100
+ if idx == 0:
101
+ logits[:, -1] = float("-inf")
102
+
103
+ with torch_profiler.record("Sampling"):
104
+ samples = session.sample(
105
+ logits=logits,
106
+ previous_tokens=session.y[:, : session.y_len + idx],
107
+ top_k=request.top_k,
108
+ top_p=request.top_p,
109
+ repetition_penalty=request.repetition_penalty,
110
+ temperature=request.temperature,
111
+ )
112
+ session.y[batch_idx, session.y_len + idx] = samples
113
+ session.input_pos.add_(1)
114
+
115
+ with torch_profiler.record("EOS"):
116
+ argmax_token = torch.argmax(logits, dim=-1)
117
+ sample_token = samples.squeeze(1)
118
+ EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
119
+
120
+ newly_done_mask = EOS_mask & (~session.completed)
121
+ newly_done_indices = newly_done_mask.nonzero()
122
+
123
+ if newly_done_indices.numel() > 0:
124
+ for i in newly_done_indices:
125
+ session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
126
+ session.completed[newly_done_indices] = True
127
+
128
+ if torch.all(session.completed).item():
129
+ if session.y[:, session.y_len :].sum() == 0:
130
+ session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
131
+ logger.error("Bad Zero Prediction")
132
+ else:
133
+ logger.info(
134
+ f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
135
+ )
136
+ logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
137
+ infer_time = time.perf_counter() - t1
138
+ infer_speed = (idx - 1) / infer_time
139
+ break
140
+
141
+ if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
142
+ for i in range(session.bsz):
143
+ if not session.completed[i].item():
144
+ session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
145
+ session.completed[i] = True
146
+ logger.error("Bad Full Prediction")
147
+ break
148
+
149
+ with torch_profiler.record("NextPos"):
150
+ y_emb = decoder.ar_audio_embedding(samples)
151
+ session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
152
+
153
+ if idx == 1:
154
+ torch_profiler.start()
155
+ t1 = time.perf_counter()
156
+
157
+ if idx == 51:
158
+ torch_profiler.end()
159
+
160
+ if idx % 100 == 0:
161
+ match session.device.type:
162
+ case "cuda":
163
+ torch.cuda.empty_cache()
164
+ case "mps":
165
+ torch.mps.empty_cache()
166
+ case "xpu":
167
+ torch.xpu.empty_cache()
168
+ case "mtia":
169
+ torch.mtia.empty_cache()
170
+
171
+ match session.device.type:
172
+ case "cuda":
173
+ if session.stream is not None:
174
+ torch.cuda.current_stream().wait_stream(session.stream)
175
+ torch.cuda.empty_cache()
176
+ case "mps":
177
+ torch.mps.empty_cache()
178
+ case "xpu":
179
+ torch.xpu.empty_cache()
180
+ case "mtia":
181
+ torch.mtia.empty_cache()
182
+ case "cpu":
183
+ gc.collect()
184
+
185
+ torch_profiler.end()
186
+ if request.use_cuda_graph and torch.cuda.is_available():
187
+ self.graphcache.release_graph(session)
188
+
189
+ return session.y_results[: request.valid_length], infer_speed, infer_time
190
+
191
+ def generate(self, request: T2SRequest):
192
+ try:
193
+ result, infer_speed, infer_time = self._handle_request(request)
194
+ t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
195
+ except Exception as e:
196
+ t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
197
+ return t2s_result
198
+
199
+ @staticmethod
200
+ def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"):
201
+ logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
202
+ module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
203
+ decoder_cls_name = "T2SDecoder"
204
+ decoder_mod = import_module(module_path, package=__package__)
205
+ decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
206
+ dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
207
+ config = dict_s1["config"]
208
+ decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
209
+ state_dict = dict_s1["weight"]
210
+ decoder.load_state_dict(state_dict)
211
+
212
+ return decoder.eval()
213
+
214
+ def init_cache(self):
215
+ assert self.decoder_model
216
+
217
+ module_name = self.decoder_model.__class__.__module__
218
+ module = sys.modules.get(module_name)
219
+ assert module
220
+
221
+ target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
222
+
223
+ return target_class(self.decoder_model)
GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ import os
9
+ import random
10
+ from abc import ABC, abstractmethod
11
+ from contextlib import nullcontext
12
+ from typing import MutableSequence
13
+
14
+ import torch
15
+ import torch._inductor.config
16
+ import torch.nn.functional as F
17
+ from torch.cuda.graphs import CUDAGraph
18
+ from torch.profiler import ProfilerAction, tensorboard_trace_handler
19
+
20
+ from . import nn
21
+ from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
22
+
23
+ Tensor = torch.Tensor
24
+
25
+
26
+ class TokenEmbedding(nn.Module):
27
+ def __init__(
28
+ self,
29
+ embedding_dim: int,
30
+ vocab_size: int,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.vocab_size = vocab_size
35
+ self.embedding_dim = embedding_dim
36
+
37
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
38
+
39
+ @property
40
+ def weight(self) -> Tensor:
41
+ return self.word_embeddings.weight
42
+
43
+ def embedding(self, index: int) -> Tensor:
44
+ return self.word_embeddings.weight[index : index + 1]
45
+
46
+ def __call__(self, x: Tensor):
47
+ x = self.word_embeddings(x)
48
+ return x
49
+
50
+
51
+ class SinePositionalEmbedding(nn.Module):
52
+ def __init__(
53
+ self,
54
+ embedding_dim: int,
55
+ scale: bool = False,
56
+ alpha: bool = False,
57
+ max_batch_size: int = 10,
58
+ max_seq_len: int = 2000,
59
+ ):
60
+ super().__init__()
61
+ self.embedding_dim = embedding_dim
62
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
63
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
64
+ self.max_batch_size = max_batch_size
65
+ self.max_seq_len = max_seq_len
66
+
67
+ self.reverse = False
68
+ self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
69
+ self.pe: torch.Tensor
70
+ self.compute_pe()
71
+
72
+ def compute_pe(self):
73
+ """Reset the positional encodings."""
74
+ if self.reverse:
75
+ position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
76
+ else:
77
+ position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
78
+ div_term = torch.exp(
79
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
80
+ )
81
+ pe = self.pe
82
+ pe[:, :, 0::2] = torch.sin(position * div_term)
83
+ pe[:, :, 1::2] = torch.cos(position * div_term)
84
+
85
+ def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
86
+ """
87
+ Args:
88
+ input_pos (Tensor): [batch_size, ]
89
+ x (Tensor): [batch_size, 1, embed_dim]
90
+
91
+ Returns:
92
+ embedded_x (Tensor): [batch_size, 1, embed_dim]
93
+ """
94
+
95
+ batch_size = x.shape[0]
96
+ pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
97
+
98
+ return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
99
+
100
+ def prefill(self, x: Tensor) -> Tensor:
101
+ """
102
+ Args:
103
+ x (Tensor): [batch_size, seq_len, embed_dim]
104
+
105
+ Returns:
106
+ embedded_x (Tensor): [batch_size, seq_len, embed_dim]
107
+ """
108
+
109
+ batch_size = x.shape[0]
110
+ pe_values = self.pe[:batch_size, : x.shape[-2]]
111
+ return x * self.x_scale + self.alpha * pe_values
112
+
113
+
114
+ class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
115
+ def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
116
+ super().__init__()
117
+
118
+ self.n_head = n_heads
119
+ self.head_dim = head_dim
120
+ self.batch_size = batch_size
121
+ self.max_seq_length = max_seq_length
122
+
123
+ self.k_cache: Tensor
124
+ self.v_cache: Tensor
125
+
126
+ def empty(self):
127
+ self.k_cache.zero_()
128
+ self.v_cache.zero_()
129
+
130
+ @abstractmethod
131
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
132
+
133
+ @abstractmethod
134
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
135
+
136
+ def sync_cache(self, kv_cache: KVCacheProtocol):
137
+ self.k_cache.copy_(kv_cache.k_cache)
138
+ self.v_cache.copy_(kv_cache.v_cache)
139
+
140
+
141
+ class KVCacheNHD(KVCacheABC):
142
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
143
+ super().__init__(batch_size, max_seq_length, n_heads, head_dim)
144
+
145
+ assert batch_size > 0
146
+ cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
147
+
148
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
149
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
150
+
151
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
152
+ # input_pos: [B, ], k_val: [B, 1, H, D]
153
+
154
+ index = (
155
+ (input_pos - 1)
156
+ .unsqueeze(-1)
157
+ .unsqueeze(-1)
158
+ .unsqueeze(-1)
159
+ .expand(
160
+ -1,
161
+ -1,
162
+ self.n_head,
163
+ self.head_dim,
164
+ )
165
+ .to(torch.int64)
166
+ ) # (bs, 1, num_head, head_dim)
167
+
168
+ k_out = self.k_cache
169
+ v_out = self.v_cache
170
+ k_out.scatter_(1, index, k_val)
171
+ v_out.scatter_(1, index, v_val)
172
+
173
+ return k_out, v_out
174
+
175
+ def empty(self):
176
+ self.k_cache.zero_()
177
+ self.v_cache.zero_()
178
+
179
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor):
180
+ # input_pos: int, k_val: [B, S, H, D]
181
+
182
+ self.k_cache[:, : k_val.shape[1]] = k_val
183
+ self.v_cache[:, : v_val.shape[1]] = v_val
184
+
185
+
186
+ class KVCacheHND(KVCacheABC):
187
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
188
+ super().__init__(batch_size, max_seq_length, n_heads, head_dim)
189
+
190
+ cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
191
+
192
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
193
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
194
+
195
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
196
+ # input_pos: [B, ], k_val: [B, H, 1, D]
197
+
198
+ index = (
199
+ (input_pos - 1)
200
+ .unsqueeze(-1)
201
+ .unsqueeze(-1)
202
+ .unsqueeze(-1)
203
+ .expand(
204
+ -1,
205
+ self.n_head,
206
+ -1,
207
+ self.head_dim,
208
+ )
209
+ .to(torch.int64)
210
+ ) # (bs, num_head, 1, head_dim)
211
+
212
+ k_out = self.k_cache
213
+ v_out = self.v_cache
214
+ k_out.scatter_(2, index, k_val)
215
+ v_out.scatter_(2, index, v_val)
216
+
217
+ return k_out, v_out
218
+
219
+ def empty(self):
220
+ self.k_cache.zero_()
221
+ self.v_cache.zero_()
222
+
223
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor):
224
+ # input_pos: int, k_val: [B, S, H, D]
225
+
226
+ self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
227
+ self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
228
+
229
+
230
+ class KVCacheHNDVarlen(KVCacheABC):
231
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
232
+ super().__init__(batch_size, max_seq_length, n_heads, head_dim)
233
+
234
+ cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
235
+ self.cache_idx: Tensor
236
+
237
+ self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
238
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
239
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
240
+
241
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
242
+ # input_pos: [B, ], k_val: [B, H, 1, D]
243
+
244
+ k_out = self.k_cache
245
+ v_out = self.v_cache
246
+
247
+ ip0 = input_pos - 1
248
+
249
+ k_out[self.cache_idx, :, ip0, None] = k_val
250
+ v_out[self.cache_idx, :, ip0, None] = v_val
251
+
252
+ return k_out, v_out
253
+
254
+ def empty(self):
255
+ self.k_cache.zero_()
256
+ self.v_cache.zero_()
257
+
258
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor):
259
+ # input_pos: int, k_val: [B, S, H, D]
260
+
261
+ self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
262
+ self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
263
+
264
+
265
+ class AttentionABC(nn.Module, ABC):
266
+ def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
267
+ super().__init__()
268
+
269
+ self.n_head = n_head
270
+ self.hidden_dim = hidden_dim
271
+ assert hidden_dim % n_head == 0
272
+ self.head_dim = hidden_dim // n_head
273
+
274
+ self.max_seq_length = max_seq_length
275
+
276
+ # key, query, value projections for all heads, but in a batch
277
+ self.in_proj: nn.Linear
278
+ self.out_proj: nn.Linear
279
+
280
+ self._register_load_state_dict_pre_hook(self.load_hook)
281
+
282
+ def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
283
+ keys_to_modify = [key for key in state_dict if "in_proj_" in key]
284
+ for key in keys_to_modify:
285
+ new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
286
+ state_dict[new_key] = state_dict.pop(key)
287
+
288
+ @abstractmethod
289
+ def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
290
+
291
+ def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
292
+ bsz, seqlen, _ = x.shape
293
+
294
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
295
+
296
+ q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
297
+
298
+ kv_cache.prefill_kv(k, v)
299
+
300
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
301
+
302
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
303
+
304
+ attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
305
+
306
+ output = self.out_proj(attn)
307
+
308
+ return output
309
+
310
+
311
+ class FeedForward(nn.Module):
312
+ def __init__(self, dim: int, hidden_dim: int) -> None:
313
+ super().__init__()
314
+
315
+ self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
316
+ self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
317
+
318
+ def __call__(self, x: Tensor):
319
+ return self.linear2(F.relu(self.linear1(x), inplace=True))
320
+
321
+
322
+ class TransformerBlockABC(nn.Module, ABC):
323
+ def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
324
+ super().__init__()
325
+
326
+ self.hidden_dim = hidden_dim
327
+ self.max_seq_length = max_seq_length
328
+
329
+ self.attention: AttentionABC
330
+ self.feed_forward: FeedForward
331
+ self.attention_norm: nn.LayerNorm
332
+ self.ffn_norm: nn.LayerNorm
333
+
334
+ self._register_load_state_dict_pre_hook(self.load_hook)
335
+
336
+ def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
337
+ for key in list(state_dict.keys()):
338
+ new_key = (
339
+ key.replace("self_attn", "attention")
340
+ .replace("linear", "feed_forward.linear")
341
+ .replace("norm1", "attention_norm")
342
+ .replace("norm2", "ffn_norm")
343
+ )
344
+ state_dict[new_key] = state_dict.pop(key)
345
+
346
+ def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
347
+ h = self.attention_norm(
348
+ x
349
+ + self.attention(
350
+ x,
351
+ input_pos,
352
+ kv_cache,
353
+ *args,
354
+ **kwds,
355
+ )
356
+ )
357
+ out = self.ffn_norm(h + self.feed_forward(h))
358
+ return out
359
+
360
+ def prefill(
361
+ self,
362
+ x: Tensor,
363
+ kv_cache: KVCacheProtocol,
364
+ attn_mask: Tensor,
365
+ ) -> Tensor:
366
+ h = self.attention_norm(
367
+ x
368
+ + self.attention.prefill(
369
+ x,
370
+ kv_cache,
371
+ attn_mask,
372
+ )
373
+ )
374
+ out = self.ffn_norm(h + self.feed_forward(h))
375
+ return out
376
+
377
+
378
+ class TransformerDecoderABC(nn.Module, ABC):
379
+ def __init__(
380
+ self,
381
+ hidden_dim: int,
382
+ n_layer: int,
383
+ n_head: int,
384
+ ffn_dim: int,
385
+ vocab_size: int,
386
+ max_seq_length: int,
387
+ max_batch_size: int,
388
+ ) -> None:
389
+ super().__init__()
390
+
391
+ self.hidden_dim = hidden_dim
392
+ self.n_head = n_head
393
+ assert hidden_dim % n_head == 0
394
+
395
+ self.head_dim = hidden_dim // n_head
396
+ self.vocab_size = vocab_size
397
+
398
+ self.n_layer = n_layer
399
+
400
+ self.layers: MutableSequence[TransformerBlockABC]
401
+
402
+ self.max_seq_length = max_seq_length
403
+ self.max_batch_size = max_batch_size
404
+
405
+ def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
406
+ for layer, kv_cache in zip(self.layers, kv_caches):
407
+ x = layer(x, input_pos, kv_cache, *args, **kwds)
408
+ return x
409
+
410
+ def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
411
+ for layer, kv_cache in zip(self.layers, kv_caches):
412
+ x = layer.prefill(x, kv_cache, attn_mask)
413
+ return x
414
+
415
+
416
+ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
417
+ def __init__(
418
+ self,
419
+ config: dict,
420
+ max_seq_length: int = 2000,
421
+ max_batch_size: int = 10,
422
+ ) -> None:
423
+ super().__init__()
424
+
425
+ hidden_dim: int = config["model"]["hidden_dim"]
426
+ embedding_dim: int = config["model"]["embedding_dim"]
427
+ n_head: int = config["model"]["head"]
428
+ n_layer: int = config["model"]["n_layer"]
429
+ vocab_size: int = config["model"]["vocab_size"]
430
+ phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
431
+ EOS: int = config["model"]["EOS"]
432
+ ffn_dim: int = hidden_dim * 4
433
+
434
+ self.n_layer = int(n_layer)
435
+ self.hidden_dim = int(hidden_dim)
436
+ self.n_head = int(n_head)
437
+ assert hidden_dim % n_head == 0
438
+
439
+ self.head_dim = int(hidden_dim // n_head)
440
+ self.embedding_dim = int(embedding_dim)
441
+ self.ffn_dim = int(ffn_dim)
442
+ self.vocab_size = int(vocab_size)
443
+ self.phoneme_vocab_size = int(phoneme_vocab_size)
444
+ self.max_seq_length = max_seq_length
445
+ self.max_batch_size = max_batch_size
446
+ self.EOS = EOS
447
+ assert self.EOS == self.vocab_size - 1
448
+
449
+ self.bert_proj: nn.Linear
450
+ self.ar_predict_layer: nn.Linear
451
+ self.h: TransformerDecoderABC
452
+
453
+ self.kv_class: type[KVCacheABC]
454
+
455
+ self.GraphCache: CUDAGraphCacheABC | None
456
+
457
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
458
+ self.ar_text_position = SinePositionalEmbedding(
459
+ self.embedding_dim,
460
+ scale=False,
461
+ alpha=True,
462
+ max_batch_size=max_batch_size,
463
+ max_seq_len=max_seq_length,
464
+ )
465
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
466
+ self.ar_audio_position = SinePositionalEmbedding(
467
+ self.embedding_dim,
468
+ scale=False,
469
+ alpha=True,
470
+ max_batch_size=max_batch_size,
471
+ max_seq_len=max_seq_length,
472
+ )
473
+
474
+ self._register_load_state_dict_pre_hook(self.load_hook)
475
+
476
+ def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
477
+ model_keys = [key for key in state_dict if key.startswith("model.")]
478
+ for key in model_keys:
479
+ new_key = key[len("model.") :]
480
+ state_dict[new_key] = state_dict.pop(key)
481
+
482
+ def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
483
+ bsz = bsz or self.h.max_batch_size
484
+ assert bsz <= self.h.max_batch_size
485
+ seq_lens = self.h.max_seq_length
486
+ dtype = self.bert_proj.bias.dtype
487
+ kvclass = self.kv_class
488
+
489
+ return nn.ModuleList(
490
+ [kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
491
+ ).to(self.device, dtype) # type: ignore
492
+
493
+ def embed(
494
+ self,
495
+ x: list[torch.Tensor],
496
+ y: torch.Tensor,
497
+ bert_features: list[torch.Tensor],
498
+ ):
499
+ x_len: list[int] = [i.shape[0] for i in x]
500
+ x_len_max = max(x_len)
501
+ xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
502
+
503
+ bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
504
+
505
+ y_len = y.shape[1]
506
+ y_emb = self.ar_audio_embedding(y)
507
+ y_pos = self.ar_audio_position.prefill(y_emb)
508
+
509
+ for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
510
+ x_emb = self.ar_text_embedding(x_)
511
+ bert = self.bert_proj(bert_feature)
512
+ x_emb = x_emb + bert
513
+ x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
514
+ xy_pos[[bs], :len_] = x_pos
515
+ xy_pos[[bs], len_ : len_ + y_len] = y_pos
516
+
517
+ return xy_pos
518
+
519
+ def compile(self, *args, **kwds):
520
+ # Experimental features to reduce compilation times, will be on by default in future
521
+ torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
522
+ torch._inductor.config.coordinate_descent_tuning = True
523
+ torch._inductor.config.triton.unique_kernel_names = True
524
+ torch._inductor.config.fx_graph_cache = True
525
+ torch._inductor.config.triton.cudagraph_trees = True
526
+ torch._inductor.config.triton.cudagraph_support_input_mutation = True
527
+ self.h.compile(fullgraph=True, mode="reduce-overhead")
528
+
529
+ def capture(
530
+ self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
531
+ ) -> CUDAGraph:
532
+ assert torch.cuda.is_available()
533
+ s = torch.cuda.Stream()
534
+ s.wait_stream(torch.cuda.current_stream())
535
+
536
+ graph = torch.cuda.CUDAGraph()
537
+
538
+ with torch.cuda.stream(s):
539
+ for _ in range(5):
540
+ self.h(input_pos, x, kv_caches, *args, **kwds)
541
+ torch.cuda.current_stream().wait_stream(s)
542
+
543
+ with torch.cuda.graph(graph):
544
+ x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
545
+ torch.cuda.synchronize()
546
+
547
+ return graph
548
+
549
+ @abstractmethod
550
+ def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
551
+ return list(), dict()
552
+
553
+ @abstractmethod
554
+ def post_forward(self, idx: int, session: T2SSession) -> None:
555
+ return
556
+
557
+
558
+ class CUDAGraphCacheABC(ABC):
559
+ def __init__(
560
+ self,
561
+ decoder: T2SDecoderABC,
562
+ ) -> None:
563
+ self.is_applicable: bool
564
+
565
+ if torch.cuda.is_available() and self.is_applicable:
566
+ self.device: torch.device = decoder.device
567
+ self.dtype = decoder.bert_proj.bias.dtype
568
+
569
+ self.assigned: bool = False
570
+
571
+ self.decoder: T2SDecoderABC = decoder
572
+ self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
573
+ self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
574
+ self.dtype
575
+ )
576
+ self.xy_dec = self.xy_pos.clone()
577
+
578
+ self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
579
+ self.graph: torch.cuda.CUDAGraph | None = None
580
+ self.stream: torch.cuda.Stream | None
581
+
582
+ self.id: int = random.randint(1, 2**32 - 1)
583
+
584
+ def assign_graph(self, session: T2SSession):
585
+ if self.graph is None:
586
+ args, kwds = self.decoder.pre_forward(session)
587
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
588
+ self.graph = graph
589
+ self.stream = torch.cuda.Stream()
590
+
591
+ if self.assigned is False:
592
+ self.get_cache_graph(session)
593
+ session.id = self.id
594
+ self.assigned = True
595
+ else:
596
+ self.capture_new_graph(session)
597
+
598
+ @abstractmethod
599
+ def release_graph(self, session: T2SSession): ...
600
+
601
+ @abstractmethod
602
+ def get_cache_graph(self, session: T2SSession):
603
+ pass
604
+
605
+ @abstractmethod
606
+ def capture_new_graph(self, session: T2SSession):
607
+ pass
608
+
609
+
610
+ class TorchProfiler:
611
+ def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
612
+ self.debug = debug
613
+ self.log_dir = log_dir
614
+ self.__profiler: torch.profiler.profile
615
+
616
+ if self.debug and not os.path.exists(self.log_dir):
617
+ os.makedirs(self.log_dir)
618
+
619
+ self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
620
+
621
+ def profiler_callback(self, prof: torch.profiler.profile):
622
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
623
+ print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
624
+ self.tensorboard_handler(prof)
625
+
626
+ @staticmethod
627
+ def three_step_schedule(step: int) -> ProfilerAction:
628
+ if step == 0:
629
+ return ProfilerAction.NONE
630
+ elif step == 1:
631
+ return ProfilerAction.RECORD
632
+ elif step == 2:
633
+ return ProfilerAction.RECORD_AND_SAVE
634
+ else:
635
+ return ProfilerAction.NONE
636
+
637
+ def start(self):
638
+ if not self.debug:
639
+ return
640
+ assert self.__profiler is not None
641
+ self.__profiler.step()
642
+
643
+ def end(self):
644
+ if not self.debug:
645
+ return
646
+ assert self.__profiler is not None
647
+ self.__profiler.step()
648
+
649
+ def profiler(self):
650
+ if self.debug:
651
+ activities_list = [torch.profiler.ProfilerActivity.CPU]
652
+ if torch.cuda.is_available():
653
+ activities_list.append(torch.profiler.ProfilerActivity.CUDA)
654
+
655
+ self.__profiler = torch.profiler.profile(
656
+ activities=activities_list,
657
+ record_shapes=True,
658
+ with_stack=True,
659
+ with_modules=True,
660
+ profile_memory=True,
661
+ schedule=self.three_step_schedule,
662
+ on_trace_ready=self.profiler_callback,
663
+ )
664
+ return self.__profiler
665
+ else:
666
+ return nullcontext()
667
+
668
+ def record(self, func_name: str):
669
+ if self.debug:
670
+ return torch.profiler.record_function(func_name)
671
+ else:
672
+ return nullcontext()
GPT_SoVITS/Accelerate/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import MLX, PyTorch
2
+ from .logger import console, logger, tb
3
+ from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
4
+ from .PyTorch.structs import T2SEngineProtocol
5
+
6
+ backends = PyTorch.backends + MLX.backends
7
+
8
+ backends = [
9
+ b.replace("_", "-")
10
+ .title()
11
+ .replace("Mlx", "MLX")
12
+ .replace("Mps", "MPS")
13
+ .replace("Cuda", "CUDA")
14
+ .replace("Mxfp4", "MXFP4")
15
+ for b in backends
16
+ ]
17
+
18
+
19
+ __all__ = [
20
+ "T2SEngineTorch",
21
+ "T2SRequest",
22
+ "T2SResult",
23
+ "backends",
24
+ "MLX",
25
+ "PyTorch",
26
+ "logger",
27
+ "console",
28
+ "tb",
29
+ "T2SEngineProtocol",
30
+ ]
GPT_SoVITS/Accelerate/logger.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Optional
3
+
4
+ from loguru import logger
5
+ from rich.console import Console, JustifyMethod
6
+ from rich.highlighter import Highlighter
7
+ from rich.logging import RichHandler
8
+ from rich.progress import Task, TextColumn
9
+ from rich.style import StyleType
10
+ from rich.table import Column
11
+ from rich.text import Text
12
+ from rich.traceback import Traceback, install
13
+
14
+ console = Console(stderr=False)
15
+ install(console=console)
16
+
17
+
18
+ def loguru_format(record):
19
+ level = record["level"].name
20
+ color = {
21
+ "DEBUG": "green",
22
+ "INFO": "blue",
23
+ "WARNING": "yellow",
24
+ "ERROR": "red",
25
+ "CRITICAL": "bright_red",
26
+ }.get(level, "white")
27
+
28
+ return f"[bold {color}][{level}][/bold {color}] " + "{message}"
29
+
30
+
31
+ handler_with_locals = RichHandler(
32
+ console=console,
33
+ show_time=False,
34
+ show_path=False,
35
+ rich_tracebacks=True,
36
+ tracebacks_show_locals=True,
37
+ show_level=False,
38
+ markup=True,
39
+ )
40
+ handler_without_locals = RichHandler(
41
+ console=console,
42
+ show_time=False,
43
+ show_path=False,
44
+ rich_tracebacks=True,
45
+ tracebacks_show_locals=False,
46
+ show_level=False,
47
+ markup=True,
48
+ )
49
+
50
+
51
+ def local_filter(r):
52
+ return r["extra"].get("show_locals", True)
53
+
54
+
55
+ logger.remove()
56
+ logger.add(handler_with_locals, format=loguru_format, filter=local_filter)
57
+ logger.add(handler_without_locals, format=loguru_format, filter=lambda x: not local_filter(x))
58
+
59
+
60
+ class SpeedColumnToken(TextColumn):
61
+ """Show task progress as a percentage.
62
+
63
+ Args:
64
+ text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
65
+ text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
66
+ style (StyleType, optional): Style of output. Defaults to "none".
67
+ justify (JustifyMethod, optional): Text justification. Defaults to "left".
68
+ markup (bool, optional): Enable markup. Defaults to True.
69
+ highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
70
+ table_column (Optional[Column], optional): Table Column to use. Defaults to None.
71
+ show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
77
+ text_format_no_percentage: str = "",
78
+ style: StyleType = "none",
79
+ justify: JustifyMethod = "left",
80
+ markup: bool = True,
81
+ highlighter: Optional[Highlighter] = None,
82
+ table_column: Optional[Column] = None,
83
+ show_speed: bool = True,
84
+ ) -> None:
85
+ self.text_format_no_percentage = text_format_no_percentage
86
+ self.show_speed = show_speed
87
+ super().__init__(
88
+ text_format=text_format,
89
+ style=style,
90
+ justify=justify,
91
+ markup=markup,
92
+ highlighter=highlighter,
93
+ table_column=table_column,
94
+ )
95
+
96
+ @classmethod
97
+ def render_speed(cls, speed: Optional[float]) -> Text:
98
+ """Render the speed in iterations per second.
99
+
100
+ Args:
101
+ task (Task): A Task object.
102
+
103
+ Returns:
104
+ Text: Text object containing the task speed.
105
+ """
106
+ if speed is None:
107
+ return Text("", style="progress.percentage")
108
+ return Text(f"{speed:.1f} token/s", style="progress.percentage")
109
+
110
+ def render(self, task: Task) -> Text:
111
+ if self.show_speed:
112
+ return self.render_speed(task.finished_speed or task.speed)
113
+ text_format = self.text_format_no_percentage if task.total is None else self.text_format
114
+ _text = text_format.format(task=task)
115
+ if self.markup:
116
+ text = Text.from_markup(_text, style=self.style, justify=self.justify)
117
+ else:
118
+ text = Text(_text, style=self.style, justify=self.justify)
119
+ if self.highlighter:
120
+ self.highlighter.highlight(text)
121
+ return text
122
+
123
+
124
+ class SpeedColumnIteration(TextColumn):
125
+ """Show task progress as a percentage.
126
+
127
+ Args:
128
+ text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
129
+ text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
130
+ style (StyleType, optional): Style of output. Defaults to "none".
131
+ justify (JustifyMethod, optional): Text justification. Defaults to "left".
132
+ markup (bool, optional): Enable markup. Defaults to True.
133
+ highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
134
+ table_column (Optional[Column], optional): Table Column to use. Defaults to None.
135
+ show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
141
+ text_format_no_percentage: str = "",
142
+ style: StyleType = "none",
143
+ justify: JustifyMethod = "left",
144
+ markup: bool = True,
145
+ highlighter: Optional[Highlighter] = None,
146
+ table_column: Optional[Column] = None,
147
+ show_speed: bool = True,
148
+ ) -> None:
149
+ self.text_format_no_percentage = text_format_no_percentage
150
+ self.show_speed = show_speed
151
+ super().__init__(
152
+ text_format=text_format,
153
+ style=style,
154
+ justify=justify,
155
+ markup=markup,
156
+ highlighter=highlighter,
157
+ table_column=table_column,
158
+ )
159
+
160
+ @classmethod
161
+ def render_speed(cls, speed: Optional[float]) -> Text:
162
+ """Render the speed in iterations per second.
163
+
164
+ Args:
165
+ task (Task): A Task object.
166
+
167
+ Returns:
168
+ Text: Text object containing the task speed.
169
+ """
170
+ if speed is None:
171
+ return Text("", style="progress.percentage")
172
+ return Text(f"{speed:.1f} it/s", style="progress.percentage")
173
+
174
+ def render(self, task: Task) -> Text:
175
+ if self.show_speed:
176
+ return self.render_speed(task.finished_speed or task.speed)
177
+ text_format = self.text_format_no_percentage if task.total is None else self.text_format
178
+ _text = text_format.format(task=task)
179
+ if self.markup:
180
+ text = Text.from_markup(_text, style=self.style, justify=self.justify)
181
+ else:
182
+ text = Text(_text, style=self.style, justify=self.justify)
183
+ if self.highlighter:
184
+ self.highlighter.highlight(text)
185
+ return text
186
+
187
+
188
+ def tb(show_locals: bool = True):
189
+ exc_type, exc_value, exc_tb = sys.exc_info()
190
+ assert exc_type
191
+ assert exc_value
192
+ tb = Traceback.from_exception(exc_type, exc_value, exc_tb, show_locals=show_locals)
193
+
194
+ return tb
195
+
196
+
197
+ __all__ = ["logger", "console", "tb", "SpeedColumnToken", "SpeedColumnIteration"]
198
+
199
+ if __name__ == "__main__":
200
+ try:
201
+ raise RuntimeError()
202
+ except Exception:
203
+ logger.bind(show_locals=False).exception("TEST")
GPT_SoVITS/configs/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.yaml
GPT_SoVITS/configs/s2.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4,
22
+ "grad_ckpt": false
23
+ },
24
+ "data": {
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 32000,
27
+ "filter_length": 2048,
28
+ "hop_length": 640,
29
+ "win_length": 2048,
30
+ "n_mel_channels": 128,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null,
33
+ "add_blank": true,
34
+ "n_speakers": 300,
35
+ "cleaned_text": true
36
+ },
37
+ "model": {
38
+ "inter_channels": 192,
39
+ "hidden_channels": 192,
40
+ "filter_channels": 768,
41
+ "n_heads": 2,
42
+ "n_layers": 6,
43
+ "kernel_size": 3,
44
+ "p_dropout": 0.1,
45
+ "resblock": "1",
46
+ "resblock_kernel_sizes": [
47
+ 3,
48
+ 7,
49
+ 11
50
+ ],
51
+ "resblock_dilation_sizes": [
52
+ [
53
+ 1,
54
+ 3,
55
+ 5
56
+ ],
57
+ [
58
+ 1,
59
+ 3,
60
+ 5
61
+ ],
62
+ [
63
+ 1,
64
+ 3,
65
+ 5
66
+ ]
67
+ ],
68
+ "upsample_rates": [
69
+ 10,
70
+ 8,
71
+ 2,
72
+ 2,
73
+ 2
74
+ ],
75
+ "upsample_initial_channel": 512,
76
+ "upsample_kernel_sizes": [
77
+ 16,
78
+ 16,
79
+ 8,
80
+ 2,
81
+ 2
82
+ ],
83
+ "n_layers_q": 3,
84
+ "use_spectral_norm": false,
85
+ "gin_channels": 512,
86
+ "semantic_frame_rate": "25hz",
87
+ "freeze_quantizer": true
88
+ },
89
+ "s2_ckpt_dir": "logs/s2/big2k1",
90
+ "content_module": "cnhubert"
91
+ }
GPT_SoVITS/configs/s2v2Pro.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4,
22
+ "grad_ckpt": false
23
+ },
24
+ "data": {
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 32000,
27
+ "filter_length": 2048,
28
+ "hop_length": 640,
29
+ "win_length": 2048,
30
+ "n_mel_channels": 128,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null,
33
+ "add_blank": true,
34
+ "n_speakers": 300,
35
+ "cleaned_text": true
36
+ },
37
+ "model": {
38
+ "inter_channels": 192,
39
+ "hidden_channels": 192,
40
+ "filter_channels": 768,
41
+ "n_heads": 2,
42
+ "n_layers": 6,
43
+ "kernel_size": 3,
44
+ "p_dropout": 0.0,
45
+ "resblock": "1",
46
+ "resblock_kernel_sizes": [
47
+ 3,
48
+ 7,
49
+ 11
50
+ ],
51
+ "resblock_dilation_sizes": [
52
+ [
53
+ 1,
54
+ 3,
55
+ 5
56
+ ],
57
+ [
58
+ 1,
59
+ 3,
60
+ 5
61
+ ],
62
+ [
63
+ 1,
64
+ 3,
65
+ 5
66
+ ]
67
+ ],
68
+ "upsample_rates": [
69
+ 10,
70
+ 8,
71
+ 2,
72
+ 2,
73
+ 2
74
+ ],
75
+ "upsample_initial_channel": 512,
76
+ "upsample_kernel_sizes": [
77
+ 16,
78
+ 16,
79
+ 8,
80
+ 2,
81
+ 2
82
+ ],
83
+ "n_layers_q": 3,
84
+ "use_spectral_norm": false,
85
+ "gin_channels": 1024,
86
+ "semantic_frame_rate": "25hz",
87
+ "freeze_quantizer": true
88
+ },
89
+ "s2_ckpt_dir": "logs/s2/big2k1",
90
+ "content_module": "cnhubert"
91
+ }
GPT_SoVITS/configs/s2v2ProPlus.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4,
22
+ "grad_ckpt": false
23
+ },
24
+ "data": {
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 32000,
27
+ "filter_length": 2048,
28
+ "hop_length": 640,
29
+ "win_length": 2048,
30
+ "n_mel_channels": 128,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null,
33
+ "add_blank": true,
34
+ "n_speakers": 300,
35
+ "cleaned_text": true
36
+ },
37
+ "model": {
38
+ "inter_channels": 192,
39
+ "hidden_channels": 192,
40
+ "filter_channels": 768,
41
+ "n_heads": 2,
42
+ "n_layers": 6,
43
+ "kernel_size": 3,
44
+ "p_dropout": 0.0,
45
+ "resblock": "1",
46
+ "resblock_kernel_sizes": [
47
+ 3,
48
+ 7,
49
+ 11
50
+ ],
51
+ "resblock_dilation_sizes": [
52
+ [
53
+ 1,
54
+ 3,
55
+ 5
56
+ ],
57
+ [
58
+ 1,
59
+ 3,
60
+ 5
61
+ ],
62
+ [
63
+ 1,
64
+ 3,
65
+ 5
66
+ ]
67
+ ],
68
+ "upsample_rates": [
69
+ 10,
70
+ 8,
71
+ 2,
72
+ 2,
73
+ 2
74
+ ],
75
+ "upsample_initial_channel": 768,
76
+ "upsample_kernel_sizes": [
77
+ 20,
78
+ 16,
79
+ 8,
80
+ 2,
81
+ 2
82
+ ],
83
+ "n_layers_q": 3,
84
+ "use_spectral_norm": false,
85
+ "gin_channels": 1024,
86
+ "semantic_frame_rate": "25hz",
87
+ "freeze_quantizer": true
88
+ },
89
+ "s2_ckpt_dir": "logs/s2/big2k1",
90
+ "content_module": "cnhubert"
91
+ }
GPT_SoVITS/eres2net/ERes2NetV2.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
6
+ within each stage. However, this modification also increases the number of model parameters and computational complexity.
7
+ To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
8
+ both the model parameters and its computational cost.
9
+ """
10
+
11
+ import math
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from . import pooling_layers as pooling_layers
18
+ from .fusion import AFF
19
+
20
+
21
+ class ReLU(nn.Hardtanh):
22
+ def __init__(self, inplace=False):
23
+ super(ReLU, self).__init__(0, 20, inplace)
24
+
25
+ def __repr__(self):
26
+ inplace_str = "inplace" if self.inplace else ""
27
+ return self.__class__.__name__ + " (" + inplace_str + ")"
28
+
29
+
30
+ class BasicBlockERes2NetV2(nn.Module):
31
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
32
+ super(BasicBlockERes2NetV2, self).__init__()
33
+ width = int(math.floor(planes * (baseWidth / 64.0)))
34
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
35
+ self.bn1 = nn.BatchNorm2d(width * scale)
36
+ self.nums = scale
37
+ self.expansion = expansion
38
+
39
+ convs = []
40
+ bns = []
41
+ for i in range(self.nums):
42
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
43
+ bns.append(nn.BatchNorm2d(width))
44
+ self.convs = nn.ModuleList(convs)
45
+ self.bns = nn.ModuleList(bns)
46
+ self.relu = ReLU(inplace=True)
47
+
48
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
49
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
50
+ self.shortcut = nn.Sequential()
51
+ if stride != 1 or in_planes != self.expansion * planes:
52
+ self.shortcut = nn.Sequential(
53
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
54
+ nn.BatchNorm2d(self.expansion * planes),
55
+ )
56
+ self.stride = stride
57
+ self.width = width
58
+ self.scale = scale
59
+
60
+ def forward(self, x):
61
+ residual = x
62
+
63
+ out = self.conv1(x)
64
+ out = self.bn1(out)
65
+ out = self.relu(out)
66
+ spx = torch.split(out, self.width, 1)
67
+ for i in range(self.nums):
68
+ if i == 0:
69
+ sp = spx[i]
70
+ else:
71
+ sp = sp + spx[i]
72
+ sp = self.convs[i](sp)
73
+ sp = self.relu(self.bns[i](sp))
74
+ if i == 0:
75
+ out = sp
76
+ else:
77
+ out = torch.cat((out, sp), 1)
78
+
79
+ out = self.conv3(out)
80
+ out = self.bn3(out)
81
+
82
+ residual = self.shortcut(x)
83
+ out += residual
84
+ out = self.relu(out)
85
+
86
+ return out
87
+
88
+
89
+ class BasicBlockERes2NetV2AFF(nn.Module):
90
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
91
+ super(BasicBlockERes2NetV2AFF, self).__init__()
92
+ width = int(math.floor(planes * (baseWidth / 64.0)))
93
+ self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
94
+ self.bn1 = nn.BatchNorm2d(width * scale)
95
+ self.nums = scale
96
+ self.expansion = expansion
97
+
98
+ convs = []
99
+ fuse_models = []
100
+ bns = []
101
+ for i in range(self.nums):
102
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
103
+ bns.append(nn.BatchNorm2d(width))
104
+ for j in range(self.nums - 1):
105
+ fuse_models.append(AFF(channels=width, r=4))
106
+
107
+ self.convs = nn.ModuleList(convs)
108
+ self.bns = nn.ModuleList(bns)
109
+ self.fuse_models = nn.ModuleList(fuse_models)
110
+ self.relu = ReLU(inplace=True)
111
+
112
+ self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
113
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
114
+ self.shortcut = nn.Sequential()
115
+ if stride != 1 or in_planes != self.expansion * planes:
116
+ self.shortcut = nn.Sequential(
117
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
118
+ nn.BatchNorm2d(self.expansion * planes),
119
+ )
120
+ self.stride = stride
121
+ self.width = width
122
+ self.scale = scale
123
+
124
+ def forward(self, x):
125
+ residual = x
126
+
127
+ out = self.conv1(x)
128
+ out = self.bn1(out)
129
+ out = self.relu(out)
130
+ spx = torch.split(out, self.width, 1)
131
+ for i in range(self.nums):
132
+ if i == 0:
133
+ sp = spx[i]
134
+ else:
135
+ sp = self.fuse_models[i - 1](sp, spx[i])
136
+
137
+ sp = self.convs[i](sp)
138
+ sp = self.relu(self.bns[i](sp))
139
+ if i == 0:
140
+ out = sp
141
+ else:
142
+ out = torch.cat((out, sp), 1)
143
+
144
+ out = self.conv3(out)
145
+ out = self.bn3(out)
146
+
147
+ residual = self.shortcut(x)
148
+ out += residual
149
+ out = self.relu(out)
150
+
151
+ return out
152
+
153
+
154
+ class ERes2NetV2(nn.Module):
155
+ def __init__(
156
+ self,
157
+ block=BasicBlockERes2NetV2,
158
+ block_fuse=BasicBlockERes2NetV2AFF,
159
+ num_blocks=[3, 4, 6, 3],
160
+ m_channels=64,
161
+ feat_dim=80,
162
+ embedding_size=192,
163
+ baseWidth=26,
164
+ scale=2,
165
+ expansion=2,
166
+ pooling_func="TSTP",
167
+ two_emb_layer=False,
168
+ ):
169
+ super(ERes2NetV2, self).__init__()
170
+ self.in_planes = m_channels
171
+ self.feat_dim = feat_dim
172
+ self.embedding_size = embedding_size
173
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
174
+ self.two_emb_layer = two_emb_layer
175
+ self.baseWidth = baseWidth
176
+ self.scale = scale
177
+ self.expansion = expansion
178
+
179
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
180
+ self.bn1 = nn.BatchNorm2d(m_channels)
181
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
182
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
183
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
184
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
185
+
186
+ # Downsampling module
187
+ self.layer3_ds = nn.Conv2d(
188
+ m_channels * 4 * self.expansion,
189
+ m_channels * 8 * self.expansion,
190
+ kernel_size=3,
191
+ padding=1,
192
+ stride=2,
193
+ bias=False,
194
+ )
195
+
196
+ # Bottom-up fusion module
197
+ self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
198
+
199
+ self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
200
+ self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
201
+ self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
202
+ if self.two_emb_layer:
203
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
204
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
205
+ else:
206
+ self.seg_bn_1 = nn.Identity()
207
+ self.seg_2 = nn.Identity()
208
+
209
+ def _make_layer(self, block, planes, num_blocks, stride):
210
+ strides = [stride] + [1] * (num_blocks - 1)
211
+ layers = []
212
+ for stride in strides:
213
+ layers.append(
214
+ block(
215
+ self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
216
+ )
217
+ )
218
+ self.in_planes = planes * self.expansion
219
+ return nn.Sequential(*layers)
220
+
221
+ def forward(self, x):
222
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
223
+ x = x.unsqueeze_(1)
224
+ out = F.relu(self.bn1(self.conv1(x)))
225
+ out1 = self.layer1(out)
226
+ out2 = self.layer2(out1)
227
+ out3 = self.layer3(out2)
228
+ out4 = self.layer4(out3)
229
+ out3_ds = self.layer3_ds(out3)
230
+ fuse_out34 = self.fuse34(out4, out3_ds)
231
+ stats = self.pool(fuse_out34)
232
+
233
+ embed_a = self.seg_1(stats)
234
+ if self.two_emb_layer:
235
+ out = F.relu(embed_a)
236
+ out = self.seg_bn_1(out)
237
+ embed_b = self.seg_2(out)
238
+ return embed_b
239
+ else:
240
+ return embed_a
241
+
242
+ def forward3(self, x):
243
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
244
+ x = x.unsqueeze_(1)
245
+ out = F.relu(self.bn1(self.conv1(x)))
246
+ out1 = self.layer1(out)
247
+ out2 = self.layer2(out1)
248
+ out3 = self.layer3(out2)
249
+ out4 = self.layer4(out3)
250
+ out3_ds = self.layer3_ds(out3)
251
+ fuse_out34 = self.fuse34(out4, out3_ds)
252
+ return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
GPT_SoVITS/eres2net/fusion.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class AFF(nn.Module):
9
+ def __init__(self, channels=64, r=4):
10
+ super(AFF, self).__init__()
11
+ inter_channels = int(channels // r)
12
+
13
+ self.local_att = nn.Sequential(
14
+ nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
15
+ nn.BatchNorm2d(inter_channels),
16
+ nn.SiLU(inplace=True),
17
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
18
+ nn.BatchNorm2d(channels),
19
+ )
20
+
21
+ def forward(self, x, ds_y):
22
+ xa = torch.cat((x, ds_y), dim=1)
23
+ x_att = self.local_att(xa)
24
+ x_att = 1.0 + torch.tanh(x_att)
25
+ xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
26
+
27
+ return xo
GPT_SoVITS/eres2net/kaldi.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+ __all__ = [
9
+ "get_mel_banks",
10
+ "inverse_mel_scale",
11
+ "inverse_mel_scale_scalar",
12
+ "mel_scale",
13
+ "mel_scale_scalar",
14
+ "spectrogram",
15
+ "fbank",
16
+ "mfcc",
17
+ "vtln_warp_freq",
18
+ "vtln_warp_mel_freq",
19
+ ]
20
+
21
+ # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
+ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
+ # 1 milliseconds = 0.001 seconds
24
+ MILLISECONDS_TO_SECONDS = 0.001
25
+
26
+ # window types
27
+ HAMMING = "hamming"
28
+ HANNING = "hanning"
29
+ POVEY = "povey"
30
+ RECTANGULAR = "rectangular"
31
+ BLACKMAN = "blackman"
32
+ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
+
34
+
35
+ def _get_epsilon(device, dtype):
36
+ return EPSILON.to(device=device, dtype=dtype)
37
+
38
+
39
+ def _next_power_of_2(x: int) -> int:
40
+ r"""Returns the smallest power of 2 that is greater than x"""
41
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
+
43
+
44
+ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
+ r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
+ representing how the window is shifted along the waveform. Each row is a frame.
47
+
48
+ Args:
49
+ waveform (Tensor): Tensor of size ``num_samples``
50
+ window_size (int): Frame length
51
+ window_shift (int): Frame shift
52
+ snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
+ depends only on the frame_shift, and we reflect the data at the ends.
55
+
56
+ Returns:
57
+ Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
+ """
59
+ assert waveform.dim() == 1
60
+ num_samples = waveform.size(0)
61
+ strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
+
63
+ if snip_edges:
64
+ if num_samples < window_size:
65
+ return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
+ else:
67
+ m = 1 + (num_samples - window_size) // window_shift
68
+ else:
69
+ reversed_waveform = torch.flip(waveform, [0])
70
+ m = (num_samples + (window_shift // 2)) // window_shift
71
+ pad = window_size // 2 - window_shift // 2
72
+ pad_right = reversed_waveform
73
+ if pad > 0:
74
+ # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
+ # but we want [2, 1, 0, 0, 1, 2]
76
+ pad_left = reversed_waveform[-pad:]
77
+ waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
+ else:
79
+ # pad is negative so we want to trim the waveform at the front
80
+ waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
+
82
+ sizes = (m, window_size)
83
+ return waveform.as_strided(sizes, strides)
84
+
85
+
86
+ def _feature_window_function(
87
+ window_type: str,
88
+ window_size: int,
89
+ blackman_coeff: float,
90
+ device: torch.device,
91
+ dtype: int,
92
+ ) -> Tensor:
93
+ r"""Returns a window function with the given type and size"""
94
+ if window_type == HANNING:
95
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
+ elif window_type == HAMMING:
97
+ return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
+ elif window_type == POVEY:
99
+ # like hanning but goes to zero at edges
100
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
+ elif window_type == RECTANGULAR:
102
+ return torch.ones(window_size, device=device, dtype=dtype)
103
+ elif window_type == BLACKMAN:
104
+ a = 2 * math.pi / (window_size - 1)
105
+ window_function = torch.arange(window_size, device=device, dtype=dtype)
106
+ # can't use torch.blackman_window as they use different coefficients
107
+ return (
108
+ blackman_coeff
109
+ - 0.5 * torch.cos(a * window_function)
110
+ + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
+ ).to(device=device, dtype=dtype)
112
+ else:
113
+ raise Exception("Invalid window type " + window_type)
114
+
115
+
116
+ def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
+ r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
+ device, dtype = strided_input.device, strided_input.dtype
119
+ log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
+ if energy_floor == 0.0:
121
+ return log_energy
122
+ return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
+
124
+
125
+ def _get_waveform_and_window_properties(
126
+ waveform: Tensor,
127
+ channel: int,
128
+ sample_frequency: float,
129
+ frame_shift: float,
130
+ frame_length: float,
131
+ round_to_power_of_two: bool,
132
+ preemphasis_coefficient: float,
133
+ ) -> Tuple[Tensor, int, int, int]:
134
+ r"""Gets the waveform and window properties"""
135
+ channel = max(channel, 0)
136
+ assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
+ waveform = waveform[channel, :] # size (n)
138
+ window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
+ window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
+ padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
+
142
+ assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
+ window_size, len(waveform)
144
+ )
145
+ assert 0 < window_shift, "`window_shift` must be greater than 0"
146
+ assert padded_window_size % 2 == 0, (
147
+ "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
148
+ )
149
+ assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
+ assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
+ return waveform, window_shift, window_size, padded_window_size
152
+
153
+
154
+ def _get_window(
155
+ waveform: Tensor,
156
+ padded_window_size: int,
157
+ window_size: int,
158
+ window_shift: int,
159
+ window_type: str,
160
+ blackman_coeff: float,
161
+ snip_edges: bool,
162
+ raw_energy: bool,
163
+ energy_floor: float,
164
+ dither: float,
165
+ remove_dc_offset: bool,
166
+ preemphasis_coefficient: float,
167
+ ) -> Tuple[Tensor, Tensor]:
168
+ r"""Gets a window and its log energy
169
+
170
+ Returns:
171
+ (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
+ """
173
+ device, dtype = waveform.device, waveform.dtype
174
+ epsilon = _get_epsilon(device, dtype)
175
+
176
+ # size (m, window_size)
177
+ strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
+
179
+ if dither != 0.0:
180
+ rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
181
+ strided_input = strided_input + rand_gauss * dither
182
+
183
+ if remove_dc_offset:
184
+ # Subtract each row/frame by its mean
185
+ row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
186
+ strided_input = strided_input - row_means
187
+
188
+ if raw_energy:
189
+ # Compute the log energy of each row/frame before applying preemphasis and
190
+ # window function
191
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
192
+
193
+ if preemphasis_coefficient != 0.0:
194
+ # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
195
+ offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
196
+ 0
197
+ ) # size (m, window_size + 1)
198
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
199
+
200
+ # Apply window_function to each row/frame
201
+ window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
202
+ 0
203
+ ) # size (1, window_size)
204
+ strided_input = strided_input * window_function # size (m, window_size)
205
+
206
+ # Pad columns with zero until we reach size (m, padded_window_size)
207
+ if padded_window_size != window_size:
208
+ padding_right = padded_window_size - window_size
209
+ strided_input = torch.nn.functional.pad(
210
+ strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
211
+ ).squeeze(0)
212
+
213
+ # Compute energy after window function (not the raw one)
214
+ if not raw_energy:
215
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
216
+
217
+ return strided_input, signal_log_energy
218
+
219
+
220
+ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
221
+ # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
222
+ # it returns size (m, n)
223
+ if subtract_mean:
224
+ col_means = torch.mean(tensor, dim=0).unsqueeze(0)
225
+ tensor = tensor - col_means
226
+ return tensor
227
+
228
+
229
+ def spectrogram(
230
+ waveform: Tensor,
231
+ blackman_coeff: float = 0.42,
232
+ channel: int = -1,
233
+ dither: float = 0.0,
234
+ energy_floor: float = 1.0,
235
+ frame_length: float = 25.0,
236
+ frame_shift: float = 10.0,
237
+ min_duration: float = 0.0,
238
+ preemphasis_coefficient: float = 0.97,
239
+ raw_energy: bool = True,
240
+ remove_dc_offset: bool = True,
241
+ round_to_power_of_two: bool = True,
242
+ sample_frequency: float = 16000.0,
243
+ snip_edges: bool = True,
244
+ subtract_mean: bool = False,
245
+ window_type: str = POVEY,
246
+ ) -> Tensor:
247
+ r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
248
+ compute-spectrogram-feats.
249
+
250
+ Args:
251
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
252
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
253
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
254
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
255
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
256
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
257
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
258
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
259
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
260
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
261
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
262
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
263
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
264
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
265
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
266
+ to FFT. (Default: ``True``)
267
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
268
+ specified there) (Default: ``16000.0``)
269
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
270
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
271
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
272
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
273
+ it this way. (Default: ``False``)
274
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
275
+ (Default: ``'povey'``)
276
+
277
+ Returns:
278
+ Tensor: A spectrogram identical to what Kaldi would output. The shape is
279
+ (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
280
+ """
281
+ device, dtype = waveform.device, waveform.dtype
282
+ epsilon = _get_epsilon(device, dtype)
283
+
284
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
285
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
286
+ )
287
+
288
+ if len(waveform) < min_duration * sample_frequency:
289
+ # signal is too short
290
+ return torch.empty(0)
291
+
292
+ strided_input, signal_log_energy = _get_window(
293
+ waveform,
294
+ padded_window_size,
295
+ window_size,
296
+ window_shift,
297
+ window_type,
298
+ blackman_coeff,
299
+ snip_edges,
300
+ raw_energy,
301
+ energy_floor,
302
+ dither,
303
+ remove_dc_offset,
304
+ preemphasis_coefficient,
305
+ )
306
+
307
+ # size (m, padded_window_size // 2 + 1, 2)
308
+ fft = torch.fft.rfft(strided_input)
309
+
310
+ # Convert the FFT into a power spectrum
311
+ power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
312
+ power_spectrum[:, 0] = signal_log_energy
313
+
314
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
315
+ return power_spectrum
316
+
317
+
318
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
319
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
320
+
321
+
322
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
323
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
324
+
325
+
326
+ def mel_scale_scalar(freq: float) -> float:
327
+ return 1127.0 * math.log(1.0 + freq / 700.0)
328
+
329
+
330
+ def mel_scale(freq: Tensor) -> Tensor:
331
+ return 1127.0 * (1.0 + freq / 700.0).log()
332
+
333
+
334
+ def vtln_warp_freq(
335
+ vtln_low_cutoff: float,
336
+ vtln_high_cutoff: float,
337
+ low_freq: float,
338
+ high_freq: float,
339
+ vtln_warp_factor: float,
340
+ freq: Tensor,
341
+ ) -> Tensor:
342
+ r"""This computes a VTLN warping function that is not the same as HTK's one,
343
+ but has similar inputs (this function has the advantage of never producing
344
+ empty bins).
345
+
346
+ This function computes a warp function F(freq), defined between low_freq
347
+ and high_freq inclusive, with the following properties:
348
+ F(low_freq) == low_freq
349
+ F(high_freq) == high_freq
350
+ The function is continuous and piecewise linear with two inflection
351
+ points.
352
+ The lower inflection point (measured in terms of the unwarped
353
+ frequency) is at frequency l, determined as described below.
354
+ The higher inflection point is at a frequency h, determined as
355
+ described below.
356
+ If l <= f <= h, then F(f) = f/vtln_warp_factor.
357
+ If the higher inflection point (measured in terms of the unwarped
358
+ frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
359
+ Since (by the last point) F(h) == h/vtln_warp_factor, then
360
+ max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
361
+ h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
362
+ = vtln_high_cutoff * min(1, vtln_warp_factor).
363
+ If the lower inflection point (measured in terms of the unwarped
364
+ frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
365
+ This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
366
+ = vtln_low_cutoff * max(1, vtln_warp_factor)
367
+ Args:
368
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
369
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
370
+ low_freq (float): Lower frequency cutoffs in mel computation
371
+ high_freq (float): Upper frequency cutoffs in mel computation
372
+ vtln_warp_factor (float): Vtln warp factor
373
+ freq (Tensor): given frequency in Hz
374
+
375
+ Returns:
376
+ Tensor: Freq after vtln warp
377
+ """
378
+ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
379
+ assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
380
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
381
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
382
+ scale = 1.0 / vtln_warp_factor
383
+ Fl = scale * l # F(l)
384
+ Fh = scale * h # F(h)
385
+ assert l > low_freq and h < high_freq
386
+ # slope of left part of the 3-piece linear function
387
+ scale_left = (Fl - low_freq) / (l - low_freq)
388
+ # [slope of center part is just "scale"]
389
+
390
+ # slope of right part of the 3-piece linear function
391
+ scale_right = (high_freq - Fh) / (high_freq - h)
392
+
393
+ res = torch.empty_like(freq)
394
+
395
+ outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
396
+ before_l = torch.lt(freq, l) # freq < l
397
+ before_h = torch.lt(freq, h) # freq < h
398
+ after_h = torch.ge(freq, h) # freq >= h
399
+
400
+ # order of operations matter here (since there is overlapping frequency regions)
401
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
402
+ res[before_h] = scale * freq[before_h]
403
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
404
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
405
+
406
+ return res
407
+
408
+
409
+ def vtln_warp_mel_freq(
410
+ vtln_low_cutoff: float,
411
+ vtln_high_cutoff: float,
412
+ low_freq,
413
+ high_freq: float,
414
+ vtln_warp_factor: float,
415
+ mel_freq: Tensor,
416
+ ) -> Tensor:
417
+ r"""
418
+ Args:
419
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
420
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
421
+ low_freq (float): Lower frequency cutoffs in mel computation
422
+ high_freq (float): Upper frequency cutoffs in mel computation
423
+ vtln_warp_factor (float): Vtln warp factor
424
+ mel_freq (Tensor): Given frequency in Mel
425
+
426
+ Returns:
427
+ Tensor: ``mel_freq`` after vtln warp
428
+ """
429
+ return mel_scale(
430
+ vtln_warp_freq(
431
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
432
+ )
433
+ )
434
+
435
+
436
+ def get_mel_banks(
437
+ num_bins: int,
438
+ window_length_padded: int,
439
+ sample_freq: float,
440
+ low_freq: float,
441
+ high_freq: float,
442
+ vtln_low: float,
443
+ vtln_high: float,
444
+ vtln_warp_factor: float,
445
+ device=None,
446
+ dtype=None,
447
+ ) -> Tuple[Tensor, Tensor]:
448
+ """
449
+ Returns:
450
+ (Tensor, Tensor): The tuple consists of ``bins`` (which is
451
+ melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
452
+ center frequencies of bins of size (``num_bins``)).
453
+ """
454
+ assert num_bins > 3, "Must have at least 3 mel bins"
455
+ assert window_length_padded % 2 == 0
456
+ num_fft_bins = window_length_padded / 2
457
+ nyquist = 0.5 * sample_freq
458
+
459
+ if high_freq <= 0.0:
460
+ high_freq += nyquist
461
+
462
+ assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
463
+ "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
464
+ )
465
+
466
+ # fft-bin width [think of it as Nyquist-freq / half-window-length]
467
+ fft_bin_width = sample_freq / window_length_padded
468
+ mel_low_freq = mel_scale_scalar(low_freq)
469
+ mel_high_freq = mel_scale_scalar(high_freq)
470
+
471
+ # divide by num_bins+1 in next line because of end-effects where the bins
472
+ # spread out to the sides.
473
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
474
+
475
+ if vtln_high < 0.0:
476
+ vtln_high += nyquist
477
+
478
+ assert vtln_warp_factor == 1.0 or (
479
+ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
480
+ ), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
481
+ vtln_low, vtln_high, low_freq, high_freq
482
+ )
483
+
484
+ bin = torch.arange(num_bins).unsqueeze(1)
485
+ left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
486
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
487
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
488
+
489
+ if vtln_warp_factor != 1.0:
490
+ left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
491
+ center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
492
+ right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
493
+
494
+ # center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
495
+ # size(1, num_fft_bins)
496
+ mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
497
+
498
+ # size (num_bins, num_fft_bins)
499
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
500
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
501
+
502
+ if vtln_warp_factor == 1.0:
503
+ # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
504
+ bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
505
+ else:
506
+ # warping can move the order of left_mel, center_mel, right_mel anywhere
507
+ bins = torch.zeros_like(up_slope)
508
+ up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
509
+ down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
510
+ bins[up_idx] = up_slope[up_idx]
511
+ bins[down_idx] = down_slope[down_idx]
512
+
513
+ return bins.to(device=device, dtype=dtype) # , center_freqs
514
+
515
+
516
+ cache = {}
517
+
518
+
519
+ def fbank(
520
+ waveform: Tensor,
521
+ blackman_coeff: float = 0.42,
522
+ channel: int = -1,
523
+ dither: float = 0.0,
524
+ energy_floor: float = 1.0,
525
+ frame_length: float = 25.0,
526
+ frame_shift: float = 10.0,
527
+ high_freq: float = 0.0,
528
+ htk_compat: bool = False,
529
+ low_freq: float = 20.0,
530
+ min_duration: float = 0.0,
531
+ num_mel_bins: int = 23,
532
+ preemphasis_coefficient: float = 0.97,
533
+ raw_energy: bool = True,
534
+ remove_dc_offset: bool = True,
535
+ round_to_power_of_two: bool = True,
536
+ sample_frequency: float = 16000.0,
537
+ snip_edges: bool = True,
538
+ subtract_mean: bool = False,
539
+ use_energy: bool = False,
540
+ use_log_fbank: bool = True,
541
+ use_power: bool = True,
542
+ vtln_high: float = -500.0,
543
+ vtln_low: float = 100.0,
544
+ vtln_warp: float = 1.0,
545
+ window_type: str = POVEY,
546
+ ) -> Tensor:
547
+ r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
548
+ compute-fbank-feats.
549
+
550
+ Args:
551
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
552
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
553
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
554
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
555
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
556
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
557
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
558
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
559
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
560
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
561
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
562
+ (Default: ``0.0``)
563
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
564
+ (need to change other parameters). (Default: ``False``)
565
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
566
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
567
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
568
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
569
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
570
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
571
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
572
+ to FFT. (Default: ``True``)
573
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
574
+ specified there) (Default: ``16000.0``)
575
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
576
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
577
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
578
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
579
+ it this way. (Default: ``False``)
580
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
581
+ use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
582
+ use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
583
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
584
+ negative, offset from high-mel-freq (Default: ``-500.0``)
585
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
586
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
587
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
588
+ (Default: ``'povey'``)
589
+
590
+ Returns:
591
+ Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
592
+ where m is calculated in _get_strided
593
+ """
594
+ device, dtype = waveform.device, waveform.dtype
595
+
596
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
597
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
598
+ )
599
+
600
+ if len(waveform) < min_duration * sample_frequency:
601
+ # signal is too short
602
+ return torch.empty(0, device=device, dtype=dtype)
603
+
604
+ # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
605
+ strided_input, signal_log_energy = _get_window(
606
+ waveform,
607
+ padded_window_size,
608
+ window_size,
609
+ window_shift,
610
+ window_type,
611
+ blackman_coeff,
612
+ snip_edges,
613
+ raw_energy,
614
+ energy_floor,
615
+ dither,
616
+ remove_dc_offset,
617
+ preemphasis_coefficient,
618
+ )
619
+
620
+ # size (m, padded_window_size // 2 + 1)
621
+ spectrum = torch.fft.rfft(strided_input).abs()
622
+ if use_power:
623
+ spectrum = spectrum.pow(2.0)
624
+
625
+ # size (num_mel_bins, padded_window_size // 2)
626
+ # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
627
+
628
+ cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
629
+ num_mel_bins,
630
+ padded_window_size,
631
+ sample_frequency,
632
+ low_freq,
633
+ high_freq,
634
+ vtln_low,
635
+ vtln_high,
636
+ vtln_warp,
637
+ device,
638
+ dtype,
639
+ )
640
+ if cache_key not in cache:
641
+ mel_energies = get_mel_banks(
642
+ num_mel_bins,
643
+ padded_window_size,
644
+ sample_frequency,
645
+ low_freq,
646
+ high_freq,
647
+ vtln_low,
648
+ vtln_high,
649
+ vtln_warp,
650
+ device,
651
+ dtype,
652
+ )
653
+ cache[cache_key] = mel_energies
654
+ else:
655
+ mel_energies = cache[cache_key]
656
+
657
+ # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
658
+ mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
659
+
660
+ # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
661
+ mel_energies = torch.mm(spectrum, mel_energies.T)
662
+ if use_log_fbank:
663
+ # avoid log of zero (which should be prevented anyway by dithering)
664
+ mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
665
+
666
+ # if use_energy then add it as the last column for htk_compat == true else first column
667
+ if use_energy:
668
+ signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
669
+ # returns size (m, num_mel_bins + 1)
670
+ if htk_compat:
671
+ mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
672
+ else:
673
+ mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
674
+
675
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
676
+ return mel_energies
677
+
678
+
679
+ def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
680
+ # returns a dct matrix of size (num_mel_bins, num_ceps)
681
+ # size (num_mel_bins, num_mel_bins)
682
+ dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
683
+ # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
684
+ # this would be the first column in the dct_matrix for torchaudio as it expects a
685
+ # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
686
+ # expects a left multiply e.g. dct_matrix * vector).
687
+ dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
688
+ dct_matrix = dct_matrix[:, :num_ceps]
689
+ return dct_matrix
690
+
691
+
692
+ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
693
+ # returns size (num_ceps)
694
+ # Compute liftering coefficients (scaling on cepstral coeffs)
695
+ # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
696
+ i = torch.arange(num_ceps)
697
+ return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
698
+
699
+
700
+ def mfcc(
701
+ waveform: Tensor,
702
+ blackman_coeff: float = 0.42,
703
+ cepstral_lifter: float = 22.0,
704
+ channel: int = -1,
705
+ dither: float = 0.0,
706
+ energy_floor: float = 1.0,
707
+ frame_length: float = 25.0,
708
+ frame_shift: float = 10.0,
709
+ high_freq: float = 0.0,
710
+ htk_compat: bool = False,
711
+ low_freq: float = 20.0,
712
+ num_ceps: int = 13,
713
+ min_duration: float = 0.0,
714
+ num_mel_bins: int = 23,
715
+ preemphasis_coefficient: float = 0.97,
716
+ raw_energy: bool = True,
717
+ remove_dc_offset: bool = True,
718
+ round_to_power_of_two: bool = True,
719
+ sample_frequency: float = 16000.0,
720
+ snip_edges: bool = True,
721
+ subtract_mean: bool = False,
722
+ use_energy: bool = False,
723
+ vtln_high: float = -500.0,
724
+ vtln_low: float = 100.0,
725
+ vtln_warp: float = 1.0,
726
+ window_type: str = POVEY,
727
+ ) -> Tensor:
728
+ r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
729
+ compute-mfcc-feats.
730
+
731
+ Args:
732
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
733
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
734
+ cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
735
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
736
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
737
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
738
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
739
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
740
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
741
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
742
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
743
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
744
+ (Default: ``0.0``)
745
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
746
+ features (need to change other parameters). (Default: ``False``)
747
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
748
+ num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
749
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
750
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
751
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
752
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
753
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
754
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
755
+ to FFT. (Default: ``True``)
756
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
757
+ specified there) (Default: ``16000.0``)
758
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
759
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
760
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
761
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
762
+ it this way. (Default: ``False``)
763
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
764
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
765
+ negative, offset from high-mel-freq (Default: ``-500.0``)
766
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
767
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
768
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
769
+ (Default: ``"povey"``)
770
+
771
+ Returns:
772
+ Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
773
+ where m is calculated in _get_strided
774
+ """
775
+ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
776
+
777
+ device, dtype = waveform.device, waveform.dtype
778
+
779
+ # The mel_energies should not be squared (use_power=True), not have mean subtracted
780
+ # (subtract_mean=False), and use log (use_log_fbank=True).
781
+ # size (m, num_mel_bins + use_energy)
782
+ feature = fbank(
783
+ waveform=waveform,
784
+ blackman_coeff=blackman_coeff,
785
+ channel=channel,
786
+ dither=dither,
787
+ energy_floor=energy_floor,
788
+ frame_length=frame_length,
789
+ frame_shift=frame_shift,
790
+ high_freq=high_freq,
791
+ htk_compat=htk_compat,
792
+ low_freq=low_freq,
793
+ min_duration=min_duration,
794
+ num_mel_bins=num_mel_bins,
795
+ preemphasis_coefficient=preemphasis_coefficient,
796
+ raw_energy=raw_energy,
797
+ remove_dc_offset=remove_dc_offset,
798
+ round_to_power_of_two=round_to_power_of_two,
799
+ sample_frequency=sample_frequency,
800
+ snip_edges=snip_edges,
801
+ subtract_mean=False,
802
+ use_energy=use_energy,
803
+ use_log_fbank=True,
804
+ use_power=True,
805
+ vtln_high=vtln_high,
806
+ vtln_low=vtln_low,
807
+ vtln_warp=vtln_warp,
808
+ window_type=window_type,
809
+ )
810
+
811
+ if use_energy:
812
+ # size (m)
813
+ signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
814
+ # offset is 0 if htk_compat==True else 1
815
+ mel_offset = int(not htk_compat)
816
+ feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
817
+
818
+ # size (num_mel_bins, num_ceps)
819
+ dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
820
+
821
+ # size (m, num_ceps)
822
+ feature = feature.matmul(dct_matrix)
823
+
824
+ if cepstral_lifter != 0.0:
825
+ # size (1, num_ceps)
826
+ lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
827
+ feature *= lifter_coeffs.to(device=device, dtype=dtype)
828
+
829
+ # if use_energy then replace the last column for htk_compat == true else first column
830
+ if use_energy:
831
+ feature[:, 0] = signal_log_energy
832
+
833
+ if htk_compat:
834
+ energy = feature[:, 0].unsqueeze(1) # size (m, 1)
835
+ feature = feature[:, 1:] # size (m, num_ceps - 1)
836
+ if not use_energy:
837
+ # scale on C0 (actually removing a scale we previously added that's
838
+ # part of one common definition of the cosine transform.)
839
+ energy *= math.sqrt(2)
840
+
841
+ feature = torch.cat((feature, energy), dim=1)
842
+
843
+ feature = _subtract_column_mean(feature, subtract_mean)
844
+ return feature
GPT_SoVITS/eres2net/pooling_layers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class TAP(nn.Module):
11
+ """
12
+ Temporal average pooling, only first-order mean is considered
13
+ """
14
+
15
+ def __init__(self, **kwargs):
16
+ super(TAP, self).__init__()
17
+
18
+ def forward(self, x):
19
+ pooling_mean = x.mean(dim=-1)
20
+ # To be compatable with 2D input
21
+ pooling_mean = pooling_mean.flatten(start_dim=1)
22
+ return pooling_mean
23
+
24
+
25
+ class TSDP(nn.Module):
26
+ """
27
+ Temporal standard deviation pooling, only second-order std is considered
28
+ """
29
+
30
+ def __init__(self, **kwargs):
31
+ super(TSDP, self).__init__()
32
+
33
+ def forward(self, x):
34
+ # The last dimension is the temporal axis
35
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
36
+ pooling_std = pooling_std.flatten(start_dim=1)
37
+ return pooling_std
38
+
39
+
40
+ class TSTP(nn.Module):
41
+ """
42
+ Temporal statistics pooling, concatenate mean and std, which is used in
43
+ x-vector
44
+ Comment: simple concatenation can not make full use of both statistics
45
+ """
46
+
47
+ def __init__(self, **kwargs):
48
+ super(TSTP, self).__init__()
49
+
50
+ def forward(self, x):
51
+ # The last dimension is the temporal axis
52
+ pooling_mean = x.mean(dim=-1)
53
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
54
+ pooling_mean = pooling_mean.flatten(start_dim=1)
55
+ pooling_std = pooling_std.flatten(start_dim=1)
56
+
57
+ stats = torch.cat((pooling_mean, pooling_std), 1)
58
+ return stats
59
+
60
+
61
+ class ASTP(nn.Module):
62
+ """Attentive statistics pooling: Channel- and context-dependent
63
+ statistics pooling, first used in ECAPA_TDNN.
64
+ """
65
+
66
+ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
67
+ super(ASTP, self).__init__()
68
+ self.global_context_att = global_context_att
69
+
70
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
71
+ # need to transpose inputs.
72
+ if global_context_att:
73
+ self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
74
+ else:
75
+ self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
76
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
77
+
78
+ def forward(self, x):
79
+ """
80
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
81
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
82
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
83
+ """
84
+ if len(x.shape) == 4:
85
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
86
+ assert len(x.shape) == 3
87
+
88
+ if self.global_context_att:
89
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
90
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
91
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
92
+ else:
93
+ x_in = x
94
+
95
+ # DON'T use ReLU here! ReLU may be hard to converge.
96
+ alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
97
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
98
+ mean = torch.sum(alpha * x, dim=2)
99
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
100
+ std = torch.sqrt(var.clamp(min=1e-10))
101
+ return torch.cat([mean, std], dim=1)
GPT_SoVITS/f5_tts/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .backbones.dit import DiT
2
+
3
+ __all__ = ["DiT"]
GPT_SoVITS/f5_tts/model/backbones/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Backbones quick introduction
2
+
3
+
4
+ ### unett.py
5
+ - flat unet transformer
6
+ - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
+
9
+ ### dit.py
10
+ - adaln-zero dit
11
+ - embedded timestep as condition
12
+ - concatted noised_input + masked_cond + embedded_text, linear proj in
13
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
+ - possible long skip connection (first layer to last layer)
15
+
16
+ ### mmdit.py
17
+ - sd3 structure
18
+ - timestep as condition
19
+ - left stream: text embedded and applied a abs pos emb
20
+ - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
GPT_SoVITS/f5_tts/model/backbones/dit.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+ from x_transformers.x_transformers import RotaryEmbedding
16
+
17
+ from GPT_SoVITS.module.commons import sequence_mask
18
+
19
+ from ..modules import (
20
+ AdaLayerNormZero_Final,
21
+ ConvNeXtV2Block,
22
+ ConvPositionEmbedding,
23
+ DiTBlock,
24
+ TimestepEmbedding,
25
+ get_pos_embed_indices,
26
+ precompute_freqs_cis,
27
+ )
28
+
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, text_dim, conv_layers=0, conv_mult=2):
32
+ super().__init__()
33
+ if conv_layers > 0:
34
+ self.extra_modeling = True
35
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
37
+ self.text_blocks = nn.Sequential(
38
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
39
+ )
40
+ else:
41
+ self.extra_modeling = False
42
+
43
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
44
+ batch, text_len = text.shape[0], text.shape[1]
45
+
46
+ if drop_text: # cfg for text
47
+ text = torch.zeros_like(text)
48
+
49
+ # possible extra modeling
50
+ if self.extra_modeling:
51
+ # sinus pos emb
52
+ batch_start = torch.zeros((batch,), dtype=torch.long)
53
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
54
+ text_pos_embed = self.freqs_cis[pos_idx]
55
+
56
+ # print(23333333,text.shape,text_pos_embed.shape)#torch.Size([7, 465, 256]) torch.Size([7, 465, 256])
57
+
58
+ text = text + text_pos_embed
59
+
60
+ # convnextv2 blocks
61
+ text = self.text_blocks(text)
62
+
63
+ return text
64
+
65
+
66
+ # noised input audio and context mixing embedding
67
+
68
+
69
+ class InputEmbedding(nn.Module):
70
+ def __init__(self, mel_dim, text_dim, out_dim):
71
+ super().__init__()
72
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
73
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
74
+
75
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
76
+ if drop_audio_cond: # cfg for cond audio
77
+ cond = torch.zeros_like(cond)
78
+
79
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
80
+ x = self.conv_pos_embed(x) + x
81
+ return x
82
+
83
+
84
+ # Transformer backbone using DiT blocks
85
+
86
+
87
+ class DiT(nn.Module):
88
+ def __init__(
89
+ self,
90
+ *,
91
+ dim,
92
+ depth=8,
93
+ heads=8,
94
+ dim_head=64,
95
+ dropout=0.1,
96
+ ff_mult=4,
97
+ mel_dim=100,
98
+ text_dim=None,
99
+ conv_layers=0,
100
+ long_skip_connection=False,
101
+ ):
102
+ super().__init__()
103
+
104
+ self.time_embed = TimestepEmbedding(dim)
105
+ self.d_embed = TimestepEmbedding(dim)
106
+ if text_dim is None:
107
+ text_dim = mel_dim
108
+ self.text_embed = TextEmbedding(text_dim, conv_layers=conv_layers)
109
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
110
+
111
+ self.rotary_embed = RotaryEmbedding(dim_head)
112
+
113
+ self.dim = dim
114
+ self.depth = depth
115
+
116
+ self.transformer_blocks = nn.ModuleList(
117
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
118
+ )
119
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
120
+
121
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
122
+ self.proj_out = nn.Linear(dim, mel_dim)
123
+
124
+ def ckpt_wrapper(self, module):
125
+ # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
126
+ def ckpt_forward(*inputs):
127
+ outputs = module(*inputs)
128
+ return outputs
129
+
130
+ return ckpt_forward
131
+
132
+ def forward( # x, prompt_x, x_lens, t, style,cond
133
+ self, # d is channel,n is T
134
+ x0: float["b n d"], # nosied input audio # noqa: F722
135
+ cond0: float["b n d"], # masked cond audio # noqa: F722
136
+ x_lens,
137
+ time: float["b"] | float[""], # time step # noqa: F821 F722
138
+ dt_base_bootstrap,
139
+ text0, # : int["b nt"] # noqa: F722#####condition feature
140
+ use_grad_ckpt=False, # bool
141
+ ###no-use
142
+ drop_audio_cond=False, # cfg for cond audio
143
+ drop_text=False, # cfg for text
144
+ # mask: bool["b n"] | None = None, # noqa: F722
145
+ infer=False, # bool
146
+ text_cache=None, # torch tensor as text_embed
147
+ dt_cache=None, # torch tensor as dt
148
+ ):
149
+ x = x0.transpose(2, 1)
150
+ cond = cond0.transpose(2, 1)
151
+ text = text0.transpose(2, 1)
152
+ mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
153
+
154
+ batch, seq_len = x.shape[0], x.shape[1]
155
+ if time.ndim == 0:
156
+ time = time.repeat(batch)
157
+
158
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
159
+ t = self.time_embed(time)
160
+ if infer and dt_cache is not None:
161
+ dt = dt_cache
162
+ else:
163
+ dt = self.d_embed(dt_base_bootstrap)
164
+ t += dt
165
+
166
+ if infer and text_cache is not None:
167
+ text_embed = text_cache
168
+ else:
169
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
170
+
171
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
172
+
173
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
174
+
175
+ if self.long_skip_connection is not None:
176
+ residual = x
177
+
178
+ for block in self.transformer_blocks:
179
+ if use_grad_ckpt:
180
+ x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
181
+ else:
182
+ x = block(x, t, mask=mask, rope=rope)
183
+
184
+ if self.long_skip_connection is not None:
185
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
186
+
187
+ x = self.norm_out(x, t)
188
+ output = self.proj_out(x)
189
+
190
+ if infer:
191
+ return output, text_embed, dt
192
+ else:
193
+ return output
GPT_SoVITS/f5_tts/model/backbones/mmdit.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ from x_transformers.x_transformers import RotaryEmbedding
15
+
16
+ from ..modules import (
17
+ AdaLayerNormZero_Final,
18
+ ConvPositionEmbedding,
19
+ MMDiTBlock,
20
+ TimestepEmbedding,
21
+ get_pos_embed_indices,
22
+ precompute_freqs_cis,
23
+ )
24
+
25
+ # text embedding
26
+
27
+
28
+ class TextEmbedding(nn.Module):
29
+ def __init__(self, out_dim, text_num_embeds):
30
+ super().__init__()
31
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
32
+
33
+ self.precompute_max_pos = 1024
34
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
35
+
36
+ def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
37
+ text = text + 1
38
+ if drop_text:
39
+ text = torch.zeros_like(text)
40
+ text = self.text_embed(text)
41
+
42
+ # sinus pos emb
43
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
44
+ batch_text_len = text.shape[1]
45
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
46
+ text_pos_embed = self.freqs_cis[pos_idx]
47
+
48
+ text = text + text_pos_embed
49
+
50
+ return text
51
+
52
+
53
+ # noised input & masked cond audio embedding
54
+
55
+
56
+ class AudioEmbedding(nn.Module):
57
+ def __init__(self, in_dim, out_dim):
58
+ super().__init__()
59
+ self.linear = nn.Linear(2 * in_dim, out_dim)
60
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
61
+
62
+ def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
63
+ if drop_audio_cond:
64
+ cond = torch.zeros_like(cond)
65
+ x = torch.cat((x, cond), dim=-1)
66
+ x = self.linear(x)
67
+ x = self.conv_pos_embed(x) + x
68
+ return x
69
+
70
+
71
+ # Transformer backbone using MM-DiT blocks
72
+
73
+
74
+ class MMDiT(nn.Module):
75
+ def __init__(
76
+ self,
77
+ *,
78
+ dim,
79
+ depth=8,
80
+ heads=8,
81
+ dim_head=64,
82
+ dropout=0.1,
83
+ ff_mult=4,
84
+ text_num_embeds=256,
85
+ mel_dim=100,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.time_embed = TimestepEmbedding(dim)
90
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
91
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
92
+
93
+ self.rotary_embed = RotaryEmbedding(dim_head)
94
+
95
+ self.dim = dim
96
+ self.depth = depth
97
+
98
+ self.transformer_blocks = nn.ModuleList(
99
+ [
100
+ MMDiTBlock(
101
+ dim=dim,
102
+ heads=heads,
103
+ dim_head=dim_head,
104
+ dropout=dropout,
105
+ ff_mult=ff_mult,
106
+ context_pre_only=i == depth - 1,
107
+ )
108
+ for i in range(depth)
109
+ ]
110
+ )
111
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
112
+ self.proj_out = nn.Linear(dim, mel_dim)
113
+
114
+ def forward(
115
+ self,
116
+ x: float["b n d"], # nosied input audio # noqa: F722
117
+ cond: float["b n d"], # masked cond audio # noqa: F722
118
+ text: int["b nt"], # text # noqa: F722
119
+ time: float["b"] | float[""], # time step # noqa: F821 F722
120
+ drop_audio_cond, # cfg for cond audio
121
+ drop_text, # cfg for text
122
+ mask: bool["b n"] | None = None, # noqa: F722
123
+ ):
124
+ batch = x.shape[0]
125
+ if time.ndim == 0:
126
+ time = time.repeat(batch)
127
+
128
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
129
+ t = self.time_embed(time)
130
+ c = self.text_embed(text, drop_text=drop_text)
131
+ x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
132
+
133
+ seq_len = x.shape[1]
134
+ text_len = text.shape[1]
135
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
136
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
137
+
138
+ for block in self.transformer_blocks:
139
+ c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
140
+
141
+ x = self.norm_out(x, t)
142
+ output = self.proj_out(x)
143
+
144
+ return output
GPT_SoVITS/f5_tts/model/backbones/unett.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Literal
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+ from x_transformers import RMSNorm
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from ..modules import (
21
+ Attention,
22
+ AttnProcessor,
23
+ ConvNeXtV2Block,
24
+ ConvPositionEmbedding,
25
+ FeedForward,
26
+ TimestepEmbedding,
27
+ get_pos_embed_indices,
28
+ precompute_freqs_cis,
29
+ )
30
+
31
+ # Text embedding
32
+
33
+
34
+ class TextEmbedding(nn.Module):
35
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
36
+ super().__init__()
37
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
38
+
39
+ if conv_layers > 0:
40
+ self.extra_modeling = True
41
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
42
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
43
+ self.text_blocks = nn.Sequential(
44
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
45
+ )
46
+ else:
47
+ self.extra_modeling = False
48
+
49
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ batch, text_len = text.shape[0], text.shape[1]
53
+ text = F.pad(text, (0, seq_len - text_len), value=0)
54
+
55
+ if drop_text: # cfg for text
56
+ text = torch.zeros_like(text)
57
+
58
+ text = self.text_embed(text) # b n -> b n d
59
+
60
+ # possible extra modeling
61
+ if self.extra_modeling:
62
+ # sinus pos emb
63
+ batch_start = torch.zeros((batch,), dtype=torch.long)
64
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
65
+ text_pos_embed = self.freqs_cis[pos_idx]
66
+ text = text + text_pos_embed
67
+
68
+ # convnextv2 blocks
69
+ text = self.text_blocks(text)
70
+
71
+ return text
72
+
73
+
74
+ # noised input audio and context mixing embedding
75
+
76
+
77
+ class InputEmbedding(nn.Module):
78
+ def __init__(self, mel_dim, text_dim, out_dim):
79
+ super().__init__()
80
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
81
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
82
+
83
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
84
+ if drop_audio_cond: # cfg for cond audio
85
+ cond = torch.zeros_like(cond)
86
+
87
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
88
+ x = self.conv_pos_embed(x) + x
89
+ return x
90
+
91
+
92
+ # Flat UNet Transformer backbone
93
+
94
+
95
+ class UNetT(nn.Module):
96
+ def __init__(
97
+ self,
98
+ *,
99
+ dim,
100
+ depth=8,
101
+ heads=8,
102
+ dim_head=64,
103
+ dropout=0.1,
104
+ ff_mult=4,
105
+ mel_dim=100,
106
+ text_num_embeds=256,
107
+ text_dim=None,
108
+ conv_layers=0,
109
+ skip_connect_type: Literal["add", "concat", "none"] = "concat",
110
+ ):
111
+ super().__init__()
112
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
113
+
114
+ self.time_embed = TimestepEmbedding(dim)
115
+ if text_dim is None:
116
+ text_dim = mel_dim
117
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
118
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
119
+
120
+ self.rotary_embed = RotaryEmbedding(dim_head)
121
+
122
+ # transformer layers & skip connections
123
+
124
+ self.dim = dim
125
+ self.skip_connect_type = skip_connect_type
126
+ needs_skip_proj = skip_connect_type == "concat"
127
+
128
+ self.depth = depth
129
+ self.layers = nn.ModuleList([])
130
+
131
+ for idx in range(depth):
132
+ is_later_half = idx >= (depth // 2)
133
+
134
+ attn_norm = RMSNorm(dim)
135
+ attn = Attention(
136
+ processor=AttnProcessor(),
137
+ dim=dim,
138
+ heads=heads,
139
+ dim_head=dim_head,
140
+ dropout=dropout,
141
+ )
142
+
143
+ ff_norm = RMSNorm(dim)
144
+ ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
145
+
146
+ skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
147
+
148
+ self.layers.append(
149
+ nn.ModuleList(
150
+ [
151
+ skip_proj,
152
+ attn_norm,
153
+ attn,
154
+ ff_norm,
155
+ ff,
156
+ ]
157
+ )
158
+ )
159
+
160
+ self.norm_out = RMSNorm(dim)
161
+ self.proj_out = nn.Linear(dim, mel_dim)
162
+
163
+ def forward(
164
+ self,
165
+ x: float["b n d"], # nosied input audio # noqa: F722
166
+ cond: float["b n d"], # masked cond audio # noqa: F722
167
+ text: int["b nt"], # text # noqa: F722
168
+ time: float["b"] | float[""], # time step # noqa: F821 F722
169
+ drop_audio_cond, # cfg for cond audio
170
+ drop_text, # cfg for text
171
+ mask: bool["b n"] | None = None, # noqa: F722
172
+ ):
173
+ batch, seq_len = x.shape[0], x.shape[1]
174
+ if time.ndim == 0:
175
+ time = time.repeat(batch)
176
+
177
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
178
+ t = self.time_embed(time)
179
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
180
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
181
+
182
+ # postfix time t to input x, [b n d] -> [b n+1 d]
183
+ x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
184
+ if mask is not None:
185
+ mask = F.pad(mask, (1, 0), value=1)
186
+
187
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
188
+
189
+ # flat unet transformer
190
+ skip_connect_type = self.skip_connect_type
191
+ skips = []
192
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
193
+ layer = idx + 1
194
+
195
+ # skip connection logic
196
+ is_first_half = layer <= (self.depth // 2)
197
+ is_later_half = not is_first_half
198
+
199
+ if is_first_half:
200
+ skips.append(x)
201
+
202
+ if is_later_half:
203
+ skip = skips.pop()
204
+ if skip_connect_type == "concat":
205
+ x = torch.cat((x, skip), dim=-1)
206
+ x = maybe_skip_proj(x)
207
+ elif skip_connect_type == "add":
208
+ x = x + skip
209
+
210
+ # attention and feedforward blocks
211
+ x = attn(attn_norm(x), rope=rope, mask=mask) + x
212
+ x = ff(ff_norm(x)) + x
213
+
214
+ assert len(skips) == 0
215
+
216
+ x = self.norm_out(x)[:, 1:, :] # unpack t from x
217
+
218
+ return self.proj_out(x)
GPT_SoVITS/f5_tts/model/modules.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ from librosa.filters import mel as librosa_mel_fn
19
+ from torch import nn
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+ # raw wav to mel spec
23
+
24
+
25
+ mel_basis_cache = {}
26
+ hann_window_cache = {}
27
+
28
+
29
+ def get_bigvgan_mel_spectrogram(
30
+ waveform,
31
+ n_fft=1024,
32
+ n_mel_channels=100,
33
+ target_sample_rate=24000,
34
+ hop_length=256,
35
+ win_length=1024,
36
+ fmin=0,
37
+ fmax=None,
38
+ center=False,
39
+ ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
40
+ device = waveform.device
41
+ key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
42
+
43
+ if key not in mel_basis_cache:
44
+ mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
45
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
46
+ hann_window_cache[key] = torch.hann_window(win_length).to(device)
47
+
48
+ mel_basis = mel_basis_cache[key]
49
+ hann_window = hann_window_cache[key]
50
+
51
+ padding = (n_fft - hop_length) // 2
52
+ waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
53
+
54
+ spec = torch.stft(
55
+ waveform,
56
+ n_fft,
57
+ hop_length=hop_length,
58
+ win_length=win_length,
59
+ window=hann_window,
60
+ center=center,
61
+ pad_mode="reflect",
62
+ normalized=False,
63
+ onesided=True,
64
+ return_complex=True,
65
+ )
66
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
67
+
68
+ mel_spec = torch.matmul(mel_basis, spec)
69
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
70
+
71
+ return mel_spec
72
+
73
+
74
+ def get_vocos_mel_spectrogram(
75
+ waveform,
76
+ n_fft=1024,
77
+ n_mel_channels=100,
78
+ target_sample_rate=24000,
79
+ hop_length=256,
80
+ win_length=1024,
81
+ ):
82
+ mel_stft = torchaudio.transforms.MelSpectrogram(
83
+ sample_rate=target_sample_rate,
84
+ n_fft=n_fft,
85
+ win_length=win_length,
86
+ hop_length=hop_length,
87
+ n_mels=n_mel_channels,
88
+ power=1,
89
+ center=True,
90
+ normalized=False,
91
+ norm=None,
92
+ ).to(waveform.device)
93
+ if len(waveform.shape) == 3:
94
+ waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
95
+
96
+ assert len(waveform.shape) == 2
97
+
98
+ mel = mel_stft(waveform)
99
+ mel = mel.clamp(min=1e-5).log()
100
+ return mel
101
+
102
+
103
+ class MelSpec(nn.Module):
104
+ def __init__(
105
+ self,
106
+ n_fft=1024,
107
+ hop_length=256,
108
+ win_length=1024,
109
+ n_mel_channels=100,
110
+ target_sample_rate=24_000,
111
+ mel_spec_type="vocos",
112
+ ):
113
+ super().__init__()
114
+ assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
115
+
116
+ self.n_fft = n_fft
117
+ self.hop_length = hop_length
118
+ self.win_length = win_length
119
+ self.n_mel_channels = n_mel_channels
120
+ self.target_sample_rate = target_sample_rate
121
+
122
+ if mel_spec_type == "vocos":
123
+ self.extractor = get_vocos_mel_spectrogram
124
+ elif mel_spec_type == "bigvgan":
125
+ self.extractor = get_bigvgan_mel_spectrogram
126
+
127
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
128
+
129
+ def forward(self, wav):
130
+ if self.dummy.device != wav.device:
131
+ self.to(wav.device)
132
+
133
+ mel = self.extractor(
134
+ waveform=wav,
135
+ n_fft=self.n_fft,
136
+ n_mel_channels=self.n_mel_channels,
137
+ target_sample_rate=self.target_sample_rate,
138
+ hop_length=self.hop_length,
139
+ win_length=self.win_length,
140
+ )
141
+
142
+ return mel
143
+
144
+
145
+ # sinusoidal position embedding
146
+
147
+
148
+ class SinusPositionEmbedding(nn.Module):
149
+ def __init__(self, dim):
150
+ super().__init__()
151
+ self.dim = dim
152
+
153
+ def forward(self, x, scale=1000):
154
+ device = x.device
155
+ half_dim = self.dim // 2
156
+ emb = math.log(10000) / (half_dim - 1)
157
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
158
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
159
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
160
+ return emb
161
+
162
+
163
+ # convolutional position embedding
164
+
165
+
166
+ class ConvPositionEmbedding(nn.Module):
167
+ def __init__(self, dim, kernel_size=31, groups=16):
168
+ super().__init__()
169
+ assert kernel_size % 2 != 0
170
+ self.conv1d = nn.Sequential(
171
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
172
+ nn.Mish(),
173
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
174
+ nn.Mish(),
175
+ )
176
+
177
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
178
+ if mask is not None:
179
+ mask = mask[..., None]
180
+ x = x.masked_fill(~mask, 0.0)
181
+
182
+ x = x.permute(0, 2, 1)
183
+ x = self.conv1d(x)
184
+ out = x.permute(0, 2, 1)
185
+
186
+ if mask is not None:
187
+ out = out.masked_fill(~mask, 0.0)
188
+
189
+ return out
190
+
191
+
192
+ # rotary positional embedding related
193
+
194
+
195
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
196
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
197
+ # has some connection to NTK literature
198
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
199
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
200
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
201
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
202
+ t = torch.arange(end, device=freqs.device) # type: ignore
203
+ freqs = torch.outer(t, freqs).float() # type: ignore
204
+ freqs_cos = torch.cos(freqs) # real part
205
+ freqs_sin = torch.sin(freqs) # imaginary part
206
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
207
+
208
+
209
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
210
+ # length = length if isinstance(length, int) else length.max()
211
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
212
+ pos = (
213
+ start.unsqueeze(1)
214
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
215
+ )
216
+ # avoid extra long error.
217
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
218
+ return pos
219
+
220
+
221
+ # Global Response Normalization layer (Instance Normalization ?)
222
+
223
+
224
+ class GRN(nn.Module):
225
+ def __init__(self, dim):
226
+ super().__init__()
227
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
228
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
229
+
230
+ def forward(self, x):
231
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
232
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
233
+ return self.gamma * (x * Nx) + self.beta + x
234
+
235
+
236
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
237
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
238
+
239
+
240
+ class ConvNeXtV2Block(nn.Module):
241
+ def __init__(
242
+ self,
243
+ dim: int,
244
+ intermediate_dim: int,
245
+ dilation: int = 1,
246
+ ):
247
+ super().__init__()
248
+ padding = (dilation * (7 - 1)) // 2
249
+ self.dwconv = nn.Conv1d(
250
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
251
+ ) # depthwise conv
252
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
253
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
254
+ self.act = nn.GELU()
255
+ self.grn = GRN(intermediate_dim)
256
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
257
+
258
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
259
+ residual = x
260
+ x = x.transpose(1, 2) # b n d -> b d n
261
+ x = self.dwconv(x)
262
+ x = x.transpose(1, 2) # b d n -> b n d
263
+ x = self.norm(x)
264
+ x = self.pwconv1(x)
265
+ x = self.act(x)
266
+ x = self.grn(x)
267
+ x = self.pwconv2(x)
268
+ return residual + x
269
+
270
+
271
+ # AdaLayerNormZero
272
+ # return with modulated x for attn input, and params for later mlp modulation
273
+
274
+
275
+ class AdaLayerNormZero(nn.Module):
276
+ def __init__(self, dim):
277
+ super().__init__()
278
+
279
+ self.silu = nn.SiLU()
280
+ self.linear = nn.Linear(dim, dim * 6)
281
+
282
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
283
+
284
+ def forward(self, x, emb=None):
285
+ emb = self.linear(self.silu(emb))
286
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
287
+
288
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
289
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
290
+
291
+
292
+ # AdaLayerNormZero for final layer
293
+ # return only with modulated x for attn input, cuz no more mlp modulation
294
+
295
+
296
+ class AdaLayerNormZero_Final(nn.Module):
297
+ def __init__(self, dim):
298
+ super().__init__()
299
+
300
+ self.silu = nn.SiLU()
301
+ self.linear = nn.Linear(dim, dim * 2)
302
+
303
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
304
+
305
+ def forward(self, x, emb):
306
+ emb = self.linear(self.silu(emb))
307
+ scale, shift = torch.chunk(emb, 2, dim=1)
308
+
309
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
310
+ return x
311
+
312
+
313
+ # FeedForward
314
+
315
+
316
+ class FeedForward(nn.Module):
317
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
318
+ super().__init__()
319
+ inner_dim = int(dim * mult)
320
+ dim_out = dim_out if dim_out is not None else dim
321
+
322
+ activation = nn.GELU(approximate=approximate)
323
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
324
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
325
+
326
+ def forward(self, x):
327
+ return self.ff(x)
328
+
329
+
330
+ # Attention with possible joint part
331
+ # modified from diffusers/src/diffusers/models/attention_processor.py
332
+
333
+
334
+ class Attention(nn.Module):
335
+ def __init__(
336
+ self,
337
+ processor: JointAttnProcessor | AttnProcessor,
338
+ dim: int,
339
+ heads: int = 8,
340
+ dim_head: int = 64,
341
+ dropout: float = 0.0,
342
+ context_dim: Optional[int] = None, # if not None -> joint attention
343
+ context_pre_only=None,
344
+ ):
345
+ super().__init__()
346
+
347
+ if not hasattr(F, "scaled_dot_product_attention"):
348
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
349
+
350
+ self.processor = processor
351
+
352
+ self.dim = dim
353
+ self.heads = heads
354
+ self.inner_dim = dim_head * heads
355
+ self.dropout = dropout
356
+
357
+ self.context_dim = context_dim
358
+ self.context_pre_only = context_pre_only
359
+
360
+ self.to_q = nn.Linear(dim, self.inner_dim)
361
+ self.to_k = nn.Linear(dim, self.inner_dim)
362
+ self.to_v = nn.Linear(dim, self.inner_dim)
363
+
364
+ if self.context_dim is not None:
365
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
366
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
367
+ if self.context_pre_only is not None:
368
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
369
+
370
+ self.to_out = nn.ModuleList([])
371
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
372
+ self.to_out.append(nn.Dropout(dropout))
373
+
374
+ if self.context_pre_only is not None and not self.context_pre_only:
375
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
376
+
377
+ def forward(
378
+ self,
379
+ x: float["b n d"], # noised input x # noqa: F722
380
+ c: float["b n d"] = None, # context c # noqa: F722
381
+ mask: bool["b n"] | None = None, # noqa: F722
382
+ rope=None, # rotary position embedding for x
383
+ c_rope=None, # rotary position embedding for c
384
+ ) -> torch.Tensor:
385
+ if c is not None:
386
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
387
+ else:
388
+ return self.processor(self, x, mask=mask, rope=rope)
389
+
390
+
391
+ # Attention processor
392
+
393
+
394
+ # from torch.nn.attention import SDPBackend
395
+ # torch.backends.cuda.enable_flash_sdp(True)
396
+ class AttnProcessor:
397
+ def __init__(self):
398
+ pass
399
+
400
+ def __call__(
401
+ self,
402
+ attn: Attention,
403
+ x: float["b n d"], # noised input x # noqa: F722
404
+ mask: bool["b n"] | None = None, # noqa: F722
405
+ rope=None, # rotary position embedding
406
+ ) -> torch.FloatTensor:
407
+ batch_size = x.shape[0]
408
+
409
+ # `sample` projections.
410
+ query = attn.to_q(x)
411
+ key = attn.to_k(x)
412
+ value = attn.to_v(x)
413
+
414
+ # apply rotary position embedding
415
+ if rope is not None:
416
+ freqs, xpos_scale = rope
417
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
418
+
419
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
420
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
421
+
422
+ # attention
423
+ inner_dim = key.shape[-1]
424
+ head_dim = inner_dim // attn.heads
425
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
428
+
429
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
430
+ if mask is not None:
431
+ attn_mask = mask
432
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
433
+ # print(3433333333,attn_mask.shape)
434
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
435
+ else:
436
+ attn_mask = None
437
+ # with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
438
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
439
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
440
+ # print(torch.backends.cuda.flash_sdp_enabled())
441
+ # print(torch.backends.cuda.mem_efficient_sdp_enabled())
442
+ # print(torch.backends.cuda.math_sdp_enabled())
443
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
444
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
445
+ x = x.to(query.dtype)
446
+
447
+ # linear proj
448
+ x = attn.to_out[0](x)
449
+ # dropout
450
+ x = attn.to_out[1](x)
451
+
452
+ if mask is not None:
453
+ mask = mask.unsqueeze(-1)
454
+ x = x.masked_fill(~mask, 0.0)
455
+
456
+ return x
457
+
458
+
459
+ # Joint Attention processor for MM-DiT
460
+ # modified from diffusers/src/diffusers/models/attention_processor.py
461
+
462
+
463
+ class JointAttnProcessor:
464
+ def __init__(self):
465
+ pass
466
+
467
+ def __call__(
468
+ self,
469
+ attn: Attention,
470
+ x: float["b n d"], # noised input x # noqa: F722
471
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
472
+ mask: bool["b n"] | None = None, # noqa: F722
473
+ rope=None, # rotary position embedding for x
474
+ c_rope=None, # rotary position embedding for c
475
+ ) -> torch.FloatTensor:
476
+ residual = x
477
+
478
+ batch_size = c.shape[0]
479
+
480
+ # `sample` projections.
481
+ query = attn.to_q(x)
482
+ key = attn.to_k(x)
483
+ value = attn.to_v(x)
484
+
485
+ # `context` projections.
486
+ c_query = attn.to_q_c(c)
487
+ c_key = attn.to_k_c(c)
488
+ c_value = attn.to_v_c(c)
489
+
490
+ # apply rope for context and noised input independently
491
+ if rope is not None:
492
+ freqs, xpos_scale = rope
493
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
494
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
495
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
496
+ if c_rope is not None:
497
+ freqs, xpos_scale = c_rope
498
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
499
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
500
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
501
+
502
+ # attention
503
+ query = torch.cat([query, c_query], dim=1)
504
+ key = torch.cat([key, c_key], dim=1)
505
+ value = torch.cat([value, c_value], dim=1)
506
+
507
+ inner_dim = key.shape[-1]
508
+ head_dim = inner_dim // attn.heads
509
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
510
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
511
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
512
+
513
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
514
+ if mask is not None:
515
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
516
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
517
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
518
+ else:
519
+ attn_mask = None
520
+
521
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
522
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
523
+ x = x.to(query.dtype)
524
+
525
+ # Split the attention outputs.
526
+ x, c = (
527
+ x[:, : residual.shape[1]],
528
+ x[:, residual.shape[1] :],
529
+ )
530
+
531
+ # linear proj
532
+ x = attn.to_out[0](x)
533
+ # dropout
534
+ x = attn.to_out[1](x)
535
+ if not attn.context_pre_only:
536
+ c = attn.to_out_c(c)
537
+
538
+ if mask is not None:
539
+ mask = mask.unsqueeze(-1)
540
+ x = x.masked_fill(~mask, 0.0)
541
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
542
+
543
+ return x, c
544
+
545
+
546
+ # DiT Block
547
+
548
+
549
+ class DiTBlock(nn.Module):
550
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
551
+ super().__init__()
552
+
553
+ self.attn_norm = AdaLayerNormZero(dim)
554
+ self.attn = Attention(
555
+ processor=AttnProcessor(),
556
+ dim=dim,
557
+ heads=heads,
558
+ dim_head=dim_head,
559
+ dropout=dropout,
560
+ )
561
+
562
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
563
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
564
+
565
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
566
+ # pre-norm & modulation for attention input
567
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
568
+
569
+ # attention
570
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
571
+
572
+ # process attention output for input x
573
+ x = x + gate_msa.unsqueeze(1) * attn_output
574
+
575
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
576
+ ff_output = self.ff(norm)
577
+ x = x + gate_mlp.unsqueeze(1) * ff_output
578
+
579
+ return x
580
+
581
+
582
+ # MMDiT Block https://arxiv.org/abs/2403.03206
583
+
584
+
585
+ class MMDiTBlock(nn.Module):
586
+ r"""
587
+ modified from diffusers/src/diffusers/models/attention.py
588
+
589
+ notes.
590
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
591
+ _x: noised input related. (right part)
592
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
593
+ """
594
+
595
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
596
+ super().__init__()
597
+
598
+ self.context_pre_only = context_pre_only
599
+
600
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
601
+ self.attn_norm_x = AdaLayerNormZero(dim)
602
+ self.attn = Attention(
603
+ processor=JointAttnProcessor(),
604
+ dim=dim,
605
+ heads=heads,
606
+ dim_head=dim_head,
607
+ dropout=dropout,
608
+ context_dim=dim,
609
+ context_pre_only=context_pre_only,
610
+ )
611
+
612
+ if not context_pre_only:
613
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
614
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
615
+ else:
616
+ self.ff_norm_c = None
617
+ self.ff_c = None
618
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
619
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
620
+
621
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
622
+ # pre-norm & modulation for attention input
623
+ if self.context_pre_only:
624
+ norm_c = self.attn_norm_c(c, t)
625
+ else:
626
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
627
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
628
+
629
+ # attention
630
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
631
+
632
+ # process attention output for context c
633
+ if self.context_pre_only:
634
+ c = None
635
+ else: # if not last layer
636
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
637
+
638
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
639
+ c_ff_output = self.ff_c(norm_c)
640
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
641
+
642
+ # process attention output for input x
643
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
644
+
645
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
646
+ x_ff_output = self.ff_x(norm_x)
647
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
648
+
649
+ return c, x
650
+
651
+
652
+ # time step conditioning embedding
653
+
654
+
655
+ class TimestepEmbedding(nn.Module):
656
+ def __init__(self, dim, freq_embed_dim=256):
657
+ super().__init__()
658
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
659
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
660
+
661
+ def forward(self, timestep: float["b"]): # noqa: F821
662
+ time_hidden = self.time_embed(timestep)
663
+ time_hidden = time_hidden.to(timestep.dtype)
664
+ time = self.time_mlp(time_hidden) # b d
665
+ return time
GPT_SoVITS/feature_extractor/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import cnhubert
2
+
3
+ content_module_map = {"cnhubert": cnhubert}
GPT_SoVITS/feature_extractor/cnhubert.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import (
7
+ HubertModel,
8
+ Wav2Vec2FeatureExtractor,
9
+ )
10
+ from transformers import logging as tf_logging
11
+
12
+ tf_logging.set_verbosity_error()
13
+
14
+ logging.getLogger("numba").setLevel(logging.WARNING)
15
+
16
+ cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
17
+
18
+
19
+ class CNHubert(nn.Module):
20
+ def __init__(self, base_path: str = ""):
21
+ super().__init__()
22
+ if not base_path:
23
+ base_path = cnhubert_base_path
24
+ if os.path.exists(base_path):
25
+ ...
26
+ else:
27
+ raise FileNotFoundError(base_path)
28
+ self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
29
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
30
+
31
+ def forward(self, x):
32
+ input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
33
+ feats = self.model(input_values)["last_hidden_state"]
34
+ return feats
35
+
36
+
37
+ def get_model():
38
+ model = CNHubert()
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ def get_content(hmodel, wav_16k_tensor):
44
+ with torch.no_grad():
45
+ feats = hmodel(wav_16k_tensor)
46
+ return feats.transpose(1, 2)
GPT_SoVITS/inference_webui.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import logging
4
+ import os
5
+ import re
6
+ import shutil
7
+ import traceback
8
+ import warnings
9
+ import zipfile
10
+ from functools import partial
11
+ from pathlib import Path
12
+ from time import time as ttime
13
+ from typing import Any
14
+
15
+ import gradio as gr
16
+ import librosa
17
+ import numpy as np
18
+ import spaces
19
+ import torch
20
+ import torchaudio
21
+ from huggingface_hub import hf_hub_download
22
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
23
+
24
+ from config import (
25
+ change_choices,
26
+ get_dtype,
27
+ get_weights_names,
28
+ )
29
+ from config import (
30
+ infer_device as default_device,
31
+ )
32
+ from GPT_SoVITS.Accelerate import PyTorch, T2SEngineProtocol, T2SRequest, backends
33
+ from GPT_SoVITS.Accelerate.logger import console
34
+ from GPT_SoVITS.feature_extractor import cnhubert
35
+ from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
36
+ from GPT_SoVITS.module.models import SynthesizerTrn
37
+ from GPT_SoVITS.process_ckpt import inspect_version
38
+ from GPT_SoVITS.sv import SV
39
+ from GPT_SoVITS.text import cleaned_text_to_sequence
40
+ from GPT_SoVITS.text.cleaner import clean_text
41
+ from GPT_SoVITS.text.LangSegmenter import LangSegmenter
42
+ from tools.assets import css, js, top_html
43
+ from tools.i18n.i18n import I18nAuto, scan_language_list
44
+ from tools.my_utils import DictToAttrRecursive
45
+
46
+ warnings.filterwarnings(
47
+ "ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
48
+ )
49
+ warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
50
+
51
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
52
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
53
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
54
+ logging.getLogger("httpx").setLevel(logging.ERROR)
55
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
56
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
57
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
58
+ logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
59
+
60
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
61
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
62
+
63
+
64
+ def install():
65
+ base = Path("GPT_SoVITS")
66
+ zip_path = hf_hub_download("XXXXRT/GPT-SoVITS-Pretrained", "pretrained_models.zip", repo_type="model")
67
+ tmp = base / "tmp_unzip"
68
+ if tmp.exists():
69
+ shutil.rmtree(tmp)
70
+ with zipfile.ZipFile(zip_path) as zf:
71
+ zf.extractall(tmp)
72
+ folder = next(tmp.iterdir())
73
+ shutil.move(str(folder), base / folder.name)
74
+ shutil.rmtree(tmp)
75
+
76
+
77
+ install()
78
+
79
+
80
+ _LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
81
+
82
+
83
+ def lang_type(text: str) -> str:
84
+ if text == "Auto":
85
+ return text
86
+ if not _LANG_RE.match(text):
87
+ raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
88
+ ll, cc = re.split(r"[_-]", text)
89
+ language = f"{ll}_{cc}"
90
+ if language in scan_language_list():
91
+ return language
92
+ else:
93
+ return "Auto"
94
+
95
+
96
+ def build_parser() -> argparse.ArgumentParser:
97
+ p = argparse.ArgumentParser(
98
+ prog="inference_webui",
99
+ description=f"python -s -m GPT_SoVITS.inference_webui zh_CN -b {backends[-1]}",
100
+ )
101
+ p.add_argument(
102
+ "language",
103
+ nargs="?",
104
+ default="Auto",
105
+ type=lang_type,
106
+ help="Language Code, Such as zh_CN, en-US",
107
+ )
108
+ p.add_argument(
109
+ "--backends",
110
+ "-b",
111
+ choices=backends,
112
+ default=backends[-1],
113
+ help="AR Inference Backend",
114
+ required=False,
115
+ )
116
+ p.add_argument(
117
+ "--device",
118
+ "-d",
119
+ default=str(default_device),
120
+ help="Inference Device",
121
+ required=False,
122
+ )
123
+ p.add_argument(
124
+ "--port",
125
+ "-p",
126
+ default=9872,
127
+ type=int,
128
+ help="WebUI Binding Port",
129
+ required=False,
130
+ )
131
+ p.add_argument(
132
+ "--share",
133
+ "-s",
134
+ default=False,
135
+ action="store_true",
136
+ help="Gradio Share Link",
137
+ required=False,
138
+ )
139
+ p.add_argument(
140
+ "--cnhubert",
141
+ default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
142
+ help="CNHuBERT Pretrain",
143
+ required=False,
144
+ )
145
+ p.add_argument(
146
+ "--bert",
147
+ default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
148
+ help="BERT Pretrain",
149
+ required=False,
150
+ )
151
+ p.add_argument(
152
+ "--gpt",
153
+ default="",
154
+ help="GPT Model",
155
+ required=False,
156
+ )
157
+ p.add_argument(
158
+ "--sovits",
159
+ default="",
160
+ help="SoVITS Model",
161
+ required=False,
162
+ )
163
+
164
+ return p
165
+
166
+
167
+ args = build_parser().parse_args()
168
+
169
+ hps: Any = None
170
+ vq_model: SynthesizerTrn | None = None
171
+ t2s_engine: T2SEngineProtocol | None = None
172
+
173
+ version = model_version = "v2"
174
+ cnhubert_base_path = str(args.cnhubert)
175
+ bert_path = str(args.bert)
176
+ infer_ttswebui = int(args.port)
177
+ is_share = bool(args.share)
178
+
179
+
180
+ i18n = I18nAuto(language=args.language)
181
+ ar_backend: str = args.backends
182
+ change_choices_i18n = partial(change_choices, i18n=i18n)
183
+
184
+ SoVITS_names, GPT_names = get_weights_names(i18n)
185
+
186
+
187
+ dict_language_v1 = {
188
+ i18n("中文"): "all_zh", # 全部按中文识别
189
+ i18n("英文"): "en", # 全部按英文识别
190
+ i18n("日文"): "all_ja", # 全部按日文识别
191
+ i18n("中英混合"): "zh", # 按中英混合识别
192
+ i18n("日英混合"): "ja", # 按日英混合识别
193
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
194
+ }
195
+ dict_language_v2 = {
196
+ i18n("中文"): "all_zh", # 全部按中文识别
197
+ i18n("英文"): "en", # 全部按英文识别
198
+ i18n("日文"): "all_ja", # 全部按日文识别
199
+ i18n("粤语"): "all_yue", # 全部按粤语识别
200
+ i18n("韩文"): "all_ko", # 全部按韩文识别
201
+ i18n("中英混合"): "zh",
202
+ i18n("日英混合"): "ja",
203
+ i18n("粤英混合"): "yue",
204
+ i18n("韩英混合"): "ko",
205
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
206
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
207
+ }
208
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
209
+
210
+ punctuation = set(["!", "?", "…", ",", ".", "-", " "])
211
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
212
+ v3v4set = {"v3", "v4"}
213
+
214
+ infer_device = torch.device(args.device)
215
+ device = infer_device if infer_device.type == "cuda" else torch.device("cpu")
216
+
217
+ dtype = get_dtype(device.index)
218
+ is_half = dtype == torch.float16
219
+
220
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
221
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path).to(infer_device, dtype)
222
+
223
+ cnhubert.cnhubert_base_path = cnhubert_base_path
224
+ ssl_model = cnhubert.get_model().to(infer_device, dtype)
225
+
226
+ spec_min = -12
227
+ spec_max = 2
228
+
229
+
230
+ def norm_spec(x):
231
+ return (x - spec_min) / (spec_max - spec_min) * 2 - 1
232
+
233
+
234
+ def denorm_spec(x):
235
+ return (x + 1) / 2 * (spec_max - spec_min) + spec_min
236
+
237
+
238
+ def mel_fn(x):
239
+ return mel_spectrogram_torch(
240
+ y=x,
241
+ n_fft=1024,
242
+ num_mels=100,
243
+ sampling_rate=24000,
244
+ hop_size=256,
245
+ win_size=1024,
246
+ fmin=0,
247
+ fmax=None,
248
+ center=False,
249
+ )
250
+
251
+
252
+ def mel_fn_v4(x):
253
+ return mel_spectrogram_torch(
254
+ y=x,
255
+ n_fft=1280,
256
+ num_mels=100,
257
+ sampling_rate=32000,
258
+ hop_size=320,
259
+ win_size=1280,
260
+ fmin=0,
261
+ fmax=None,
262
+ center=False,
263
+ )
264
+
265
+
266
+ gpt_path = str(args.gpt) or GPT_names[0][-1]
267
+ sovits_path = str(args.sovits) or SoVITS_names[0][-1]
268
+
269
+
270
+ def get_bert_feature(text, word2ph):
271
+ inputs = tokenizer(text, return_tensors="pt")
272
+ for i in inputs:
273
+ inputs[i] = inputs[i].to(infer_device)
274
+ res = bert_model(**inputs, output_hidden_states=True)
275
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
276
+
277
+ assert len(word2ph) == len(text)
278
+ phone_level_feature = []
279
+ for i in range(len(word2ph)):
280
+ repeat_feature = res[i].repeat(word2ph[i], 1)
281
+ phone_level_feature.append(repeat_feature)
282
+ phone_level_feature_t = torch.cat(phone_level_feature, dim=0)
283
+ return phone_level_feature_t.T
284
+
285
+
286
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
287
+ global vq_model, hps, version, model_version, dict_language
288
+ model_version, version, is_lora, hps, dict_s2 = inspect_version(sovits_path)
289
+ print(sovits_path, version, model_version, is_lora)
290
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
291
+ visible_sample_steps = visible_inp_refs = None
292
+ if prompt_language is not None and text_language is not None:
293
+ if prompt_language in list(dict_language.keys()):
294
+ prompt_text_update, prompt_language_update = gr.skip(), gr.update(choices=list(dict_language.keys()))
295
+ else:
296
+ prompt_text_update = gr.update(value="")
297
+ prompt_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
298
+ if text_language in list(dict_language.keys()):
299
+ text_update, text_language_update = gr.skip(), gr.skip()
300
+ else:
301
+ text_update = gr.update(value="")
302
+ text_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
303
+
304
+ if model_version in v3v4set:
305
+ visible_sample_steps = True
306
+ visible_inp_refs = False
307
+ else:
308
+ visible_sample_steps = False
309
+ visible_inp_refs = True
310
+ yield (
311
+ prompt_text_update,
312
+ prompt_language_update,
313
+ text_update,
314
+ text_language_update,
315
+ gr.update(
316
+ visible=visible_sample_steps,
317
+ value=32 if model_version == "v3" else 8,
318
+ choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
319
+ ),
320
+ gr.update(visible=visible_inp_refs),
321
+ gr.update(value=False, interactive=True if model_version not in v3v4set else False),
322
+ gr.update(visible=True if model_version == "v3" else False),
323
+ gr.update(value=i18n("模型加载中,请等待"), interactive=False),
324
+ )
325
+
326
+ hps = DictToAttrRecursive(hps)
327
+ hps.model.semantic_frame_rate = "25hz"
328
+ hps.model.version = model_version
329
+ if model_version not in v3v4set:
330
+ vq_model = SynthesizerTrn(
331
+ hps.data.filter_length // 2 + 1,
332
+ hps.train.segment_size // hps.data.hop_length,
333
+ n_speakers=hps.data.n_speakers,
334
+ **hps.model,
335
+ )
336
+ else:
337
+ raise RuntimeError("Unsupported model version")
338
+
339
+ if "pretrained" not in sovits_path:
340
+ if hasattr(vq_model, "enc_q"):
341
+ del vq_model.enc_q
342
+
343
+ if is_lora is False:
344
+ console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False))
345
+ else:
346
+ RuntimeError("Unsupported model version")
347
+
348
+ vq_model = vq_model.to(infer_device, dtype)
349
+
350
+ yield (
351
+ gr.skip(),
352
+ gr.skip(),
353
+ gr.skip(),
354
+ gr.skip(),
355
+ gr.skip(),
356
+ gr.skip(),
357
+ gr.skip(),
358
+ gr.skip(),
359
+ gr.update(value=i18n("合成语音"), interactive=True),
360
+ )
361
+
362
+
363
+ with contextlib.suppress(UnboundLocalError):
364
+ next(change_sovits_weights(sovits_path))
365
+
366
+
367
+ def change_gpt_weights(gpt_path):
368
+ global t2s_engine, config
369
+
370
+ t2s_engine = PyTorch.T2SEngineTorch(
371
+ PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend),
372
+ device,
373
+ dtype=dtype,
374
+ )
375
+ # t2s_engine.decoder_model.compile()
376
+ total = sum(p.numel() for p in t2s_engine.decoder_model.parameters())
377
+ console.print(">> Number of parameter: %.2fM" % (total / 1e6))
378
+
379
+
380
+ change_gpt_weights(gpt_path)
381
+
382
+
383
+ sv_cn_model = SV(infer_device, is_half)
384
+
385
+
386
+ resample_transform_dict = {}
387
+
388
+
389
+ def resample(audio_tensor, sr0, sr1, device):
390
+ global resample_transform_dict
391
+ key = f"{sr0}-{sr1}-{device}"
392
+ if key not in resample_transform_dict:
393
+ resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
394
+ return resample_transform_dict[key](audio_tensor)
395
+
396
+
397
+ def get_spepc(hps, filename, dtype, device, is_v2pro=False):
398
+ sr1 = int(hps.data.sampling_rate)
399
+ audio, sr0 = torchaudio.load_with_torchcodec(filename)
400
+ audio = audio.to(device)
401
+
402
+ if sr0 != sr1:
403
+ audio = resample(audio, sr0, sr1, device)
404
+ if audio.shape[0] > 1:
405
+ audio = audio.mean(0).unsqueeze(0)
406
+
407
+ maxx = float(audio.abs().max())
408
+ if maxx > 1:
409
+ audio /= min(2, maxx)
410
+ spec = spectrogram_torch(
411
+ audio,
412
+ hps.data.filter_length,
413
+ hps.data.sampling_rate,
414
+ hps.data.hop_length,
415
+ hps.data.win_length,
416
+ center=False,
417
+ )
418
+ spec = spec.to(dtype)
419
+ if is_v2pro is True:
420
+ audio = resample(audio, sr1, 16000, device).to(dtype)
421
+ return spec, audio
422
+
423
+
424
+ def clean_text_inf(text, language, version):
425
+ language = language.replace("all_", "")
426
+ phones, word2ph, norm_text = clean_text(text, language, version)
427
+ phones = cleaned_text_to_sequence(phones, version)
428
+ return phones, word2ph, norm_text
429
+
430
+
431
+ def get_bert_inf(phones, word2ph, norm_text, language):
432
+ language = language.replace("all_", "")
433
+ if language == "zh":
434
+ bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
435
+ else:
436
+ bert = torch.zeros(
437
+ (1024, len(phones)),
438
+ dtype=torch.float16 if is_half is True else torch.float32,
439
+ ).to(device)
440
+
441
+ return bert
442
+
443
+
444
+ def get_first(text):
445
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
446
+ text = re.split(pattern, text)[0].strip()
447
+ return text
448
+
449
+
450
+ def get_phones_and_bert(text, language, version, final=False):
451
+ text = re.sub(r" {2,}", " ", text)
452
+ textlist = []
453
+ langlist = []
454
+ if language == "all_zh":
455
+ for tmp in LangSegmenter.getTexts(text, "zh"):
456
+ langlist.append(tmp["lang"])
457
+ textlist.append(tmp["text"])
458
+ elif language == "all_yue":
459
+ for tmp in LangSegmenter.getTexts(text, "zh"):
460
+ if tmp["lang"] == "zh":
461
+ tmp["lang"] = "yue"
462
+ langlist.append(tmp["lang"])
463
+ textlist.append(tmp["text"])
464
+ elif language == "all_ja":
465
+ for tmp in LangSegmenter.getTexts(text, "ja"):
466
+ langlist.append(tmp["lang"])
467
+ textlist.append(tmp["text"])
468
+ elif language == "all_ko":
469
+ for tmp in LangSegmenter.getTexts(text, "ko"):
470
+ langlist.append(tmp["lang"])
471
+ textlist.append(tmp["text"])
472
+ elif language == "en":
473
+ langlist.append("en")
474
+ textlist.append(text)
475
+ elif language == "auto":
476
+ for tmp in LangSegmenter.getTexts(text):
477
+ langlist.append(tmp["lang"])
478
+ textlist.append(tmp["text"])
479
+ elif language == "auto_yue":
480
+ for tmp in LangSegmenter.getTexts(text):
481
+ if tmp["lang"] == "zh":
482
+ tmp["lang"] = "yue"
483
+ langlist.append(tmp["lang"])
484
+ textlist.append(tmp["text"])
485
+ else:
486
+ for tmp in LangSegmenter.getTexts(text):
487
+ if langlist:
488
+ if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
489
+ textlist[-1] += tmp["text"]
490
+ continue
491
+ if tmp["lang"] == "en":
492
+ langlist.append(tmp["lang"])
493
+ else:
494
+ # 因无法区别中日韩文汉字,以用户输入为准
495
+ langlist.append(language)
496
+ textlist.append(tmp["text"])
497
+ print(textlist)
498
+ print(langlist)
499
+ phones_list = []
500
+ bert_list = []
501
+ norm_text_list = []
502
+ for i in range(len(textlist)):
503
+ lang = langlist[i]
504
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
505
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
506
+ phones_list.append(phones)
507
+ norm_text_list.append(norm_text)
508
+ bert_list.append(bert)
509
+ bert = torch.cat(bert_list, dim=1)
510
+ phones = sum(phones_list, [])
511
+ norm_text = "".join(norm_text_list)
512
+
513
+ if not final and len(phones) < 6:
514
+ return get_phones_and_bert("." + text, language, version, final=True)
515
+
516
+ return phones, bert.to(dtype), norm_text
517
+
518
+
519
+ def merge_short_text_in_array(texts, threshold):
520
+ if (len(texts)) < 2:
521
+ return texts
522
+ result = []
523
+ text = ""
524
+ for ele in texts:
525
+ text += ele
526
+ if len(text) >= threshold:
527
+ result.append(text)
528
+ text = ""
529
+ if len(text) > 0:
530
+ if len(result) == 0:
531
+ result.append(text)
532
+ else:
533
+ result[len(result) - 1] += text
534
+ return result
535
+
536
+
537
+ sr_model = None
538
+
539
+
540
+ cache: dict[int, Any] = {}
541
+
542
+
543
+ @spaces.GPU
544
+ def get_tts_wav(
545
+ ref_wav_path,
546
+ prompt_text,
547
+ prompt_language,
548
+ text,
549
+ text_language,
550
+ how_to_cut=i18n("不切"),
551
+ top_k=20,
552
+ top_p=0.6,
553
+ temperature=0.6,
554
+ ref_free=False,
555
+ speed=1,
556
+ if_freeze=False,
557
+ inp_refs=None,
558
+ sample_steps=8,
559
+ if_sr=False,
560
+ pause_second=0.3,
561
+ ):
562
+ torch.set_grad_enabled(False)
563
+ ttfb_time = ttime()
564
+
565
+ if ref_wav_path:
566
+ pass
567
+ else:
568
+ gr.Warning(i18n("请上传参考音频"))
569
+ if text:
570
+ pass
571
+ else:
572
+ gr.Warning(i18n("请填入推理文本"))
573
+ t = []
574
+ if prompt_text is None or len(prompt_text) == 0:
575
+ ref_free = True
576
+ if model_version in v3v4set:
577
+ ref_free = False # s2v3暂不支持ref_free
578
+ t0 = ttime()
579
+ prompt_language = dict_language[prompt_language]
580
+ text_language = dict_language[text_language]
581
+
582
+ if not ref_free:
583
+ prompt_text = prompt_text.strip("\n")
584
+ if prompt_text[-1] not in splits:
585
+ prompt_text += "。" if prompt_language != "en" else "."
586
+ print(">>", i18n("实际输入的参考文本:"), prompt_text)
587
+ text = text.strip("\n")
588
+ # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
589
+
590
+ print(">>", i18n("实际输入的目标文本:"), text)
591
+ zero_wav = np.zeros(
592
+ int(hps.data.sampling_rate * pause_second),
593
+ dtype=np.float16 if is_half is True else np.float32,
594
+ )
595
+ zero_wav_torch = torch.from_numpy(zero_wav)
596
+ if is_half is True:
597
+ zero_wav_torch = zero_wav_torch.half().to(infer_device)
598
+ else:
599
+ zero_wav_torch = zero_wav_torch.to(infer_device)
600
+ if not ref_free:
601
+ assert vq_model
602
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
603
+ if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
604
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
605
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
606
+ wav16k_t = torch.from_numpy(wav16k)
607
+ if is_half is True:
608
+ wav16k_t = wav16k_t.half().to(infer_device)
609
+ else:
610
+ wav16k_t = wav16k_t.to(infer_device)
611
+ wav16k_t = torch.cat([wav16k_t, zero_wav_torch])
612
+ ssl_content = ssl_model.model(wav16k_t.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
613
+ codes = vq_model.extract_latent(ssl_content)
614
+ prompt_semantic = codes[0, 0]
615
+ prompt = prompt_semantic.unsqueeze(0).to(device)
616
+ else:
617
+ prompt = torch.zeros((1, 0)).to(device, torch.int32)
618
+
619
+ t1 = ttime()
620
+ t.append(t1 - t0)
621
+
622
+ if how_to_cut == i18n("凑四句一切"):
623
+ text = cut1(text)
624
+ elif how_to_cut == i18n("凑50字一切"):
625
+ text = cut2(text)
626
+ elif how_to_cut == i18n("按中文句号。切"):
627
+ text = cut3(text)
628
+ elif how_to_cut == i18n("按英文句号.切"):
629
+ text = cut4(text)
630
+ elif how_to_cut == i18n("按标点符号切"):
631
+ text = cut5(text)
632
+ while "\n\n" in text:
633
+ text = text.replace("\n\n", "\n")
634
+ texts = text.split("\n")
635
+ texts = process_text(texts)
636
+ texts = merge_short_text_in_array(texts, 5)
637
+ audio_opt = []
638
+ # s2v3暂不支持ref_free
639
+ if not ref_free:
640
+ phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version)
641
+ else:
642
+ phones1, bert1 = [], torch.zeros(1024, 0).to(device, dtype)
643
+
644
+ infer_len: list[int] = []
645
+ infer_time: list[float] = []
646
+ assert vq_model
647
+
648
+ for i_text, text in enumerate(texts):
649
+ # 解决输入目标文本的空行导致报错的问题
650
+ if len(text.strip()) == 0:
651
+ continue
652
+ if text[-1] not in splits:
653
+ text += "。" if text_language != "en" else "."
654
+ print(">>", i18n("实际输入的目标文本(每句):"), text)
655
+ phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
656
+ print(">>", i18n("前端处理后的文本(每句):"), norm_text2)
657
+
658
+ bert = torch.cat([bert1, bert2], 1)
659
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
660
+
661
+ bert = bert.to(device).unsqueeze(0)
662
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
663
+
664
+ t2 = ttime()
665
+ if i_text in cache and if_freeze is True:
666
+ pred_semantic = cache[i_text]
667
+ else:
668
+ t2s_request = T2SRequest(
669
+ [all_phoneme_ids.squeeze(0)],
670
+ all_phoneme_len,
671
+ prompt,
672
+ [bert.squeeze(0)],
673
+ valid_length=1,
674
+ top_k=top_k,
675
+ top_p=top_p,
676
+ temperature=temperature,
677
+ early_stop_num=1500,
678
+ use_cuda_graph=torch.cuda.is_available(),
679
+ # debug=True,
680
+ )
681
+ assert t2s_engine
682
+ t2s_result = t2s_engine.generate(t2s_request)
683
+ if t2s_result.exception is not None:
684
+ console.print(t2s_result.traceback)
685
+ raise RuntimeError()
686
+ pred_semantic_list = t2s_result.result
687
+ assert pred_semantic_list, t2s_result.traceback
688
+ pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device)
689
+ infer_len.append(pred_semantic.shape[-1])
690
+ infer_time.append(t2s_result.infer_speed[-1])
691
+
692
+ cache[i_text] = pred_semantic
693
+ t3 = ttime()
694
+ is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
695
+
696
+ sv_emb: list[torch.Tensor] = []
697
+ if model_version not in v3v4set:
698
+ refers = []
699
+ if inp_refs:
700
+ for path in inp_refs:
701
+ try: # 这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
702
+ refer, audio_tensor = get_spepc(hps, path.name, dtype, infer_device, is_v2pro)
703
+ refers.append(refer)
704
+ if is_v2pro:
705
+ assert sv_cn_model
706
+ sv_emb.append(sv_cn_model.compute_embedding(audio_tensor))
707
+ except Exception as e:
708
+ print(e)
709
+ traceback.print_exc()
710
+ if len(refers) == 0:
711
+ refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, infer_device, is_v2pro)
712
+ refers = [refers]
713
+ if is_v2pro:
714
+ assert sv_cn_model
715
+ sv_emb = [sv_cn_model.compute_embedding(audio_tensor)]
716
+ if is_v2pro:
717
+ audio = vq_model.decode(
718
+ pred_semantic,
719
+ torch.LongTensor(phones2).to(infer_device).unsqueeze(0),
720
+ refers,
721
+ speed=speed,
722
+ sv_emb=sv_emb,
723
+ )[0][0] # type: ignore
724
+ else:
725
+ audio = vq_model.decode(
726
+ pred_semantic,
727
+ torch.LongTensor(phones2).to(infer_device).unsqueeze(0),
728
+ refers,
729
+ speed=speed,
730
+ )[0][0] # type: ignore
731
+ else:
732
+ raise RuntimeError("Unsupported model version")
733
+ if i_text == 0:
734
+ ttfb_time = ttime() - ttfb_time
735
+ max_audio = torch.abs(audio).max() # 简单防止16bit爆音
736
+ if max_audio > 1:
737
+ audio = audio / max_audio
738
+ audio_opt.append(audio)
739
+ audio_opt.append(zero_wav_torch) # zero_wav
740
+ t4 = ttime()
741
+ t.extend([t2 - t1, t3 - t2, t4 - t3])
742
+ t1 = ttime()
743
+
744
+ audio_opt_t = torch.cat(audio_opt, 0) # np.concatenate
745
+ opt_sr = 32000
746
+ audio_opt_n = audio_opt_t.cpu().numpy()
747
+
748
+ t0 = t[0]
749
+ t1 = sum(t[1::3])
750
+ t2 = sum(t[2::3])
751
+ t3 = sum(t[3::3])
752
+
753
+ infer_speed_avg = sum(infer_len) / sum(infer_time)
754
+ rtf_value = sum(t) / (audio_opt_n.__len__() / opt_sr)
755
+
756
+ console.print(f">> Time Stamps: {t0:.3f}\t{t1:.3f}\t{t2:.3f}\t{t3:.3f}")
757
+ console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
758
+ console.print(f">> RTF: {rtf_value:.2f}")
759
+ if ttfb_time > 2:
760
+ console.print(f">> TTFB: {ttfb_time:.3f} s")
761
+ else:
762
+ console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
763
+
764
+ gr.Info(f"{infer_speed_avg:.2f} Token/s", title="Infer Speed")
765
+ gr.Info(f"{rtf_value:.2f}", title="RTF")
766
+
767
+ if ttfb_time > 2:
768
+ gr.Info(f">> TTFB: {ttfb_time:.3f} s")
769
+ else:
770
+ gr.Info(f">> TTFB: {ttfb_time * 1000:.3f} ms")
771
+
772
+ if torch.cuda.is_available():
773
+ torch.cuda.empty_cache()
774
+
775
+ yield opt_sr, (audio_opt_n * 32767).astype(np.int16)
776
+
777
+
778
+ def split(todo_text):
779
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
780
+ if todo_text[-1] not in splits:
781
+ todo_text += "。"
782
+ i_split_head = i_split_tail = 0
783
+ len_text = len(todo_text)
784
+ todo_texts = []
785
+ while 1:
786
+ if i_split_head >= len_text:
787
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
788
+ if todo_text[i_split_head] in splits:
789
+ i_split_head += 1
790
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
791
+ i_split_tail = i_split_head
792
+ else:
793
+ i_split_head += 1
794
+ return todo_texts
795
+
796
+
797
+ def cut1(inp):
798
+ inp = inp.strip("\n")
799
+ inps = split(inp)
800
+ split_idx: list[int | None] = list(range(0, len(inps) + 1, 4))
801
+ split_idx[-1] = None
802
+ if len(split_idx) > 1:
803
+ opts = []
804
+ for idx in range(len(split_idx) - 1):
805
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
806
+ else:
807
+ opts = [inp]
808
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
809
+ return "\n".join(opts)
810
+
811
+
812
+ def cut2(inp):
813
+ inp = inp.strip("\n")
814
+ inps = split(inp)
815
+ if len(inps) < 2:
816
+ return inp
817
+ opts = []
818
+ summ = 0
819
+ tmp_str = ""
820
+ for i in range(len(inps)):
821
+ summ += len(inps[i])
822
+ tmp_str += inps[i]
823
+ if summ > 50:
824
+ summ = 0
825
+ opts.append(tmp_str)
826
+ tmp_str = ""
827
+ if tmp_str != "":
828
+ opts.append(tmp_str)
829
+ if len(opts) > 1 and len(opts[-1]) < 50: # 如果最后一个太短了,和前一个合一起
830
+ opts[-2] = opts[-2] + opts[-1]
831
+ opts = opts[:-1]
832
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
833
+ return "\n".join(opts)
834
+
835
+
836
+ def cut3(inp):
837
+ inp = inp.strip("\n")
838
+ opts = inp.strip("。").split("。")
839
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
840
+ return "\n".join(opts)
841
+
842
+
843
+ def cut4(inp):
844
+ inp = inp.strip("\n")
845
+ opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
846
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
847
+ return "\n".join(opts)
848
+
849
+
850
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
851
+ def cut5(inp):
852
+ inp = inp.strip("\n")
853
+ punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
854
+ mergeitems = []
855
+ items = []
856
+
857
+ for i, char in enumerate(inp):
858
+ if char in punds:
859
+ if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
860
+ items.append(char)
861
+ else:
862
+ items.append(char)
863
+ mergeitems.append("".join(items))
864
+ items = []
865
+ else:
866
+ items.append(char)
867
+
868
+ if items:
869
+ mergeitems.append("".join(items))
870
+
871
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
872
+ return "\n".join(opt)
873
+
874
+
875
+ def process_text(texts):
876
+ _text = []
877
+ if all(text in [None, " ", "\n", ""] for text in texts):
878
+ raise ValueError(i18n("请输入有效文本"))
879
+ for text in texts:
880
+ if text in [None, " ", ""]:
881
+ pass
882
+ else:
883
+ _text.append(text)
884
+ return _text
885
+
886
+
887
+ def html_center(text, label="p"):
888
+ return f"""<div style="text-align: center; margin: 100; padding: 50;">
889
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
890
+ </div>"""
891
+
892
+
893
+ def html_left(text, label="p"):
894
+ return f"""<div style="text-align: left; margin: 0; padding: 0;">
895
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
896
+ </div>"""
897
+
898
+
899
+ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
900
+ gr.HTML(
901
+ top_html.format(
902
+ i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
903
+ + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
904
+ ),
905
+ elem_classes="markdown",
906
+ )
907
+ gr.Markdown(html_center(i18n("模型切换"), "h3"))
908
+ with gr.Row(equal_height=True):
909
+ with gr.Column(scale=2):
910
+ with gr.Row(equal_height=True):
911
+ GPT_dropdown = gr.Dropdown(
912
+ label=i18n("GPT模型列表"),
913
+ choices=GPT_names,
914
+ value=gpt_path,
915
+ interactive=True,
916
+ )
917
+ SoVITS_dropdown = gr.Dropdown(
918
+ label=i18n("SoVITS模型列表"),
919
+ choices=SoVITS_names,
920
+ value=sovits_path,
921
+ interactive=True,
922
+ )
923
+ with gr.Column(scale=1):
924
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14)
925
+ refresh_button.click(fn=change_choices_i18n, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
926
+ gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
927
+ with gr.Row(equal_height=True):
928
+ with gr.Column(scale=2):
929
+ with gr.Row(equal_height=True):
930
+ with gr.Column(scale=1):
931
+ inp_ref = gr.Audio(
932
+ label=i18n("请上传3~10秒内参考音频,超过会报错!"),
933
+ type="filepath",
934
+ sources="upload",
935
+ scale=13,
936
+ editable=False,
937
+ waveform_options={"show_recording_waveform": False},
938
+ )
939
+ with gr.Column(scale=1):
940
+ gr.Markdown(
941
+ html_center(
942
+ i18n("使用无参考文本模式时建议使用微调的GPT")
943
+ + "<br>"
944
+ + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
945
+ )
946
+ )
947
+ ref_text_free = gr.Checkbox(
948
+ label=i18n("开启无参考文本模式"),
949
+ info=i18n("不填参考文本亦相当于开启") + ", " + i18n("v3暂不支持该模式,使用了会报错。"),
950
+ value=False,
951
+ interactive=True if model_version not in v3v4set else False,
952
+ show_label=True,
953
+ scale=1,
954
+ )
955
+ prompt_language = gr.Dropdown(
956
+ label="",
957
+ info=i18n("参考音频的语种"),
958
+ choices=list(dict_language.keys()),
959
+ value=i18n("中文"),
960
+ )
961
+ prompt_text = gr.Textbox(label="", info=i18n("参考音频的文本"), value="", lines=3, max_lines=3)
962
+
963
+ with gr.Column(scale=1):
964
+ inp_refs = (
965
+ gr.File(
966
+ label=i18n(
967
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
968
+ ),
969
+ file_count="multiple",
970
+ )
971
+ if model_version not in v3v4set
972
+ else gr.File(
973
+ label=i18n(
974
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
975
+ ),
976
+ file_count="multiple",
977
+ visible=False,
978
+ )
979
+ )
980
+ sample_steps = (
981
+ gr.Radio(
982
+ label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
983
+ value=32 if model_version == "v3" else 8,
984
+ choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
985
+ visible=True,
986
+ )
987
+ if model_version in v3v4set
988
+ else gr.Radio(
989
+ label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
990
+ choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
991
+ visible=False,
992
+ value=32 if model_version == "v3" else 8,
993
+ )
994
+ )
995
+ if_sr_Checkbox = gr.Checkbox(
996
+ label=i18n("v3输出如果觉得闷可以试试开超分"),
997
+ value=False,
998
+ interactive=True,
999
+ show_label=True,
1000
+ visible=False if model_version != "v3" else True,
1001
+ )
1002
+ gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
1003
+ with gr.Row(equal_height=True):
1004
+ with gr.Column(scale=2):
1005
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=30, max_lines=40)
1006
+ with gr.Column(scale=1):
1007
+ text_language = gr.Dropdown(
1008
+ label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
1009
+ choices=list(dict_language.keys()),
1010
+ value=i18n("中文"),
1011
+ scale=1,
1012
+ )
1013
+ how_to_cut = gr.Dropdown(
1014
+ label=i18n("怎么切"),
1015
+ choices=[
1016
+ i18n("不切"),
1017
+ i18n("凑四句一切"),
1018
+ i18n("凑50字一切"),
1019
+ i18n("按中文句号。切"),
1020
+ i18n("按英文句号.切"),
1021
+ i18n("按标点符号切"),
1022
+ ],
1023
+ value=i18n("凑四句一切"),
1024
+ interactive=True,
1025
+ scale=1,
1026
+ )
1027
+ if_freeze = gr.Checkbox(
1028
+ label=i18n("是否直接对上次合成结果调整语速和音色"),
1029
+ value=False,
1030
+ interactive=True,
1031
+ show_label=True,
1032
+ scale=1,
1033
+ )
1034
+ with gr.Row(equal_height=True):
1035
+ speed = gr.Slider(
1036
+ minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1
1037
+ )
1038
+ pause_second_slider = gr.Slider(
1039
+ minimum=0.1,
1040
+ maximum=0.5,
1041
+ step=0.01,
1042
+ label=i18n("句间停顿秒数"),
1043
+ value=0.3,
1044
+ interactive=True,
1045
+ scale=1,
1046
+ )
1047
+ gr.Markdown(html_center(i18n("GPT采样参数(不懂就用默认):")))
1048
+ top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1)
1049
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1)
1050
+ temperature = gr.Slider(
1051
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1
1052
+ )
1053
+ with gr.Row(equal_height=True):
1054
+ with gr.Column(scale=2):
1055
+ inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg")
1056
+ with gr.Column(scale=1):
1057
+ output = gr.Audio(
1058
+ label=i18n("输出的语音"),
1059
+ waveform_options={"show_recording_waveform": False},
1060
+ editable=False,
1061
+ )
1062
+
1063
+ inference_button.click(
1064
+ get_tts_wav,
1065
+ [
1066
+ inp_ref,
1067
+ prompt_text,
1068
+ prompt_language,
1069
+ text,
1070
+ text_language,
1071
+ how_to_cut,
1072
+ top_k,
1073
+ top_p,
1074
+ temperature,
1075
+ ref_text_free,
1076
+ speed,
1077
+ if_freeze,
1078
+ inp_refs,
1079
+ sample_steps,
1080
+ if_sr_Checkbox,
1081
+ pause_second_slider,
1082
+ ],
1083
+ [output],
1084
+ )
1085
+ SoVITS_dropdown.change(
1086
+ change_sovits_weights,
1087
+ [SoVITS_dropdown, prompt_language, text_language],
1088
+ [
1089
+ prompt_text,
1090
+ prompt_language,
1091
+ text,
1092
+ text_language,
1093
+ sample_steps,
1094
+ inp_refs,
1095
+ ref_text_free,
1096
+ if_sr_Checkbox,
1097
+ inference_button,
1098
+ ],
1099
+ )
1100
+ GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
1101
+
1102
+
1103
+ if __name__ == "__main__":
1104
+ app.queue(api_open=False, default_concurrency_limit=1, max_size=1024).launch()
GPT_SoVITS/module/attentions.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.utils import remove_weight_norm
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from . import commons
10
+ from .modules import LayerNorm
11
+
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_channels,
17
+ filter_channels,
18
+ n_heads,
19
+ n_layers,
20
+ kernel_size=1,
21
+ p_dropout=0.0,
22
+ window_size=4,
23
+ isflow=False,
24
+ **kwargs,
25
+ ):
26
+ super().__init__()
27
+ self.hidden_channels = hidden_channels
28
+ self.filter_channels = filter_channels
29
+ self.n_heads = n_heads
30
+ self.n_layers = n_layers
31
+ self.kernel_size = kernel_size
32
+ self.p_dropout = p_dropout
33
+ self.window_size = window_size
34
+
35
+ self.drop = nn.Dropout(p_dropout)
36
+ self.attn_layers = nn.ModuleList()
37
+ self.norm_layers_1 = nn.ModuleList()
38
+ self.ffn_layers = nn.ModuleList()
39
+ self.norm_layers_2 = nn.ModuleList()
40
+ for i in range(self.n_layers):
41
+ self.attn_layers.append(
42
+ MultiHeadAttention(
43
+ hidden_channels,
44
+ hidden_channels,
45
+ n_heads,
46
+ p_dropout=p_dropout,
47
+ window_size=window_size,
48
+ )
49
+ )
50
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
51
+ self.ffn_layers.append(
52
+ FFN(
53
+ hidden_channels,
54
+ hidden_channels,
55
+ filter_channels,
56
+ kernel_size,
57
+ p_dropout=p_dropout,
58
+ )
59
+ )
60
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
61
+ if isflow:
62
+ cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
63
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
64
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
65
+ self.gin_channels = kwargs["gin_channels"]
66
+
67
+ def forward(self, x, x_mask, g=None):
68
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
69
+ x = x * x_mask
70
+ if g is not None:
71
+ g = self.cond_layer(g)
72
+
73
+ for i in range(self.n_layers):
74
+ if g is not None:
75
+ x = self.cond_pre(x)
76
+ cond_offset = i * 2 * self.hidden_channels
77
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
78
+ x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
79
+ y = self.attn_layers[i](x, x, attn_mask)
80
+ y = self.drop(y)
81
+ x = self.norm_layers_1[i](x + y)
82
+
83
+ y = self.ffn_layers[i](x, x_mask)
84
+ y = self.drop(y)
85
+ x = self.norm_layers_2[i](x + y)
86
+ x = x * x_mask
87
+ return x
88
+
89
+
90
+ class Decoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ hidden_channels,
94
+ filter_channels,
95
+ n_heads,
96
+ n_layers,
97
+ kernel_size=1,
98
+ p_dropout=0.0,
99
+ proximal_bias=False,
100
+ proximal_init=True,
101
+ **kwargs,
102
+ ):
103
+ super().__init__()
104
+ self.hidden_channels = hidden_channels
105
+ self.filter_channels = filter_channels
106
+ self.n_heads = n_heads
107
+ self.n_layers = n_layers
108
+ self.kernel_size = kernel_size
109
+ self.p_dropout = p_dropout
110
+ self.proximal_bias = proximal_bias
111
+ self.proximal_init = proximal_init
112
+
113
+ self.drop = nn.Dropout(p_dropout)
114
+ self.self_attn_layers = nn.ModuleList()
115
+ self.norm_layers_0 = nn.ModuleList()
116
+ self.encdec_attn_layers = nn.ModuleList()
117
+ self.norm_layers_1 = nn.ModuleList()
118
+ self.ffn_layers = nn.ModuleList()
119
+ self.norm_layers_2 = nn.ModuleList()
120
+ for i in range(self.n_layers):
121
+ self.self_attn_layers.append(
122
+ MultiHeadAttention(
123
+ hidden_channels,
124
+ hidden_channels,
125
+ n_heads,
126
+ p_dropout=p_dropout,
127
+ proximal_bias=proximal_bias,
128
+ proximal_init=proximal_init,
129
+ )
130
+ )
131
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
132
+ self.encdec_attn_layers.append(
133
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
134
+ )
135
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
136
+ self.ffn_layers.append(
137
+ FFN(
138
+ hidden_channels,
139
+ hidden_channels,
140
+ filter_channels,
141
+ kernel_size,
142
+ p_dropout=p_dropout,
143
+ causal=True,
144
+ )
145
+ )
146
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
147
+
148
+ def forward(self, x, x_mask, h, h_mask):
149
+ """
150
+ x: decoder input
151
+ h: encoder output
152
+ """
153
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
154
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
155
+ x = x * x_mask
156
+ for i in range(self.n_layers):
157
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
158
+ y = self.drop(y)
159
+ x = self.norm_layers_0[i](x + y)
160
+
161
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
162
+ y = self.drop(y)
163
+ x = self.norm_layers_1[i](x + y)
164
+
165
+ y = self.ffn_layers[i](x, x_mask)
166
+ y = self.drop(y)
167
+ x = self.norm_layers_2[i](x + y)
168
+ x = x * x_mask
169
+ return x
170
+
171
+
172
+ class MultiHeadAttention(nn.Module):
173
+ def __init__(
174
+ self,
175
+ channels,
176
+ out_channels,
177
+ n_heads,
178
+ p_dropout=0.0,
179
+ window_size=None,
180
+ heads_share=True,
181
+ block_length=None,
182
+ proximal_bias=False,
183
+ proximal_init=False,
184
+ ):
185
+ super().__init__()
186
+ assert channels % n_heads == 0
187
+
188
+ self.channels = channels
189
+ self.out_channels = out_channels
190
+ self.n_heads = n_heads
191
+ self.p_dropout = p_dropout
192
+ self.window_size = window_size
193
+ self.heads_share = heads_share
194
+ self.block_length = block_length
195
+ self.proximal_bias = proximal_bias
196
+ self.proximal_init = proximal_init
197
+ self.attn = None
198
+
199
+ self.k_channels = channels // n_heads
200
+ self.conv_q = nn.Conv1d(channels, channels, 1)
201
+ self.conv_k = nn.Conv1d(channels, channels, 1)
202
+ self.conv_v = nn.Conv1d(channels, channels, 1)
203
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
204
+ self.drop = nn.Dropout(p_dropout)
205
+
206
+ if window_size is not None:
207
+ n_heads_rel = 1 if heads_share else n_heads
208
+ rel_stddev = self.k_channels**-0.5
209
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
210
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
211
+
212
+ nn.init.xavier_uniform_(self.conv_q.weight)
213
+ nn.init.xavier_uniform_(self.conv_k.weight)
214
+ nn.init.xavier_uniform_(self.conv_v.weight)
215
+ if proximal_init:
216
+ with torch.no_grad():
217
+ self.conv_k.weight.copy_(self.conv_q.weight)
218
+ self.conv_k.bias.copy_(self.conv_q.bias)
219
+
220
+ def forward(self, x, c, attn_mask=None):
221
+ q = self.conv_q(x)
222
+ k = self.conv_k(c)
223
+ v = self.conv_v(c)
224
+
225
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
226
+
227
+ x = self.conv_o(x)
228
+ return x
229
+
230
+ def attention(self, query, key, value, mask=None):
231
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
232
+ b, d, t_s, t_t = (*key.size(), query.size(2))
233
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
234
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
235
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
236
+
237
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
238
+ if self.window_size is not None:
239
+ assert t_s == t_t, "Relative attention is only available for self-attention."
240
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
241
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
242
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
243
+ scores = scores + scores_local
244
+ if self.proximal_bias:
245
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
246
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
247
+ if mask is not None:
248
+ scores = scores.masked_fill(mask == 0, -1e4)
249
+ if self.block_length is not None:
250
+ assert t_s == t_t, "Local attention is only available for self-attention."
251
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
252
+ scores = scores.masked_fill(block_mask == 0, -1e4)
253
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
254
+ p_attn = self.drop(p_attn)
255
+ output = torch.matmul(p_attn, value)
256
+ if self.window_size is not None:
257
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
258
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
259
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
260
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
261
+ return output, p_attn
262
+
263
+ def _matmul_with_relative_values(self, x, y):
264
+ """
265
+ x: [b, h, l, m]
266
+ y: [h or 1, m, d]
267
+ ret: [b, h, l, d]
268
+ """
269
+ ret = torch.matmul(x, y.unsqueeze(0))
270
+ return ret
271
+
272
+ def _matmul_with_relative_keys(self, x, y):
273
+ """
274
+ x: [b, h, l, d]
275
+ y: [h or 1, m, d]
276
+ ret: [b, h, l, m]
277
+ """
278
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
279
+ return ret
280
+
281
+ def _get_relative_embeddings(self, relative_embeddings, length):
282
+ max_relative_position = 2 * self.window_size + 1
283
+ # Pad first before slice to avoid using cond ops.
284
+ pad_length = max(length - (self.window_size + 1), 0)
285
+ slice_start_position = max((self.window_size + 1) - length, 0)
286
+ slice_end_position = slice_start_position + 2 * length - 1
287
+ if pad_length > 0:
288
+ padded_relative_embeddings = F.pad(
289
+ relative_embeddings,
290
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
291
+ )
292
+ else:
293
+ padded_relative_embeddings = relative_embeddings
294
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
295
+ return used_relative_embeddings
296
+
297
+ def _relative_position_to_absolute_position(self, x):
298
+ """
299
+ x: [b, h, l, 2*l-1]
300
+ ret: [b, h, l, l]
301
+ """
302
+ batch, heads, length, _ = x.size()
303
+ # Concat columns of pad to shift from relative to absolute indexing.
304
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
305
+
306
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
307
+ x_flat = x.view([batch, heads, length * 2 * length])
308
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
309
+
310
+ # Reshape and slice out the padded elements.
311
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
312
+ return x_final
313
+
314
+ def _absolute_position_to_relative_position(self, x):
315
+ """
316
+ x: [b, h, l, l]
317
+ ret: [b, h, l, 2*l-1]
318
+ """
319
+ batch, heads, length, _ = x.size()
320
+ # padd along column
321
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
322
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
323
+ # add 0's in the beginning that will skew the elements after reshape
324
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
325
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
326
+ return x_final
327
+
328
+ def _attention_bias_proximal(self, length):
329
+ """Bias for self-attention to encourage attention to close positions.
330
+ Args:
331
+ length: an integer scalar.
332
+ Returns:
333
+ a Tensor with shape [1, 1, length, length]
334
+ """
335
+ r = torch.arange(length, dtype=torch.float32)
336
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
337
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
338
+
339
+
340
+ class FFN(nn.Module):
341
+ def __init__(
342
+ self,
343
+ in_channels,
344
+ out_channels,
345
+ filter_channels,
346
+ kernel_size,
347
+ p_dropout=0.0,
348
+ activation=None,
349
+ causal=False,
350
+ ):
351
+ super().__init__()
352
+ self.in_channels = in_channels
353
+ self.out_channels = out_channels
354
+ self.filter_channels = filter_channels
355
+ self.kernel_size = kernel_size
356
+ self.p_dropout = p_dropout
357
+ self.activation = activation
358
+ self.causal = causal
359
+
360
+ if causal:
361
+ self.padding = self._causal_padding
362
+ else:
363
+ self.padding = self._same_padding
364
+
365
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
366
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
367
+ self.drop = nn.Dropout(p_dropout)
368
+
369
+ def forward(self, x, x_mask):
370
+ x = self.conv_1(self.padding(x * x_mask))
371
+ if self.activation == "gelu":
372
+ x = x * torch.sigmoid(1.702 * x)
373
+ else:
374
+ x = torch.relu(x)
375
+ x = self.drop(x)
376
+ x = self.conv_2(self.padding(x * x_mask))
377
+ return x * x_mask
378
+
379
+ def _causal_padding(self, x):
380
+ if self.kernel_size == 1:
381
+ return x
382
+ pad_l = self.kernel_size - 1
383
+ pad_r = 0
384
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
385
+ x = F.pad(x, commons.convert_pad_shape(padding))
386
+ return x
387
+
388
+ def _same_padding(self, x):
389
+ if self.kernel_size == 1:
390
+ return x
391
+ pad_l = (self.kernel_size - 1) // 2
392
+ pad_r = self.kernel_size // 2
393
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
394
+ x = F.pad(x, commons.convert_pad_shape(padding))
395
+ return x
396
+
397
+
398
+ class Depthwise_Separable_Conv1D(nn.Module):
399
+ def __init__(
400
+ self,
401
+ in_channels,
402
+ out_channels,
403
+ kernel_size,
404
+ stride=1,
405
+ padding=0,
406
+ dilation=1,
407
+ bias=True,
408
+ padding_mode="zeros", # TODO: refine this type
409
+ device=None,
410
+ dtype=None,
411
+ ):
412
+ super().__init__()
413
+ self.depth_conv = nn.Conv1d(
414
+ in_channels=in_channels,
415
+ out_channels=in_channels,
416
+ kernel_size=kernel_size,
417
+ groups=in_channels,
418
+ stride=stride,
419
+ padding=padding,
420
+ dilation=dilation,
421
+ bias=bias,
422
+ padding_mode=padding_mode,
423
+ device=device,
424
+ dtype=dtype,
425
+ )
426
+ self.point_conv = nn.Conv1d(
427
+ in_channels=in_channels,
428
+ out_channels=out_channels,
429
+ kernel_size=1,
430
+ bias=bias,
431
+ device=device,
432
+ dtype=dtype,
433
+ )
434
+
435
+ def forward(self, input):
436
+ return self.point_conv(self.depth_conv(input))
437
+
438
+ def weight_norm(self):
439
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
440
+ self.point_conv = weight_norm(self.point_conv, name="weight")
441
+
442
+ def remove_weight_norm(self):
443
+ self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
444
+ self.point_conv = remove_weight_norm(self.point_conv, name="weight")
445
+
446
+
447
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
448
+ def __init__(
449
+ self,
450
+ in_channels,
451
+ out_channels,
452
+ kernel_size,
453
+ stride=1,
454
+ padding=0,
455
+ output_padding=0,
456
+ bias=True,
457
+ dilation=1,
458
+ padding_mode="zeros", # TODO: refine this type
459
+ device=None,
460
+ dtype=None,
461
+ ):
462
+ super().__init__()
463
+ self.depth_conv = nn.ConvTranspose1d(
464
+ in_channels=in_channels,
465
+ out_channels=in_channels,
466
+ kernel_size=kernel_size,
467
+ groups=in_channels,
468
+ stride=stride,
469
+ output_padding=output_padding,
470
+ padding=padding,
471
+ dilation=dilation,
472
+ bias=bias,
473
+ padding_mode=padding_mode,
474
+ device=device,
475
+ dtype=dtype,
476
+ )
477
+ self.point_conv = nn.Conv1d(
478
+ in_channels=in_channels,
479
+ out_channels=out_channels,
480
+ kernel_size=1,
481
+ bias=bias,
482
+ device=device,
483
+ dtype=dtype,
484
+ )
485
+
486
+ def forward(self, input):
487
+ return self.point_conv(self.depth_conv(input))
488
+
489
+ def weight_norm(self):
490
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
491
+ self.point_conv = weight_norm(self.point_conv, name="weight")
492
+
493
+ def remove_weight_norm(self):
494
+ remove_weight_norm(self.depth_conv, name="weight")
495
+ remove_weight_norm(self.point_conv, name="weight")
496
+
497
+
498
+ def weight_norm_modules(module, name="weight", dim=0):
499
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
500
+ module.weight_norm()
501
+ return module
502
+ else:
503
+ return weight_norm(module, name, dim)
504
+
505
+
506
+ def remove_weight_norm_modules(module, name="weight"):
507
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
508
+ module.remove_weight_norm()
509
+ else:
510
+ remove_weight_norm(module, name)
511
+
512
+
513
+ class FFT(nn.Module):
514
+ def __init__(
515
+ self,
516
+ hidden_channels,
517
+ filter_channels,
518
+ n_heads,
519
+ n_layers=1,
520
+ kernel_size=1,
521
+ p_dropout=0.0,
522
+ proximal_bias=False,
523
+ proximal_init=True,
524
+ isflow=False,
525
+ **kwargs,
526
+ ):
527
+ super().__init__()
528
+ self.hidden_channels = hidden_channels
529
+ self.filter_channels = filter_channels
530
+ self.n_heads = n_heads
531
+ self.n_layers = n_layers
532
+ self.kernel_size = kernel_size
533
+ self.p_dropout = p_dropout
534
+ self.proximal_bias = proximal_bias
535
+ self.proximal_init = proximal_init
536
+ if isflow:
537
+ cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
538
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
539
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
540
+ self.gin_channels = kwargs["gin_channels"]
541
+ self.drop = nn.Dropout(p_dropout)
542
+ self.self_attn_layers = nn.ModuleList()
543
+ self.norm_layers_0 = nn.ModuleList()
544
+ self.ffn_layers = nn.ModuleList()
545
+ self.norm_layers_1 = nn.ModuleList()
546
+ for i in range(self.n_layers):
547
+ self.self_attn_layers.append(
548
+ MultiHeadAttention(
549
+ hidden_channels,
550
+ hidden_channels,
551
+ n_heads,
552
+ p_dropout=p_dropout,
553
+ proximal_bias=proximal_bias,
554
+ proximal_init=proximal_init,
555
+ )
556
+ )
557
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
558
+ self.ffn_layers.append(
559
+ FFN(
560
+ hidden_channels,
561
+ hidden_channels,
562
+ filter_channels,
563
+ kernel_size,
564
+ p_dropout=p_dropout,
565
+ causal=True,
566
+ )
567
+ )
568
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
569
+
570
+ def forward(self, x, x_mask, g=None):
571
+ """
572
+ x: decoder input
573
+ h: encoder output
574
+ """
575
+ if g is not None:
576
+ g = self.cond_layer(g)
577
+
578
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
579
+ x = x * x_mask
580
+ for i in range(self.n_layers):
581
+ if g is not None:
582
+ x = self.cond_pre(x)
583
+ cond_offset = i * 2 * self.hidden_channels
584
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
585
+ x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
586
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
587
+ y = self.drop(y)
588
+ x = self.norm_layers_0[i](x + y)
589
+
590
+ y = self.ffn_layers[i](x, x_mask)
591
+ y = self.drop(y)
592
+ x = self.norm_layers_1[i](x + y)
593
+ x = x * x_mask
594
+ return x
595
+
596
+
597
+ class TransformerCouplingLayer(nn.Module):
598
+ def __init__(
599
+ self,
600
+ channels,
601
+ hidden_channels,
602
+ kernel_size,
603
+ n_layers,
604
+ n_heads,
605
+ p_dropout=0,
606
+ filter_channels=0,
607
+ mean_only=False,
608
+ wn_sharing_parameter=None,
609
+ gin_channels=0,
610
+ ):
611
+ assert channels % 2 == 0, "channels should be divisible by 2"
612
+ super().__init__()
613
+ self.channels = channels
614
+ self.hidden_channels = hidden_channels
615
+ self.kernel_size = kernel_size
616
+ self.n_layers = n_layers
617
+ self.half_channels = channels // 2
618
+ self.mean_only = mean_only
619
+
620
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
621
+ self.enc = (
622
+ Encoder(
623
+ hidden_channels,
624
+ filter_channels,
625
+ n_heads,
626
+ n_layers,
627
+ kernel_size,
628
+ p_dropout,
629
+ isflow=True,
630
+ gin_channels=gin_channels,
631
+ )
632
+ if wn_sharing_parameter is None
633
+ else wn_sharing_parameter
634
+ )
635
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
636
+ self.post.weight.data.zero_()
637
+ self.post.bias.data.zero_()
638
+
639
+ def forward(self, x, x_mask, g=None, reverse=False):
640
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
641
+ h = self.pre(x0) * x_mask
642
+ h = self.enc(h, x_mask, g=g)
643
+ stats = self.post(h) * x_mask
644
+ if not self.mean_only:
645
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
646
+ else:
647
+ m = stats
648
+ logs = torch.zeros_like(m)
649
+
650
+ if not reverse:
651
+ x1 = m + x1 * torch.exp(logs) * x_mask
652
+ x = torch.cat([x0, x1], 1)
653
+ logdet = torch.sum(logs, [1, 2])
654
+ return x, logdet
655
+ else:
656
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
657
+ x = torch.cat([x0, x1], 1)
658
+ return x
GPT_SoVITS/module/attentions_onnx.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from . import commons
9
+
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ @torch.jit.script
27
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
28
+ n_channels_int = n_channels[0]
29
+ in_act = input_a + input_b
30
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
31
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
32
+ acts = t_act * s_act
33
+ return acts
34
+
35
+
36
+ class Encoder(nn.Module):
37
+ def __init__(
38
+ self,
39
+ hidden_channels,
40
+ filter_channels,
41
+ n_heads,
42
+ n_layers,
43
+ kernel_size=1,
44
+ p_dropout=0.0,
45
+ window_size=4,
46
+ isflow=True,
47
+ **kwargs,
48
+ ):
49
+ super().__init__()
50
+ self.hidden_channels = hidden_channels
51
+ self.filter_channels = filter_channels
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.kernel_size = kernel_size
55
+ self.p_dropout = p_dropout
56
+ self.window_size = window_size
57
+ # if isflow:
58
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
59
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
60
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
61
+ # self.gin_channels = 256
62
+ self.cond_layer_idx = self.n_layers
63
+ self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
70
+ logging.debug(self.gin_channels, self.cond_layer_idx)
71
+ assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
72
+ self.drop = nn.Dropout(p_dropout)
73
+ self.attn_layers = nn.ModuleList()
74
+ self.norm_layers_1 = nn.ModuleList()
75
+ self.ffn_layers = nn.ModuleList()
76
+ self.norm_layers_2 = nn.ModuleList()
77
+ for i in range(self.n_layers):
78
+ self.attn_layers.append(
79
+ MultiHeadAttention(
80
+ hidden_channels,
81
+ hidden_channels,
82
+ n_heads,
83
+ p_dropout=p_dropout,
84
+ window_size=window_size,
85
+ )
86
+ )
87
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
88
+ self.ffn_layers.append(
89
+ FFN(
90
+ hidden_channels,
91
+ hidden_channels,
92
+ filter_channels,
93
+ kernel_size,
94
+ p_dropout=p_dropout,
95
+ )
96
+ )
97
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
98
+
99
+ # def forward(self, x, x_mask, g=None):
100
+ # attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
101
+ # x = x * x_mask
102
+ # for i in range(self.n_layers):
103
+ # if i == self.cond_layer_idx and g is not None:
104
+ # g = self.spk_emb_linear(g.transpose(1, 2))
105
+ # g = g.transpose(1, 2)
106
+ # x = x + g
107
+ # x = x * x_mask
108
+ # y = self.attn_layers[i](x, x, attn_mask)
109
+ # y = self.drop(y)
110
+ # x = self.norm_layers_1[i](x + y)
111
+
112
+ # y = self.ffn_layers[i](x, x_mask)
113
+ # y = self.drop(y)
114
+ # x = self.norm_layers_2[i](x + y)
115
+ # x = x * x_mask
116
+ # return x
117
+
118
+ def forward(self, x, x_mask):
119
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
120
+ x = x * x_mask
121
+ for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
122
+ self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
123
+ ):
124
+ y = attn_layers(x, x, attn_mask)
125
+ y = self.drop(y)
126
+ x = norm_layers_1(x + y)
127
+
128
+ y = ffn_layers(x, x_mask)
129
+ y = self.drop(y)
130
+ x = norm_layers_2(x + y)
131
+ x = x * x_mask
132
+ return x
133
+
134
+
135
+ class MultiHeadAttention(nn.Module):
136
+ def __init__(
137
+ self,
138
+ channels,
139
+ out_channels,
140
+ n_heads,
141
+ p_dropout=0.0,
142
+ window_size=None,
143
+ heads_share=True,
144
+ block_length=None,
145
+ proximal_bias=False,
146
+ proximal_init=False,
147
+ ):
148
+ super().__init__()
149
+ assert channels % n_heads == 0
150
+
151
+ self.channels = channels
152
+ self.out_channels = out_channels
153
+ self.n_heads = n_heads
154
+ self.p_dropout = p_dropout
155
+ self.window_size = window_size
156
+ self.heads_share = heads_share
157
+ self.block_length = block_length
158
+ self.proximal_bias = proximal_bias
159
+ self.proximal_init = proximal_init
160
+ self.attn = None
161
+
162
+ self.k_channels = channels // n_heads
163
+ self.conv_q = nn.Conv1d(channels, channels, 1)
164
+ self.conv_k = nn.Conv1d(channels, channels, 1)
165
+ self.conv_v = nn.Conv1d(channels, channels, 1)
166
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
167
+ self.drop = nn.Dropout(p_dropout)
168
+
169
+ if window_size is not None:
170
+ n_heads_rel = 1 if heads_share else n_heads
171
+ rel_stddev = self.k_channels**-0.5
172
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
173
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
174
+
175
+ nn.init.xavier_uniform_(self.conv_q.weight)
176
+ nn.init.xavier_uniform_(self.conv_k.weight)
177
+ nn.init.xavier_uniform_(self.conv_v.weight)
178
+ if proximal_init:
179
+ with torch.no_grad():
180
+ self.conv_k.weight.copy_(self.conv_q.weight)
181
+ self.conv_k.bias.copy_(self.conv_q.bias)
182
+
183
+ def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
184
+ q = self.conv_q(x)
185
+ k = self.conv_k(c)
186
+ v = self.conv_v(c)
187
+
188
+ # x, self.attn = self.attention(q, k, v, mask=attn_mask)
189
+ x, _ = self.attention(q, k, v, mask=attn_mask)
190
+
191
+ x = self.conv_o(x)
192
+ return x
193
+
194
+ def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
195
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
196
+ b, d, t_s, _ = (*key.size(), query.size(2))
197
+ query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
198
+ key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
199
+ value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
200
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
201
+
202
+ if self.window_size is not None:
203
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
204
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
205
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
206
+ scores = scores + scores_local
207
+
208
+ if mask is not None:
209
+ scores = scores.masked_fill(mask == 0, -1e4)
210
+
211
+ p_attn = F.softmax(scores, dim=-1)
212
+ p_attn = self.drop(p_attn)
213
+ output = torch.matmul(p_attn, value)
214
+
215
+ if self.window_size is not None:
216
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
217
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
218
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
219
+
220
+ output = output.transpose(2, 3).contiguous().view(b, d, -1)
221
+ return output, p_attn
222
+
223
+ def _matmul_with_relative_values(self, x, y):
224
+ """
225
+ x: [b, h, l, m]
226
+ y: [h or 1, m, d]
227
+ ret: [b, h, l, d]
228
+ """
229
+ ret = torch.matmul(x, y.unsqueeze(0))
230
+ return ret
231
+
232
+ def _matmul_with_relative_keys(self, x, y):
233
+ """
234
+ x: [b, h, l, d]
235
+ y: [h or 1, m, d]
236
+ ret: [b, h, l, m]
237
+ """
238
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
239
+ return ret
240
+
241
+ def _get_relative_embeddings(self, relative_embeddings, length):
242
+ max_relative_position = 2 * self.window_size + 1
243
+ # Pad first before slice to avoid using cond ops.
244
+ pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
245
+ pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
246
+ pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
247
+ slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
248
+
249
+ slice_end_position = slice_start_position + 2 * length - 1
250
+ padded_relative_embeddings = F.pad(
251
+ relative_embeddings,
252
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
253
+ )
254
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
255
+ return used_relative_embeddings
256
+
257
+ def _relative_position_to_absolute_position(self, x):
258
+ """
259
+ x: [b, h, l, 2*l-1]
260
+ ret: [b, h, l, l]
261
+ """
262
+ batch, heads, length, _ = x.size()
263
+ # Concat columns of pad to shift from relative to absolute indexing.
264
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
265
+
266
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
267
+ x_flat = x.view([batch, heads, length * 2 * length])
268
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
269
+
270
+ # Reshape and slice out the padded elements.
271
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
272
+ return x_final
273
+
274
+ def _absolute_position_to_relative_position(self, x):
275
+ """
276
+ x: [b, h, l, l]
277
+ ret: [b, h, l, 2*l-1]
278
+ """
279
+ batch, heads, length, _ = x.size()
280
+ # padd along column
281
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
282
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
283
+ # add 0's in the beginning that will skew the elements after reshape
284
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
285
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
286
+ return x_final
287
+
288
+ def _attention_bias_proximal(self, length):
289
+ """Bias for self-attention to encourage attention to close positions.
290
+ Args:
291
+ length: an integer scalar.
292
+ Returns:
293
+ a Tensor with shape [1, 1, length, length]
294
+ """
295
+ r = torch.arange(length, dtype=torch.float32)
296
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
297
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
298
+
299
+
300
+ class FFN(nn.Module):
301
+ def __init__(
302
+ self,
303
+ in_channels,
304
+ out_channels,
305
+ filter_channels,
306
+ kernel_size,
307
+ p_dropout=0.0,
308
+ activation="",
309
+ causal=False,
310
+ ):
311
+ super().__init__()
312
+ self.in_channels = in_channels
313
+ self.out_channels = out_channels
314
+ self.filter_channels = filter_channels
315
+ self.kernel_size = kernel_size
316
+ self.p_dropout = p_dropout
317
+ self.activation = activation
318
+ self.causal = causal
319
+
320
+ # 从上下文看这里一定是 False
321
+ # if causal:
322
+ # self.padding = self._causal_padding
323
+ # else:
324
+ # self.padding = self._same_padding
325
+
326
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
327
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
328
+ self.drop = nn.Dropout(p_dropout)
329
+
330
+ def forward(self, x, x_mask):
331
+ x = self.conv_1(self.padding(x * x_mask))
332
+ if self.activation == "gelu":
333
+ x = x * torch.sigmoid(1.702 * x)
334
+ else:
335
+ x = torch.relu(x)
336
+ x = self.drop(x)
337
+ x = self.conv_2(self.padding(x * x_mask))
338
+ return x * x_mask
339
+
340
+ def padding(self, x):
341
+ return self._same_padding(x)
342
+
343
+ def _causal_padding(self, x):
344
+ if self.kernel_size == 1:
345
+ return x
346
+ pad_l = self.kernel_size - 1
347
+ pad_r = 0
348
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
349
+ x = F.pad(x, commons.convert_pad_shape(padding))
350
+ return x
351
+
352
+ def _same_padding(self, x):
353
+ if self.kernel_size == 1:
354
+ return x
355
+ pad_l = (self.kernel_size - 1) // 2
356
+ pad_r = self.kernel_size // 2
357
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
358
+ x = F.pad(x, commons.convert_pad_shape(padding))
359
+ return x
360
+
361
+
362
+ class MRTE(nn.Module):
363
+ def __init__(
364
+ self,
365
+ content_enc_channels=192,
366
+ hidden_size=512,
367
+ out_channels=192,
368
+ kernel_size=5,
369
+ n_heads=4,
370
+ ge_layer=2,
371
+ ):
372
+ super(MRTE, self).__init__()
373
+ self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
374
+ self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
375
+ self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
376
+ self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
377
+
378
+ def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
379
+ attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
380
+
381
+ ssl_enc = self.c_pre(ssl_enc * ssl_mask)
382
+ text_enc = self.text_pre(text * text_mask)
383
+ x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
384
+ x = self.c_post(x * ssl_mask)
385
+ return x
GPT_SoVITS/module/commons.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ # def convert_pad_shape(pad_shape):
17
+ # l = pad_shape[::-1]
18
+ # pad_shape = [item for sublist in l for item in sublist]
19
+ # return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
32
+ return kl
33
+
34
+
35
+ def rand_gumbel(shape):
36
+ """Sample from the Gumbel distribution, protect from overflows."""
37
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
38
+ return -torch.log(-torch.log(uniform_samples))
39
+
40
+
41
+ def rand_gumbel_like(x):
42
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
43
+ return g
44
+
45
+
46
+ def slice_segments(x, ids_str, segment_size=4):
47
+ ret = torch.zeros_like(x[:, :, :segment_size])
48
+ for i in range(x.size(0)):
49
+ idx_str = ids_str[i]
50
+ idx_end = idx_str + segment_size
51
+ ret[i] = x[i, :, idx_str:idx_end]
52
+ return ret
53
+
54
+
55
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
56
+ b, d, t = x.size()
57
+ if x_lengths is None:
58
+ x_lengths = t
59
+ ids_str_max = x_lengths - segment_size + 1
60
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
61
+ ret = slice_segments(x, ids_str, segment_size)
62
+ return ret, ids_str
63
+
64
+
65
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
66
+ position = torch.arange(length, dtype=torch.float)
67
+ num_timescales = channels // 2
68
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
69
+ inv_timescales = min_timescale * torch.exp(
70
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
71
+ )
72
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
73
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
74
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
75
+ signal = signal.view(1, channels, length)
76
+ return signal
77
+
78
+
79
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
80
+ b, channels, length = x.size()
81
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
82
+ return x + signal.to(dtype=x.dtype, device=x.device)
83
+
84
+
85
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
86
+ b, channels, length = x.size()
87
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
89
+
90
+
91
+ def subsequent_mask(length):
92
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
93
+ return mask
94
+
95
+
96
+ @torch.jit.script
97
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
98
+ n_channels_int = n_channels[0]
99
+ in_act = input_a + input_b
100
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
101
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
102
+ acts = t_act * s_act
103
+ return acts
104
+
105
+
106
+ def convert_pad_shape(pad_shape):
107
+ l = pad_shape[::-1]
108
+ pad_shape = [item for sublist in l for item in sublist]
109
+ return pad_shape
110
+
111
+
112
+ def shift_1d(x):
113
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
114
+ return x
115
+
116
+
117
+ def sequence_mask(length, max_length=None):
118
+ if max_length is None:
119
+ max_length = length.max()
120
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
121
+ return x.unsqueeze(0) < length.unsqueeze(1)
122
+
123
+
124
+ def generate_path(duration, mask):
125
+ """
126
+ duration: [b, 1, t_x]
127
+ mask: [b, 1, t_y, t_x]
128
+ """
129
+ device = duration.device
130
+
131
+ b, _, t_y, t_x = mask.shape
132
+ cum_duration = torch.cumsum(duration, -1)
133
+
134
+ cum_duration_flat = cum_duration.view(b * t_x)
135
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
136
+ path = path.view(b, t_x, t_y)
137
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
138
+ path = path.unsqueeze(1).transpose(2, 3) * mask
139
+ return path
140
+
141
+
142
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
143
+ if isinstance(parameters, torch.Tensor):
144
+ parameters = [parameters]
145
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
146
+ norm_type = float(norm_type)
147
+ if clip_value is not None:
148
+ clip_value = float(clip_value)
149
+
150
+ total_norm = 0
151
+ for p in parameters:
152
+ param_norm = p.grad.data.norm(norm_type)
153
+ total_norm += param_norm.item() ** norm_type
154
+ if clip_value is not None:
155
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
156
+ total_norm = total_norm ** (1.0 / norm_type)
157
+ return total_norm
158
+
159
+
160
+ def squeeze(x, x_mask=None, n_sqz=2):
161
+ b, c, t = x.size()
162
+
163
+ t = (t // n_sqz) * n_sqz
164
+ x = x[:, :, :t]
165
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
166
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
167
+
168
+ if x_mask is not None:
169
+ x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
170
+ else:
171
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
172
+ return x_sqz * x_mask, x_mask
173
+
174
+
175
+ def unsqueeze(x, x_mask=None, n_sqz=2):
176
+ b, c, t = x.size()
177
+
178
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
179
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
180
+
181
+ if x_mask is not None:
182
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
183
+ else:
184
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
185
+ return x_unsqz * x_mask, x_mask
GPT_SoVITS/module/core_vq.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+ from tqdm import tqdm
41
+
42
+
43
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
44
+ return val if val is not None else d
45
+
46
+
47
+ def ema_inplace(moving_avg, new, decay: float):
48
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
49
+
50
+
51
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
52
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
53
+
54
+
55
+ def uniform_init(*shape: int):
56
+ t = torch.empty(shape)
57
+ nn.init.kaiming_uniform_(t)
58
+ return t
59
+
60
+
61
+ def sample_vectors(samples, num: int):
62
+ num_samples, device = samples.shape[0], samples.device
63
+
64
+ if num_samples >= num:
65
+ indices = torch.randperm(num_samples, device=device)[:num]
66
+ else:
67
+ indices = torch.randint(0, num_samples, (num,), device=device)
68
+
69
+ return samples[indices]
70
+
71
+
72
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
73
+ dim, dtype = samples.shape[-1], samples.dtype
74
+ max_kmeans_samples = 500
75
+ samples = samples[:max_kmeans_samples, :]
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ print("kmeans start ... ")
79
+ for _ in tqdm(range(num_iters)):
80
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
81
+ dists = -(diffs**2).sum(dim=-1)
82
+
83
+ buckets = dists.max(dim=-1).indices
84
+ bins = torch.bincount(buckets, minlength=num_clusters)
85
+ zero_mask = bins == 0
86
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
87
+
88
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
89
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
90
+ new_means = new_means / bins_min_clamped[..., None]
91
+
92
+ means = torch.where(zero_mask[..., None], means, new_means)
93
+
94
+ return means, bins
95
+
96
+
97
+ class EuclideanCodebook(nn.Module):
98
+ """Codebook with Euclidean distance.
99
+ Args:
100
+ dim (int): Dimension.
101
+ codebook_size (int): Codebook size.
102
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
103
+ If set to true, run the k-means algorithm on the first training batch and use
104
+ the learned centroids as initialization.
105
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
106
+ decay (float): Decay for exponential moving average over the codebooks.
107
+ epsilon (float): Epsilon value for numerical stability.
108
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
109
+ that have an exponential moving average cluster size less than the specified threshold with
110
+ randomly selected vector from the current batch.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ codebook_size: int,
117
+ kmeans_init: int = False,
118
+ kmeans_iters: int = 10,
119
+ decay: float = 0.99,
120
+ epsilon: float = 1e-5,
121
+ threshold_ema_dead_code: int = 2,
122
+ ):
123
+ super().__init__()
124
+ self.decay = decay
125
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
126
+ embed = init_fn(codebook_size, dim)
127
+
128
+ self.codebook_size = codebook_size
129
+
130
+ self.kmeans_iters = kmeans_iters
131
+ self.epsilon = epsilon
132
+ self.threshold_ema_dead_code = threshold_ema_dead_code
133
+
134
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
135
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
136
+ self.register_buffer("embed", embed)
137
+ self.register_buffer("embed_avg", embed.clone())
138
+
139
+ @torch.jit.ignore
140
+ def init_embed_(self, data):
141
+ if self.inited:
142
+ return
143
+
144
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
145
+ self.embed.data.copy_(embed)
146
+ self.embed_avg.data.copy_(embed.clone())
147
+ self.cluster_size.data.copy_(cluster_size)
148
+ self.inited.data.copy_(torch.Tensor([True]))
149
+ # Make sure all buffers across workers are in sync after initialization
150
+ # broadcast_tensors(self.buffers())
151
+
152
+ def replace_(self, samples, mask):
153
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
154
+ self.embed.data.copy_(modified_codebook)
155
+
156
+ def expire_codes_(self, batch_samples):
157
+ if self.threshold_ema_dead_code == 0:
158
+ return
159
+
160
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
161
+ if not torch.any(expired_codes):
162
+ return
163
+
164
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
165
+ self.replace_(batch_samples, mask=expired_codes)
166
+ # broadcast_tensors(self.buffers())
167
+
168
+ def preprocess(self, x):
169
+ x = rearrange(x, "... d -> (...) d")
170
+ return x
171
+
172
+ def quantize(self, x):
173
+ embed = self.embed.t()
174
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
175
+ embed_ind = dist.max(dim=-1).indices
176
+ return embed_ind
177
+
178
+ def postprocess_emb(self, embed_ind, shape):
179
+ return embed_ind.view(*shape[:-1])
180
+
181
+ def dequantize(self, embed_ind):
182
+ quantize = F.embedding(embed_ind, self.embed)
183
+ return quantize
184
+
185
+ def encode(self, x):
186
+ shape = x.shape
187
+ # pre-process
188
+ x = self.preprocess(x)
189
+ # quantize
190
+ embed_ind = self.quantize(x)
191
+ # post-process
192
+ embed_ind = self.postprocess_emb(embed_ind, shape)
193
+ return embed_ind
194
+
195
+ def decode(self, embed_ind):
196
+ quantize = self.dequantize(embed_ind)
197
+ return quantize
198
+
199
+ def forward(self, x):
200
+ shape, dtype = x.shape, x.dtype
201
+ x = self.preprocess(x)
202
+
203
+ self.init_embed_(x)
204
+
205
+ embed_ind = self.quantize(x)
206
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
207
+ embed_ind = self.postprocess_emb(embed_ind, shape)
208
+ quantize = self.dequantize(embed_ind)
209
+
210
+ if self.training:
211
+ # We do the expiry of code at that point as buffers are in sync
212
+ # and all the workers will take the same decision.
213
+ self.expire_codes_(x)
214
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
215
+ embed_sum = x.t() @ embed_onehot
216
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
217
+ cluster_size = (
218
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
219
+ )
220
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
221
+ self.embed.data.copy_(embed_normalized)
222
+
223
+ return quantize, embed_ind
224
+
225
+
226
+ class VectorQuantization(nn.Module):
227
+ """Vector quantization implementation.
228
+ Currently supports only euclidean distance.
229
+ Args:
230
+ dim (int): Dimension
231
+ codebook_size (int): Codebook size
232
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
233
+ decay (float): Decay for exponential moving average over the codebooks.
234
+ epsilon (float): Epsilon value for numerical stability.
235
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
236
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
237
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
238
+ that have an exponential moving average cluster size less than the specified threshold with
239
+ randomly selected vector from the current batch.
240
+ commitment_weight (float): Weight for commitment loss.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ dim: int,
246
+ codebook_size: int,
247
+ codebook_dim: tp.Optional[int] = None,
248
+ decay: float = 0.99,
249
+ epsilon: float = 1e-5,
250
+ kmeans_init: bool = True,
251
+ kmeans_iters: int = 50,
252
+ threshold_ema_dead_code: int = 2,
253
+ commitment_weight: float = 1.0,
254
+ ):
255
+ super().__init__()
256
+ _codebook_dim: int = default(codebook_dim, dim)
257
+
258
+ requires_projection = _codebook_dim != dim
259
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
260
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
261
+
262
+ self.epsilon = epsilon
263
+ self.commitment_weight = commitment_weight
264
+
265
+ self._codebook = EuclideanCodebook(
266
+ dim=_codebook_dim,
267
+ codebook_size=codebook_size,
268
+ kmeans_init=kmeans_init,
269
+ kmeans_iters=kmeans_iters,
270
+ decay=decay,
271
+ epsilon=epsilon,
272
+ threshold_ema_dead_code=threshold_ema_dead_code,
273
+ )
274
+ self.codebook_size = codebook_size
275
+
276
+ @property
277
+ def codebook(self):
278
+ return self._codebook.embed
279
+
280
+ def encode(self, x):
281
+ x = rearrange(x, "b d n -> b n d")
282
+ x = self.project_in(x)
283
+ embed_in = self._codebook.encode(x)
284
+ return embed_in
285
+
286
+ def decode(self, embed_ind):
287
+ quantize = self._codebook.decode(embed_ind)
288
+ quantize = self.project_out(quantize)
289
+ quantize = rearrange(quantize, "b n d -> b d n")
290
+ return quantize
291
+
292
+ def forward(self, x):
293
+ device = x.device
294
+ x = rearrange(x, "b d n -> b n d")
295
+ x = self.project_in(x)
296
+
297
+ quantize, embed_ind = self._codebook(x)
298
+
299
+ if self.training:
300
+ quantize = x + (quantize - x).detach()
301
+
302
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
303
+
304
+ if self.training:
305
+ if self.commitment_weight > 0:
306
+ commit_loss = F.mse_loss(quantize.detach(), x)
307
+ loss = loss + commit_loss * self.commitment_weight
308
+
309
+ quantize = self.project_out(quantize)
310
+ quantize = rearrange(quantize, "b n d -> b d n")
311
+ return quantize, embed_ind, loss
312
+
313
+
314
+ class ResidualVectorQuantization(nn.Module):
315
+ """Residual vector quantization implementation.
316
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
317
+ """
318
+
319
+ def __init__(self, *, num_quantizers, **kwargs):
320
+ super().__init__()
321
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
322
+
323
+ def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
324
+ quantized_out = 0.0
325
+ residual = x
326
+
327
+ all_losses = []
328
+ all_indices = []
329
+ out_quantized = []
330
+
331
+ n_q = n_q or len(self.layers)
332
+
333
+ for i, layer in enumerate(self.layers[:n_q]):
334
+ quantized, indices, loss = layer(residual)
335
+ residual = residual - quantized
336
+ quantized_out = quantized_out + quantized
337
+
338
+ all_indices.append(indices)
339
+ all_losses.append(loss)
340
+ if layers and i in layers:
341
+ out_quantized.append(quantized)
342
+
343
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
344
+ return quantized_out, out_indices, out_losses, out_quantized
345
+
346
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
347
+ residual = x
348
+ all_indices = []
349
+ n_q = n_q or len(self.layers)
350
+ st = st or 0
351
+ for layer in self.layers[st:n_q]:
352
+ indices = layer.encode(residual)
353
+ quantized = layer.decode(indices)
354
+ residual = residual - quantized
355
+ all_indices.append(indices)
356
+ out_indices = torch.stack(all_indices)
357
+ return out_indices
358
+
359
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
360
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
361
+ for i, indices in enumerate(q_indices):
362
+ layer = self.layers[st + i]
363
+ quantized = layer.decode(indices)
364
+ quantized_out = quantized_out + quantized
365
+ return quantized_out
GPT_SoVITS/module/data_utils.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import traceback
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ from tqdm import tqdm
9
+
10
+ from GPT_SoVITS.text import cleaned_text_to_sequence
11
+ from tools.my_utils import load_audio
12
+
13
+ from .mel_processing import spec_to_mel_torch, spectrogram_torch
14
+
15
+ version = os.environ.get("version", None)
16
+
17
+
18
+ # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
19
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
20
+ """
21
+ 1) loads audio, speaker_id, text pairs
22
+ 2) normalizes text and converts them to sequences of integers
23
+ 3) computes spectrograms from audio files.
24
+ """
25
+
26
+ def __init__(self, hparams, version=None, val=False):
27
+ exp_dir = hparams.exp_dir
28
+ self.path2 = "%s/2-name2text.txt" % exp_dir
29
+ self.path4 = "%s/4-cnhubert" % exp_dir
30
+ self.path5 = "%s/5-wav32k" % exp_dir
31
+ assert os.path.exists(self.path2)
32
+ assert os.path.exists(self.path4)
33
+ assert os.path.exists(self.path5)
34
+ self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
35
+ if self.is_v2Pro:
36
+ self.path7 = "%s/7-sv_cn" % exp_dir
37
+ assert os.path.exists(self.path7)
38
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
39
+ names5 = set(os.listdir(self.path5))
40
+ if self.is_v2Pro:
41
+ names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
42
+ self.phoneme_data = {}
43
+ with open(self.path2, "r", encoding="utf8") as f:
44
+ lines = f.read().strip("\n").split("\n")
45
+
46
+ for line in lines:
47
+ tmp = line.split("\t")
48
+ if len(tmp) != 4:
49
+ continue
50
+ self.phoneme_data[tmp[0]] = [tmp[1]]
51
+ if self.is_v2Pro:
52
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
53
+ else:
54
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
55
+ tmp = self.audiopaths_sid_text
56
+ leng = len(tmp)
57
+ min_num = 100
58
+ if leng < min_num:
59
+ self.audiopaths_sid_text = []
60
+ for _ in range(max(2, int(min_num / leng))):
61
+ self.audiopaths_sid_text += tmp
62
+ self.max_wav_value = hparams.max_wav_value
63
+ self.sampling_rate = hparams.sampling_rate
64
+ self.filter_length = hparams.filter_length
65
+ self.hop_length = hparams.hop_length
66
+ self.win_length = hparams.win_length
67
+ self.sampling_rate = hparams.sampling_rate
68
+ self.val = val
69
+
70
+ random.seed(1234)
71
+ random.shuffle(self.audiopaths_sid_text)
72
+
73
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
74
+ print("wav_data_len:", len(self.audiopaths_sid_text))
75
+
76
+ audiopaths_sid_text_new = []
77
+ lengths = []
78
+ skipped_phone = 0
79
+ skipped_dur = 0
80
+ for audiopath in tqdm(self.audiopaths_sid_text):
81
+ try:
82
+ phoneme = self.phoneme_data[audiopath][0]
83
+ phoneme = phoneme.split(" ")
84
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
85
+ except Exception:
86
+ print(f"{audiopath} not in self.phoneme_data !")
87
+ skipped_phone += 1
88
+ continue
89
+
90
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
91
+ duration = size / self.sampling_rate / 2
92
+
93
+ if duration == 0:
94
+ print(f"Zero duration for {audiopath}, skipping...")
95
+ skipped_dur += 1
96
+ continue
97
+
98
+ if 54 > duration > 0.6 or self.val:
99
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
100
+ lengths.append(size // (2 * self.hop_length))
101
+ else:
102
+ skipped_dur += 1
103
+ continue
104
+
105
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
106
+ print("total left: ", len(audiopaths_sid_text_new))
107
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
108
+ self.audiopaths_sid_text = audiopaths_sid_text_new
109
+ self.lengths = lengths
110
+
111
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
112
+ audiopath, phoneme_ids = audiopath_sid_text
113
+ text = torch.FloatTensor(phoneme_ids)
114
+ try:
115
+ spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
116
+ with torch.no_grad():
117
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
118
+ if ssl.shape[-1] != spec.shape[-1]:
119
+ typee = ssl.dtype
120
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
121
+ ssl.requires_grad = False
122
+ if self.is_v2Pro:
123
+ sv_emb = torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
124
+ except:
125
+ traceback.print_exc()
126
+ spec = torch.zeros(1025, 100)
127
+ wav = torch.zeros(1, 100 * self.hop_length)
128
+ ssl = torch.zeros(1, 768, 100)
129
+ text = text[-1:]
130
+ if self.is_v2Pro:
131
+ sv_emb = torch.zeros(1, 20480)
132
+ print("load audio or ssl error!!!!!!", audiopath)
133
+ if self.is_v2Pro:
134
+ return (ssl, spec, wav, text, sv_emb)
135
+ else:
136
+ return (ssl, spec, wav, text)
137
+
138
+ def get_audio(self, filename):
139
+ audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
140
+ audio = torch.FloatTensor(audio_array) # /32768
141
+ audio_norm = audio
142
+ audio_norm = audio_norm.unsqueeze(0)
143
+ spec = spectrogram_torch(
144
+ audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
145
+ )
146
+ spec = torch.squeeze(spec, 0)
147
+ return spec, audio_norm
148
+
149
+ def get_sid(self, sid):
150
+ sid = torch.LongTensor([int(sid)])
151
+ return sid
152
+
153
+ def __getitem__(self, index):
154
+ # with torch.no_grad():
155
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
156
+
157
+ def __len__(self):
158
+ return len(self.audiopaths_sid_text)
159
+
160
+ def random_slice(self, ssl, wav, mel):
161
+ assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
162
+
163
+ len_mel = mel.shape[1]
164
+ if self.val:
165
+ reference_mel = mel[:, : len_mel // 3]
166
+ return reference_mel, ssl, wav, mel
167
+ dir = random.randint(0, 1)
168
+ sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
169
+
170
+ if dir == 0:
171
+ reference_mel = mel[:, :sep_point]
172
+ ssl = ssl[:, :, sep_point:]
173
+ wav2 = wav[:, sep_point * self.hop_length :]
174
+ mel = mel[:, sep_point:]
175
+ else:
176
+ reference_mel = mel[:, sep_point:]
177
+ ssl = ssl[:, :, :sep_point]
178
+ wav2 = wav[:, : sep_point * self.hop_length]
179
+ mel = mel[:, :sep_point]
180
+
181
+ assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
182
+ ssl.shape,
183
+ wav.shape,
184
+ wav2.shape,
185
+ mel.shape,
186
+ sep_point,
187
+ self.hop_length,
188
+ sep_point * self.hop_length,
189
+ dir,
190
+ )
191
+ return reference_mel, ssl, wav2, mel
192
+
193
+
194
+ class TextAudioSpeakerCollate:
195
+ """Zero-pads model inputs and targets"""
196
+
197
+ def __init__(self, return_ids=False, version=None):
198
+ self.return_ids = return_ids
199
+ self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
200
+
201
+ def __call__(self, batch):
202
+ """Collate's training batch from normalized text, audio and speaker identities
203
+ PARAMS
204
+ ------
205
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
206
+ """
207
+ # Right zero-pad all one-hot text sequences to max input length
208
+ _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
209
+
210
+ max_ssl_len = max([x[0].size(2) for x in batch])
211
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
212
+ max_spec_len = max([x[1].size(1) for x in batch])
213
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
214
+ max_wav_len = max([x[2].size(1) for x in batch])
215
+ max_text_len = max([x[3].size(0) for x in batch])
216
+
217
+ ssl_lengths = torch.LongTensor(len(batch))
218
+ spec_lengths = torch.LongTensor(len(batch))
219
+ wav_lengths = torch.LongTensor(len(batch))
220
+ text_lengths = torch.LongTensor(len(batch))
221
+
222
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
223
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
224
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
225
+ text_padded = torch.LongTensor(len(batch), max_text_len)
226
+
227
+ spec_padded.zero_()
228
+ wav_padded.zero_()
229
+ ssl_padded.zero_()
230
+ text_padded.zero_()
231
+
232
+ if self.is_v2Pro:
233
+ sv_embs = torch.FloatTensor(len(batch), 20480)
234
+
235
+ for i in range(len(ids_sorted_decreasing)):
236
+ row = batch[ids_sorted_decreasing[i]]
237
+
238
+ ssl = row[0]
239
+ ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
240
+ ssl_lengths[i] = ssl.size(2)
241
+
242
+ spec = row[1]
243
+ spec_padded[i, :, : spec.size(1)] = spec
244
+ spec_lengths[i] = spec.size(1)
245
+
246
+ wav = row[2]
247
+ wav_padded[i, :, : wav.size(1)] = wav
248
+ wav_lengths[i] = wav.size(1)
249
+
250
+ text = row[3]
251
+ text_padded[i, : text.size(0)] = text
252
+ text_lengths[i] = text.size(0)
253
+
254
+ if self.is_v2Pro:
255
+ sv_embs[i] = row[4]
256
+ if self.is_v2Pro:
257
+ return (
258
+ ssl_padded,
259
+ ssl_lengths,
260
+ spec_padded,
261
+ spec_lengths,
262
+ wav_padded,
263
+ wav_lengths,
264
+ text_padded,
265
+ text_lengths,
266
+ sv_embs,
267
+ )
268
+ else:
269
+ return (
270
+ ssl_padded,
271
+ ssl_lengths,
272
+ spec_padded,
273
+ spec_lengths,
274
+ wav_padded,
275
+ wav_lengths,
276
+ text_padded,
277
+ text_lengths,
278
+ )
279
+
280
+
281
+ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
282
+ """
283
+ 1) loads audio, speaker_id, text pairs
284
+ 2) normalizes text and converts them to sequences of integers
285
+ 3) computes spectrograms from audio files.
286
+ """
287
+
288
+ def __init__(self, hparams, val=False):
289
+ exp_dir = hparams.exp_dir
290
+ self.path2 = "%s/2-name2text.txt" % exp_dir
291
+ self.path4 = "%s/4-cnhubert" % exp_dir
292
+ self.path5 = "%s/5-wav32k" % exp_dir
293
+ assert os.path.exists(self.path2)
294
+ assert os.path.exists(self.path4)
295
+ assert os.path.exists(self.path5)
296
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
297
+ names5 = set(os.listdir(self.path5))
298
+ self.phoneme_data = {}
299
+ with open(self.path2, "r", encoding="utf8") as f:
300
+ lines = f.read().strip("\n").split("\n")
301
+
302
+ for line in lines:
303
+ tmp = line.split("\t")
304
+ if len(tmp) != 4:
305
+ continue
306
+ self.phoneme_data[tmp[0]] = [tmp[1]]
307
+
308
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
309
+ tmp = self.audiopaths_sid_text
310
+ leng = len(tmp)
311
+ min_num = 100
312
+ if leng < min_num:
313
+ self.audiopaths_sid_text = []
314
+ for _ in range(max(2, int(min_num / leng))):
315
+ self.audiopaths_sid_text += tmp
316
+ self.max_wav_value = hparams.max_wav_value
317
+ self.sampling_rate = hparams.sampling_rate
318
+ self.filter_length = hparams.filter_length
319
+ self.hop_length = hparams.hop_length
320
+ self.win_length = hparams.win_length
321
+ self.sampling_rate = hparams.sampling_rate
322
+ self.val = val
323
+
324
+ random.seed(1234)
325
+ random.shuffle(self.audiopaths_sid_text)
326
+
327
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
328
+ print("wav_data_len:", len(self.audiopaths_sid_text))
329
+
330
+ audiopaths_sid_text_new = []
331
+ lengths = []
332
+ skipped_phone = 0
333
+ skipped_dur = 0
334
+ for audiopath in tqdm(self.audiopaths_sid_text):
335
+ try:
336
+ phoneme = self.phoneme_data[audiopath][0]
337
+ phoneme = phoneme.split(" ")
338
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
339
+ except Exception:
340
+ print(f"{audiopath} not in self.phoneme_data !")
341
+ skipped_phone += 1
342
+ continue
343
+
344
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
345
+ duration = size / self.sampling_rate / 2
346
+
347
+ if duration == 0:
348
+ print(f"Zero duration for {audiopath}, skipping...")
349
+ skipped_dur += 1
350
+ continue
351
+
352
+ if 54 > duration > 0.6 or self.val:
353
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
354
+ lengths.append(size // (2 * self.hop_length))
355
+ else:
356
+ skipped_dur += 1
357
+ continue
358
+
359
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
360
+ print("total left: ", len(audiopaths_sid_text_new))
361
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
362
+ self.audiopaths_sid_text = audiopaths_sid_text_new
363
+ self.lengths = lengths
364
+ self.spec_min = -12
365
+ self.spec_max = 2
366
+
367
+ self.filter_length_mel = self.win_length_mel = 1024
368
+ self.hop_length_mel = 256
369
+ self.n_mel_channels = 100
370
+ self.sampling_rate_mel = 24000
371
+ self.mel_fmin = 0
372
+ self.mel_fmax = None
373
+
374
+ def norm_spec(self, x):
375
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
376
+
377
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
378
+ audiopath, phoneme_ids = audiopath_sid_text
379
+ text = torch.FloatTensor(phoneme_ids)
380
+ try:
381
+ spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
382
+ with torch.no_grad():
383
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
384
+ if ssl.shape[-1] != spec.shape[-1]:
385
+ typee = ssl.dtype
386
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
387
+ ssl.requires_grad = False
388
+ except:
389
+ traceback.print_exc()
390
+ mel = torch.zeros(100, 180)
391
+ # wav = torch.zeros(1, 96 * self.hop_length)
392
+ spec = torch.zeros(1025, 96)
393
+ ssl = torch.zeros(1, 768, 96)
394
+ text = text[-1:]
395
+ print("load audio or ssl error!!!!!!", audiopath)
396
+ return (ssl, spec, mel, text)
397
+
398
+ def get_audio(self, filename):
399
+ audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
400
+ audio = torch.FloatTensor(audio_array) # /32768
401
+ audio_norm = audio
402
+ audio_norm = audio_norm.unsqueeze(0)
403
+ audio_array24 = load_audio(
404
+ filename, 24000
405
+ ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
406
+ audio24 = torch.FloatTensor(audio_array24) # /32768
407
+ audio_norm24 = audio24
408
+ audio_norm24 = audio_norm24.unsqueeze(0)
409
+
410
+ spec = spectrogram_torch(
411
+ audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
412
+ )
413
+ spec = torch.squeeze(spec, 0)
414
+
415
+ spec1 = spectrogram_torch(
416
+ audio_norm24,
417
+ self.filter_length_mel,
418
+ self.sampling_rate_mel,
419
+ self.hop_length_mel,
420
+ self.win_length_mel,
421
+ center=False,
422
+ )
423
+ mel = spec_to_mel_torch(
424
+ spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
425
+ )
426
+ mel = torch.squeeze(mel, 0)
427
+ mel = self.norm_spec(mel)
428
+ # print(1111111,spec.shape,mel.shape)
429
+ return spec, mel
430
+
431
+ def get_sid(self, sid):
432
+ sid = torch.LongTensor([int(sid)])
433
+ return sid
434
+
435
+ def __getitem__(self, index):
436
+ # with torch.no_grad():
437
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
438
+
439
+ def __len__(self):
440
+ return len(self.audiopaths_sid_text)
441
+
442
+
443
+ class TextAudioSpeakerCollateV3:
444
+ """Zero-pads model inputs and targets"""
445
+
446
+ def __init__(self, return_ids=False):
447
+ self.return_ids = return_ids
448
+
449
+ def __call__(self, batch):
450
+ """Collate's training batch from normalized text, audio and speaker identities
451
+ PARAMS
452
+ ------
453
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
454
+ """
455
+ # ssl, spec, wav,mel, text
456
+ # Right zero-pad all one-hot text sequences to max input length
457
+ _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
458
+ # (ssl, spec,mel, text)
459
+ max_ssl_len = max([x[0].size(2) for x in batch])
460
+
461
+ max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
462
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
463
+
464
+ # max_ssl_len = int(8 * ((max_ssl_len // 8) + 1))
465
+ # max_ssl_len1=max_ssl_len
466
+
467
+ max_spec_len = max([x[1].size(1) for x in batch])
468
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
469
+ # max_wav_len = max([x[2].size(1) for x in batch])
470
+
471
+ max_text_len = max([x[3].size(0) for x in batch])
472
+ max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
473
+
474
+ ssl_lengths = torch.LongTensor(len(batch))
475
+ spec_lengths = torch.LongTensor(len(batch))
476
+ text_lengths = torch.LongTensor(len(batch))
477
+ # wav_lengths = torch.LongTensor(len(batch))
478
+ mel_lengths = torch.LongTensor(len(batch))
479
+
480
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
481
+ mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len)
482
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
483
+ text_padded = torch.LongTensor(len(batch), max_text_len)
484
+ # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
485
+
486
+ spec_padded.zero_()
487
+ mel_padded.zero_()
488
+ ssl_padded.zero_()
489
+ text_padded.zero_()
490
+ # wav_padded.zero_()
491
+
492
+ for i in range(len(ids_sorted_decreasing)):
493
+ row = batch[ids_sorted_decreasing[i]]
494
+ # ssl, spec, wav,mel, text
495
+ ssl = row[0]
496
+ ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
497
+ ssl_lengths[i] = ssl.size(2)
498
+
499
+ spec = row[1]
500
+ spec_padded[i, :, : spec.size(1)] = spec
501
+ spec_lengths[i] = spec.size(1)
502
+
503
+ # wav = row[2]
504
+ # wav_padded[i, :, :wav.size(1)] = wav
505
+ # wav_lengths[i] = wav.size(1)
506
+
507
+ mel = row[2]
508
+ mel_padded[i, :, : mel.size(1)] = mel
509
+ mel_lengths[i] = mel.size(1)
510
+
511
+ text = row[3]
512
+ text_padded[i, : text.size(0)] = text
513
+ text_lengths[i] = text.size(0)
514
+
515
+ # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
516
+ return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
517
+
518
+
519
+ class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
520
+ """
521
+ 1) loads audio, speaker_id, text pairs
522
+ 2) normalizes text and converts them to sequences of integers
523
+ 3) computes spectrograms from audio files.
524
+ """
525
+
526
+ def __init__(self, hparams, val=False):
527
+ exp_dir = hparams.exp_dir
528
+ self.path2 = "%s/2-name2text.txt" % exp_dir
529
+ self.path4 = "%s/4-cnhubert" % exp_dir
530
+ self.path5 = "%s/5-wav32k" % exp_dir
531
+ assert os.path.exists(self.path2)
532
+ assert os.path.exists(self.path4)
533
+ assert os.path.exists(self.path5)
534
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
535
+ names5 = set(os.listdir(self.path5))
536
+ self.phoneme_data = {}
537
+ with open(self.path2, "r", encoding="utf8") as f:
538
+ lines = f.read().strip("\n").split("\n")
539
+
540
+ for line in lines:
541
+ tmp = line.split("\t")
542
+ if len(tmp) != 4:
543
+ continue
544
+ self.phoneme_data[tmp[0]] = [tmp[1]]
545
+
546
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
547
+ tmp = self.audiopaths_sid_text
548
+ leng = len(tmp)
549
+ min_num = 100
550
+ if leng < min_num:
551
+ self.audiopaths_sid_text = []
552
+ for _ in range(max(2, int(min_num / leng))):
553
+ self.audiopaths_sid_text += tmp
554
+ self.max_wav_value = hparams.max_wav_value
555
+ self.sampling_rate = hparams.sampling_rate
556
+ self.filter_length = hparams.filter_length
557
+ self.hop_length = hparams.hop_length
558
+ self.win_length = hparams.win_length
559
+ self.sampling_rate = hparams.sampling_rate
560
+ self.val = val
561
+
562
+ random.seed(1234)
563
+ random.shuffle(self.audiopaths_sid_text)
564
+
565
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
566
+ print("wav_data_len:", len(self.audiopaths_sid_text))
567
+
568
+ audiopaths_sid_text_new = []
569
+ lengths = []
570
+ skipped_phone = 0
571
+ skipped_dur = 0
572
+ for audiopath in tqdm(self.audiopaths_sid_text):
573
+ try:
574
+ phoneme = self.phoneme_data[audiopath][0]
575
+ phoneme = phoneme.split(" ")
576
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
577
+ except Exception:
578
+ print(f"{audiopath} not in self.phoneme_data !")
579
+ skipped_phone += 1
580
+ continue
581
+
582
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
583
+ duration = size / self.sampling_rate / 2
584
+
585
+ if duration == 0:
586
+ print(f"Zero duration for {audiopath}, skipping...")
587
+ skipped_dur += 1
588
+ continue
589
+
590
+ if 54 > duration > 0.6 or self.val:
591
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
592
+ lengths.append(size // (2 * self.hop_length))
593
+ else:
594
+ skipped_dur += 1
595
+ continue
596
+
597
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
598
+ print("total left: ", len(audiopaths_sid_text_new))
599
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
600
+ self.audiopaths_sid_text = audiopaths_sid_text_new
601
+ self.lengths = lengths
602
+ self.spec_min = -12
603
+ self.spec_max = 2
604
+
605
+ self.filter_length_mel = self.win_length_mel = 1280
606
+ self.hop_length_mel = 320
607
+ self.n_mel_channels = 100
608
+ self.sampling_rate_mel = 32000
609
+ self.mel_fmin = 0
610
+ self.mel_fmax = None
611
+
612
+ def norm_spec(self, x):
613
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
614
+
615
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
616
+ audiopath, phoneme_ids = audiopath_sid_text
617
+ text = torch.FloatTensor(phoneme_ids)
618
+ try:
619
+ spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
620
+ with torch.no_grad():
621
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
622
+ if ssl.shape[-1] != spec.shape[-1]:
623
+ typee = ssl.dtype
624
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
625
+ ssl.requires_grad = False
626
+ except:
627
+ traceback.print_exc()
628
+ mel = torch.zeros(100, 192)
629
+ # wav = torch.zeros(1, 96 * self.hop_length)
630
+ spec = torch.zeros(1025, 96)
631
+ ssl = torch.zeros(1, 768, 96)
632
+ text = text[-1:]
633
+ print("load audio or ssl error!!!!!!", audiopath)
634
+ return (ssl, spec, mel, text)
635
+
636
+ def get_audio(self, filename):
637
+ audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
638
+ audio = torch.FloatTensor(audio_array) # /32768
639
+ audio_norm = audio
640
+ audio_norm = audio_norm.unsqueeze(0)
641
+ spec = spectrogram_torch(
642
+ audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
643
+ )
644
+ spec = torch.squeeze(spec, 0)
645
+ spec1 = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False)
646
+ mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None)
647
+ mel = self.norm_spec(torch.squeeze(mel, 0))
648
+ return spec, mel
649
+
650
+ def get_sid(self, sid):
651
+ sid = torch.LongTensor([int(sid)])
652
+ return sid
653
+
654
+ def __getitem__(self, index):
655
+ # with torch.no_grad():
656
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
657
+
658
+ def __len__(self):
659
+ return len(self.audiopaths_sid_text)
660
+
661
+
662
+ class TextAudioSpeakerCollateV4:
663
+ """Zero-pads model inputs and targets"""
664
+
665
+ def __init__(self, return_ids=False):
666
+ self.return_ids = return_ids
667
+
668
+ def __call__(self, batch):
669
+ """Collate's training batch from normalized text, audio and speaker identities
670
+ PARAMS
671
+ ------
672
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
673
+ """
674
+ # ssl, spec, wav,mel, text
675
+ # Right zero-pad all one-hot text sequences to max input length
676
+ _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
677
+ # (ssl, spec,mel, text)
678
+ max_ssl_len = max([x[0].size(2) for x in batch])
679
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
680
+ max_spec_len = max([x[1].size(1) for x in batch])
681
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
682
+ # max_wav_len = max([x[2].size(1) for x in batch])
683
+ max_text_len = max([x[3].size(0) for x in batch])
684
+
685
+ ssl_lengths = torch.LongTensor(len(batch))
686
+ spec_lengths = torch.LongTensor(len(batch))
687
+ text_lengths = torch.LongTensor(len(batch))
688
+ # wav_lengths = torch.LongTensor(len(batch))
689
+ mel_lengths = torch.LongTensor(len(batch))
690
+
691
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
692
+ mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len * 2)
693
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
694
+ text_padded = torch.LongTensor(len(batch), max_text_len)
695
+ # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
696
+
697
+ spec_padded.zero_()
698
+ mel_padded.zero_()
699
+ ssl_padded.zero_()
700
+ text_padded.zero_()
701
+ # wav_padded.zero_()
702
+
703
+ for i in range(len(ids_sorted_decreasing)):
704
+ row = batch[ids_sorted_decreasing[i]]
705
+ # ssl, spec, wav,mel, text
706
+ ssl = row[0]
707
+ ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
708
+ ssl_lengths[i] = ssl.size(2)
709
+
710
+ spec = row[1]
711
+ spec_padded[i, :, : spec.size(1)] = spec
712
+ spec_lengths[i] = spec.size(1)
713
+
714
+ # wav = row[2]
715
+ # wav_padded[i, :, :wav.size(1)] = wav
716
+ # wav_lengths[i] = wav.size(1)
717
+
718
+ mel = row[2]
719
+ mel_padded[i, :, : mel.size(1)] = mel
720
+ mel_lengths[i] = mel.size(1)
721
+
722
+ text = row[3]
723
+ text_padded[i, : text.size(0)] = text
724
+ text_lengths[i] = text.size(0)
725
+
726
+ # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
727
+ return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
728
+
729
+
730
+ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
731
+ """
732
+ 1) loads audio, speaker_id, text pairs
733
+ 2) normalizes text and converts them to sequences of integers
734
+ 3) computes spectrograms from audio files.
735
+ """
736
+
737
+ def __init__(self, hparams, val=False):
738
+ exp_dir = hparams.exp_dir
739
+ self.path2 = "%s/2-name2text.txt" % exp_dir
740
+ self.path4 = "%s/4-cnhubert" % exp_dir
741
+ self.path5 = "%s/5-wav32k" % exp_dir
742
+ assert os.path.exists(self.path2)
743
+ assert os.path.exists(self.path4)
744
+ assert os.path.exists(self.path5)
745
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
746
+ names5 = set(os.listdir(self.path5))
747
+ self.phoneme_data = {}
748
+ with open(self.path2, "r", encoding="utf8") as f:
749
+ lines = f.read().strip("\n").split("\n")
750
+
751
+ for line in lines:
752
+ tmp = line.split("\t")
753
+ if len(tmp) != 4:
754
+ continue
755
+ self.phoneme_data[tmp[0]] = [tmp[1]]
756
+
757
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
758
+ tmp = self.audiopaths_sid_text
759
+ leng = len(tmp)
760
+ min_num = 100
761
+ if leng < min_num:
762
+ self.audiopaths_sid_text = []
763
+ for _ in range(max(2, int(min_num / leng))):
764
+ self.audiopaths_sid_text += tmp
765
+ self.max_wav_value = hparams.max_wav_value
766
+ self.sampling_rate = hparams.sampling_rate
767
+ self.filter_length = hparams.filter_length
768
+ self.hop_length = hparams.hop_length
769
+ self.win_length = hparams.win_length
770
+ self.sampling_rate = hparams.sampling_rate
771
+ self.val = val
772
+
773
+ random.seed(1234)
774
+ random.shuffle(self.audiopaths_sid_text)
775
+
776
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
777
+ print("wav_data_len:", len(self.audiopaths_sid_text))
778
+
779
+ audiopaths_sid_text_new = []
780
+ lengths = []
781
+ skipped_phone = 0
782
+ skipped_dur = 0
783
+ for audiopath in tqdm(self.audiopaths_sid_text):
784
+ try:
785
+ phoneme = self.phoneme_data[audiopath][0]
786
+ phoneme = phoneme.split(" ")
787
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
788
+ except Exception:
789
+ print(f"{audiopath} not in self.phoneme_data !")
790
+ skipped_phone += 1
791
+ continue
792
+
793
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
794
+ duration = size / self.sampling_rate / 2
795
+
796
+ if duration == 0:
797
+ print(f"Zero duration for {audiopath}, skipping...")
798
+ skipped_dur += 1
799
+ continue
800
+
801
+ if 54 > duration > 0.6 or self.val:
802
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
803
+ lengths.append(size // (2 * self.hop_length))
804
+ else:
805
+ skipped_dur += 1
806
+ continue
807
+
808
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
809
+ print("total left: ", len(audiopaths_sid_text_new))
810
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
811
+ self.audiopaths_sid_text = audiopaths_sid_text_new
812
+ self.lengths = lengths
813
+ self.spec_min = -12
814
+ self.spec_max = 2
815
+
816
+ self.filter_length_mel = self.win_length_mel = 1024
817
+ self.hop_length_mel = 256
818
+ self.n_mel_channels = 100
819
+ self.sampling_rate_mel = 24000
820
+ self.mel_fmin = 0
821
+ self.mel_fmax = None
822
+
823
+ def norm_spec(self, x):
824
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
825
+
826
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
827
+ audiopath, phoneme_ids = audiopath_sid_text
828
+ text = torch.FloatTensor(phoneme_ids)
829
+ try:
830
+ spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
831
+ with torch.no_grad():
832
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
833
+ if ssl.shape[-1] != spec.shape[-1]:
834
+ typee = ssl.dtype
835
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
836
+ ssl.requires_grad = False
837
+ except:
838
+ traceback.print_exc()
839
+ mel = torch.zeros(100, 180)
840
+ wav = torch.zeros(1, 96 * self.hop_length)
841
+ spec = torch.zeros(1025, 96)
842
+ ssl = torch.zeros(1, 768, 96)
843
+ text = text[-1:]
844
+ print("load audio or ssl error!!!!!!", audiopath)
845
+ return (ssl, spec, wav, mel, text)
846
+
847
+ def get_audio(self, filename):
848
+ audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
849
+ audio = torch.FloatTensor(audio_array) # /32768
850
+ audio_norm = audio
851
+ audio_norm = audio_norm.unsqueeze(0)
852
+ audio_array24 = load_audio(
853
+ filename, 24000
854
+ ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
855
+ audio24 = torch.FloatTensor(audio_array24) # /32768
856
+ audio_norm24 = audio24
857
+ audio_norm24 = audio_norm24.unsqueeze(0)
858
+
859
+ spec = spectrogram_torch(
860
+ audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
861
+ )
862
+ spec = torch.squeeze(spec, 0)
863
+
864
+ spec1 = spectrogram_torch(
865
+ audio_norm24,
866
+ self.filter_length_mel,
867
+ self.sampling_rate_mel,
868
+ self.hop_length_mel,
869
+ self.win_length_mel,
870
+ center=False,
871
+ )
872
+ mel = spec_to_mel_torch(
873
+ spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
874
+ )
875
+ mel = torch.squeeze(mel, 0)
876
+ mel = self.norm_spec(mel)
877
+ # print(1111111,spec.shape,mel.shape)
878
+ return spec, mel, audio_norm
879
+
880
+ def get_sid(self, sid):
881
+ sid = torch.LongTensor([int(sid)])
882
+ return sid
883
+
884
+ def __getitem__(self, index):
885
+ # with torch.no_grad():
886
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
887
+
888
+ def __len__(self):
889
+ return len(self.audiopaths_sid_text)
890
+
891
+
892
+ class TextAudioSpeakerCollateV3b:
893
+ """Zero-pads model inputs and targets"""
894
+
895
+ def __init__(self, return_ids=False):
896
+ self.return_ids = return_ids
897
+
898
+ def __call__(self, batch):
899
+ """Collate's training batch from normalized text, audio and speaker identities
900
+ PARAMS
901
+ ------
902
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
903
+ """
904
+ # ssl, spec, wav,mel, text
905
+ # Right zero-pad all one-hot text sequences to max input length
906
+ _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
907
+ # (ssl, spec,mel, text)
908
+ max_ssl_len = max([x[0].size(2) for x in batch])
909
+
910
+ max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
911
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
912
+
913
+ # max_ssl_len = int(8 * ((max_ssl_len // 8) + 1))
914
+ # max_ssl_len1=max_ssl_len
915
+
916
+ max_spec_len = max([x[1].size(1) for x in batch])
917
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
918
+ max_wav_len = max([x[2].size(1) for x in batch])
919
+ max_text_len = max([x[4].size(0) for x in batch])
920
+ max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
921
+
922
+ ssl_lengths = torch.LongTensor(len(batch))
923
+ spec_lengths = torch.LongTensor(len(batch))
924
+ text_lengths = torch.LongTensor(len(batch))
925
+ wav_lengths = torch.LongTensor(len(batch))
926
+ mel_lengths = torch.LongTensor(len(batch))
927
+
928
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
929
+ mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
930
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
931
+ text_padded = torch.LongTensor(len(batch), max_text_len)
932
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
933
+
934
+ spec_padded.zero_()
935
+ mel_padded.zero_()
936
+ ssl_padded.zero_()
937
+ text_padded.zero_()
938
+ wav_padded.zero_()
939
+
940
+ for i in range(len(ids_sorted_decreasing)):
941
+ row = batch[ids_sorted_decreasing[i]]
942
+ # ssl, spec, wav,mel, text
943
+ ssl = row[0]
944
+ ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
945
+ ssl_lengths[i] = ssl.size(2)
946
+
947
+ spec = row[1]
948
+ spec_padded[i, :, : spec.size(1)] = spec
949
+ spec_lengths[i] = spec.size(1)
950
+
951
+ wav = row[2]
952
+ wav_padded[i, :, : wav.size(1)] = wav
953
+ wav_lengths[i] = wav.size(1)
954
+
955
+ mel = row[3]
956
+ mel_padded[i, :, : mel.size(1)] = mel
957
+ mel_lengths[i] = mel.size(1)
958
+
959
+ text = row[4]
960
+ text_padded[i, : text.size(0)] = text
961
+ text_lengths[i] = text.size(0)
962
+
963
+ return (
964
+ ssl_padded,
965
+ spec_padded,
966
+ mel_padded,
967
+ ssl_lengths,
968
+ spec_lengths,
969
+ text_padded,
970
+ text_lengths,
971
+ wav_padded,
972
+ wav_lengths,
973
+ mel_lengths,
974
+ )
975
+ # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
976
+
977
+
978
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
979
+ """
980
+ Maintain similar input lengths in a batch.
981
+ Length groups are specified by boundaries.
982
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
983
+
984
+ It removes samples which are not included in the boundaries.
985
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
986
+ """
987
+
988
+ def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
989
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
990
+ self.lengths = dataset.lengths
991
+ self.batch_size = batch_size
992
+ self.boundaries = boundaries
993
+
994
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
995
+ self.total_size = sum(self.num_samples_per_bucket)
996
+ self.num_samples = self.total_size // self.num_replicas
997
+
998
+ def _create_buckets(self):
999
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
1000
+ for i in range(len(self.lengths)):
1001
+ length = self.lengths[i]
1002
+ idx_bucket = self._bisect(length)
1003
+ if idx_bucket != -1:
1004
+ buckets[idx_bucket].append(i)
1005
+
1006
+ i = len(buckets) - 1
1007
+ while i >= 0:
1008
+ if len(buckets[i]) == 0:
1009
+ buckets.pop(i)
1010
+ self.boundaries.pop(i + 1)
1011
+ i -= 1
1012
+
1013
+ num_samples_per_bucket = []
1014
+ for i in range(len(buckets)):
1015
+ len_bucket = len(buckets[i])
1016
+ total_batch_size = self.num_replicas * self.batch_size
1017
+ rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
1018
+ num_samples_per_bucket.append(len_bucket + rem)
1019
+ return buckets, num_samples_per_bucket
1020
+
1021
+ def __iter__(self):
1022
+ g = torch.Generator()
1023
+ g.manual_seed(self.epoch)
1024
+
1025
+ indices = []
1026
+ if self.shuffle:
1027
+ for bucket in self.buckets:
1028
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
1029
+ else:
1030
+ for bucket in self.buckets:
1031
+ indices.append(list(range(len(bucket))))
1032
+
1033
+ batches = []
1034
+ for i in range(len(self.buckets)):
1035
+ bucket = self.buckets[i]
1036
+ len_bucket = len(bucket)
1037
+ ids_bucket = indices[i]
1038
+ num_samples_bucket = self.num_samples_per_bucket[i]
1039
+
1040
+ rem = num_samples_bucket - len_bucket
1041
+ ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
1042
+
1043
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
1044
+
1045
+ for j in range(len(ids_bucket) // self.batch_size):
1046
+ batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
1047
+ batches.append(batch)
1048
+
1049
+ if self.shuffle:
1050
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
1051
+ batches = [batches[i] for i in batch_ids]
1052
+ self.batches = batches
1053
+
1054
+ assert len(self.batches) * self.batch_size == self.num_samples
1055
+ return iter(self.batches)
1056
+
1057
+ def _bisect(self, x, lo=0, hi=None):
1058
+ if hi is None:
1059
+ hi = len(self.boundaries) - 1
1060
+
1061
+ if hi > lo:
1062
+ mid = (hi + lo) // 2
1063
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
1064
+ return mid
1065
+ elif x <= self.boundaries[mid]:
1066
+ return self._bisect(x, lo, mid)
1067
+ else:
1068
+ return self._bisect(x, mid + 1, hi)
1069
+ else:
1070
+ return -1
1071
+
1072
+ def __len__(self):
1073
+ return self.num_samples // self.batch_size
GPT_SoVITS/module/losses.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ def feature_loss(fmap_r, fmap_g):
7
+ loss = torch.tensor(0).to(fmap_r[0][0].device)
8
+ for dr, dg in zip(fmap_r, fmap_g):
9
+ for rl, gl in zip(dr, dg):
10
+ rl = rl.float().detach()
11
+ gl = gl.float()
12
+ loss = torch.mean(torch.abs(rl - gl)) + loss
13
+
14
+ return loss * 2
15
+
16
+
17
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
18
+ loss = torch.tensor(0).to(disc_real_outputs[0].device)
19
+ r_losses = []
20
+ g_losses = []
21
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
22
+ dr = dr.float()
23
+ dg = dg.float()
24
+ r_loss = torch.mean((1 - dr) ** 2)
25
+ g_loss = torch.mean(dg**2)
26
+ loss = r_loss + g_loss + loss
27
+ r_losses.append(r_loss.item())
28
+ g_losses.append(g_loss.item())
29
+
30
+ return loss, r_losses, g_losses
31
+
32
+
33
+ def generator_loss(disc_outputs):
34
+ loss = torch.tensor(0).to(disc_outputs[0].device)
35
+ gen_losses = []
36
+ for dg in disc_outputs:
37
+ dg = dg.float()
38
+ l = torch.mean((1 - dg) ** 2)
39
+ gen_losses.append(l)
40
+ loss = l + loss
41
+
42
+ return loss, gen_losses
43
+
44
+
45
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
46
+ """
47
+ z_p, logs_q: [b, h, t_t]
48
+ m_p, logs_p: [b, h, t_t]
49
+ """
50
+ z_p = z_p.float()
51
+ logs_q = logs_q.float()
52
+ m_p = m_p.float()
53
+ logs_p = logs_p.float()
54
+ z_mask = z_mask.float()
55
+
56
+ kl = logs_p - logs_q - 0.5
57
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
58
+ kl = torch.sum(kl * z_mask)
59
+ l = kl / torch.sum(z_mask)
60
+ return l
61
+
62
+
63
+ def mle_loss(z, m, logs, logdet, mask):
64
+ l = torch.sum(logs) + 0.5 * torch.sum(
65
+ torch.exp(-2 * logs) * ((z - m) ** 2)
66
+ ) # neg normal likelihood w/o the constant term
67
+ l = l - torch.sum(logdet) # log jacobian determinant
68
+ l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
69
+ l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
70
+ return l
GPT_SoVITS/module/mel_processing.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from librosa.filters import mel as librosa_mel_fn
3
+
4
+ MAX_WAV_VALUE = 32768.0
5
+
6
+
7
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
8
+ """
9
+ PARAMS
10
+ ------
11
+ C: compression factor
12
+ """
13
+ return torch.log(torch.clamp(x, min=clip_val) * C)
14
+
15
+
16
+ def dynamic_range_decompression_torch(x, C=1):
17
+ """
18
+ PARAMS
19
+ ------
20
+ C: compression factor used to compress
21
+ """
22
+ return torch.exp(x) / C
23
+
24
+
25
+ def spectral_normalize_torch(magnitudes):
26
+ output = dynamic_range_compression_torch(magnitudes)
27
+ return output
28
+
29
+
30
+ def spectral_de_normalize_torch(magnitudes):
31
+ output = dynamic_range_decompression_torch(magnitudes)
32
+ return output
33
+
34
+
35
+ mel_basis = {}
36
+ hann_window = {}
37
+
38
+
39
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
40
+ if torch.min(y) < -1.2:
41
+ print("min value is ", torch.min(y))
42
+ if torch.max(y) > 1.2:
43
+ print("max value is ", torch.max(y))
44
+
45
+ global hann_window
46
+ dtype_device = str(y.dtype) + "_" + str(y.device)
47
+ # wnsize_dtype_device = str(win_size) + '_' + dtype_device
48
+ key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size)
49
+ # if wnsize_dtype_device not in hann_window:
50
+ if key not in hann_window:
51
+ # hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
52
+ hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
53
+
54
+ y = torch.nn.functional.pad(
55
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
56
+ )
57
+ y = y.squeeze(1)
58
+ # spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
59
+ spec = torch.stft(
60
+ y,
61
+ n_fft,
62
+ hop_length=hop_size,
63
+ win_length=win_size,
64
+ window=hann_window[key],
65
+ center=center,
66
+ pad_mode="reflect",
67
+ normalized=False,
68
+ onesided=True,
69
+ return_complex=True,
70
+ )
71
+
72
+ spec = spec.abs().pow_(2).add_(1e-8).sqrt_()
73
+ return spec
74
+
75
+
76
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
77
+ global mel_basis
78
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
79
+ # fmax_dtype_device = str(fmax) + '_' + dtype_device
80
+ key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax)
81
+ # if fmax_dtype_device not in mel_basis:
82
+ if key not in mel_basis:
83
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
84
+ # mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
85
+ mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
86
+ # spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
87
+ spec = torch.matmul(mel_basis[key], spec)
88
+ spec = spectral_normalize_torch(spec)
89
+ return spec
90
+
91
+
92
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
93
+ if torch.min(y) < -1.2:
94
+ print("min value is ", torch.min(y))
95
+ if torch.max(y) > 1.2:
96
+ print("max value is ", torch.max(y))
97
+
98
+ global mel_basis, hann_window
99
+ dtype_device = str(y.dtype) + "_" + str(y.device)
100
+ # fmax_dtype_device = str(fmax) + '_' + dtype_device
101
+ fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % (
102
+ dtype_device,
103
+ n_fft,
104
+ num_mels,
105
+ sampling_rate,
106
+ hop_size,
107
+ win_size,
108
+ fmin,
109
+ fmax,
110
+ )
111
+ # wnsize_dtype_device = str(win_size) + '_' + dtype_device
112
+ wnsize_dtype_device = fmax_dtype_device
113
+ if fmax_dtype_device not in mel_basis:
114
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
115
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
116
+ if wnsize_dtype_device not in hann_window:
117
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
118
+
119
+ y = torch.nn.functional.pad(
120
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
121
+ )
122
+ y = y.squeeze(1)
123
+
124
+ spec = torch.stft(
125
+ y,
126
+ n_fft,
127
+ hop_length=hop_size,
128
+ win_length=win_size,
129
+ window=hann_window[wnsize_dtype_device],
130
+ center=center,
131
+ pad_mode="reflect",
132
+ normalized=False,
133
+ onesided=True,
134
+ return_complex=True,
135
+ )
136
+
137
+ spec = spec.abs().pow_(2).add_(1e-8).sqrt_()
138
+
139
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
140
+ spec = spectral_normalize_torch(spec)
141
+
142
+ return spec
GPT_SoVITS/module/models.py ADDED
@@ -0,0 +1,1411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import math
3
+ import random
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.cuda.amp import autocast
8
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
9
+ from torch.nn import functional as F
10
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
11
+
12
+ from GPT_SoVITS.f5_tts.model import DiT
13
+ from GPT_SoVITS.text import symbols as symbols_v1
14
+ from GPT_SoVITS.text import symbols2 as symbols_v2
15
+ from GPT_SoVITS.utils import HParams
16
+ from tools.my_utils import _open_file
17
+
18
+ from . import attentions, commons, modules
19
+ from .commons import get_padding, init_weights
20
+ from .mrte_model import MRTE
21
+ from .quantize import ResidualVectorQuantizer
22
+
23
+
24
+ def set_serialization():
25
+ torch.serialization.add_safe_globals([(HParams, "utils.HParams")])
26
+ torch.serialization._open_file = _open_file
27
+
28
+
29
+ set_serialization()
30
+
31
+
32
+ class StochasticDurationPredictor(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ filter_channels,
37
+ kernel_size,
38
+ p_dropout,
39
+ n_flows=4,
40
+ gin_channels=0,
41
+ ):
42
+ super().__init__()
43
+ filter_channels = in_channels # it needs to be removed from future version.
44
+ self.in_channels = in_channels
45
+ self.filter_channels = filter_channels
46
+ self.kernel_size = kernel_size
47
+ self.p_dropout = p_dropout
48
+ self.n_flows = n_flows
49
+ self.gin_channels = gin_channels
50
+
51
+ self.log_flow = modules.Log()
52
+ self.flows = nn.ModuleList()
53
+ self.flows.append(modules.ElementwiseAffine(2))
54
+ for i in range(n_flows):
55
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
56
+ self.flows.append(modules.Flip())
57
+
58
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
59
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
60
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
61
+ self.post_flows = nn.ModuleList()
62
+ self.post_flows.append(modules.ElementwiseAffine(2))
63
+ for i in range(4):
64
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
65
+ self.post_flows.append(modules.Flip())
66
+
67
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
68
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
69
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
70
+ if gin_channels != 0:
71
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
72
+
73
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
74
+ x = torch.detach(x)
75
+ x = self.pre(x)
76
+ if g is not None:
77
+ g = torch.detach(g)
78
+ x = x + self.cond(g)
79
+ x = self.convs(x, x_mask)
80
+ x = self.proj(x) * x_mask
81
+
82
+ if not reverse:
83
+ flows = self.flows
84
+ assert w is not None
85
+
86
+ logdet_tot_q = 0
87
+ h_w = self.post_pre(w)
88
+ h_w = self.post_convs(h_w, x_mask)
89
+ h_w = self.post_proj(h_w) * x_mask
90
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
91
+ z_q = e_q
92
+ for flow in self.post_flows:
93
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
94
+ logdet_tot_q += logdet_q
95
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
96
+ u = torch.sigmoid(z_u) * x_mask
97
+ z0 = (w - u) * x_mask
98
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
99
+ logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
100
+
101
+ logdet_tot = 0
102
+ z0, logdet = self.log_flow(z0, x_mask)
103
+ logdet_tot += logdet
104
+ z = torch.cat([z0, z1], 1)
105
+ for flow in flows:
106
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
107
+ logdet_tot = logdet_tot + logdet
108
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
109
+ return nll + logq # [b]
110
+ else:
111
+ flows = list(reversed(self.flows))
112
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
113
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
114
+ for flow in flows:
115
+ z = flow(z, x_mask, g=x, reverse=reverse)
116
+ z0, z1 = torch.split(z, [1, 1], 1)
117
+ logw = z0
118
+ return logw
119
+
120
+
121
+ class DurationPredictor(nn.Module):
122
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
123
+ super().__init__()
124
+
125
+ self.in_channels = in_channels
126
+ self.filter_channels = filter_channels
127
+ self.kernel_size = kernel_size
128
+ self.p_dropout = p_dropout
129
+ self.gin_channels = gin_channels
130
+
131
+ self.drop = nn.Dropout(p_dropout)
132
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
133
+ self.norm_1 = modules.LayerNorm(filter_channels)
134
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
135
+ self.norm_2 = modules.LayerNorm(filter_channels)
136
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
137
+
138
+ if gin_channels != 0:
139
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
140
+
141
+ def forward(self, x, x_mask, g=None):
142
+ x = torch.detach(x)
143
+ if g is not None:
144
+ g = torch.detach(g)
145
+ x = x + self.cond(g)
146
+ x = self.conv_1(x * x_mask)
147
+ x = torch.relu(x)
148
+ x = self.norm_1(x)
149
+ x = self.drop(x)
150
+ x = self.conv_2(x * x_mask)
151
+ x = torch.relu(x)
152
+ x = self.norm_2(x)
153
+ x = self.drop(x)
154
+ x = self.proj(x * x_mask)
155
+ return x * x_mask
156
+
157
+
158
+ class TextEncoder(nn.Module):
159
+ def __init__(
160
+ self,
161
+ out_channels,
162
+ hidden_channels,
163
+ filter_channels,
164
+ n_heads,
165
+ n_layers,
166
+ kernel_size,
167
+ p_dropout,
168
+ latent_channels=192,
169
+ version="v2",
170
+ ):
171
+ super().__init__()
172
+ self.out_channels = out_channels
173
+ self.hidden_channels = hidden_channels
174
+ self.filter_channels = filter_channels
175
+ self.n_heads = n_heads
176
+ self.n_layers = n_layers
177
+ self.kernel_size = kernel_size
178
+ self.p_dropout = p_dropout
179
+ self.latent_channels = latent_channels
180
+ self.version = version
181
+
182
+ self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
183
+
184
+ self.encoder_ssl = attentions.Encoder(
185
+ hidden_channels,
186
+ filter_channels,
187
+ n_heads,
188
+ n_layers // 2,
189
+ kernel_size,
190
+ p_dropout,
191
+ )
192
+
193
+ self.encoder_text = attentions.Encoder(
194
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
195
+ )
196
+
197
+ if self.version == "v1":
198
+ symbols = symbols_v1.symbols
199
+ else:
200
+ symbols = symbols_v2.symbols
201
+ self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
202
+
203
+ self.mrte = MRTE()
204
+
205
+ self.encoder2 = attentions.Encoder(
206
+ hidden_channels,
207
+ filter_channels,
208
+ n_heads,
209
+ n_layers // 2,
210
+ kernel_size,
211
+ p_dropout,
212
+ )
213
+
214
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
215
+
216
+ def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
217
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
218
+
219
+ y = self.ssl_proj(y * y_mask) * y_mask
220
+
221
+ y = self.encoder_ssl(y * y_mask, y_mask)
222
+
223
+ text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
224
+ if test == 1:
225
+ text[:, :] = 0
226
+ text = self.text_embedding(text).transpose(1, 2)
227
+ text = self.encoder_text(text * text_mask, text_mask)
228
+ y = self.mrte(y, y_mask, text, text_mask, ge)
229
+ y = self.encoder2(y * y_mask, y_mask)
230
+ if speed != 1:
231
+ y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
232
+ y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
233
+ stats = self.proj(y) * y_mask
234
+ m, logs = torch.split(stats, self.out_channels, dim=1)
235
+ return y, m, logs, y_mask
236
+
237
+
238
+ class ResidualCouplingBlock(nn.Module):
239
+ def __init__(
240
+ self,
241
+ channels,
242
+ hidden_channels,
243
+ kernel_size,
244
+ dilation_rate,
245
+ n_layers,
246
+ n_flows=4,
247
+ gin_channels=0,
248
+ ):
249
+ super().__init__()
250
+ self.channels = channels
251
+ self.hidden_channels = hidden_channels
252
+ self.kernel_size = kernel_size
253
+ self.dilation_rate = dilation_rate
254
+ self.n_layers = n_layers
255
+ self.n_flows = n_flows
256
+ self.gin_channels = gin_channels
257
+
258
+ self.flows = nn.ModuleList()
259
+ for i in range(n_flows):
260
+ self.flows.append(
261
+ modules.ResidualCouplingLayer(
262
+ channels,
263
+ hidden_channels,
264
+ kernel_size,
265
+ dilation_rate,
266
+ n_layers,
267
+ gin_channels=gin_channels,
268
+ mean_only=True,
269
+ )
270
+ )
271
+ self.flows.append(modules.Flip())
272
+
273
+ def forward(self, x, x_mask, g=None, reverse=False):
274
+ if not reverse:
275
+ for flow in self.flows:
276
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
277
+ else:
278
+ for flow in reversed(self.flows):
279
+ x = flow(x, x_mask, g=g, reverse=reverse)
280
+ return x
281
+
282
+
283
+ class PosteriorEncoder(nn.Module):
284
+ def __init__(
285
+ self,
286
+ in_channels,
287
+ out_channels,
288
+ hidden_channels,
289
+ kernel_size,
290
+ dilation_rate,
291
+ n_layers,
292
+ gin_channels=0,
293
+ ):
294
+ super().__init__()
295
+ self.in_channels = in_channels
296
+ self.out_channels = out_channels
297
+ self.hidden_channels = hidden_channels
298
+ self.kernel_size = kernel_size
299
+ self.dilation_rate = dilation_rate
300
+ self.n_layers = n_layers
301
+ self.gin_channels = gin_channels
302
+
303
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
304
+ self.enc = modules.WN(
305
+ hidden_channels,
306
+ kernel_size,
307
+ dilation_rate,
308
+ n_layers,
309
+ gin_channels=gin_channels,
310
+ )
311
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
312
+
313
+ def forward(self, x, x_lengths, g=None):
314
+ if g != None:
315
+ g = g.detach()
316
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
317
+ x = self.pre(x) * x_mask
318
+ x = self.enc(x, x_mask, g=g)
319
+ stats = self.proj(x) * x_mask
320
+ m, logs = torch.split(stats, self.out_channels, dim=1)
321
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
322
+ return z, m, logs, x_mask
323
+
324
+
325
+ class Encoder(nn.Module):
326
+ def __init__(
327
+ self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
328
+ ):
329
+ super().__init__()
330
+ self.in_channels = in_channels
331
+ self.out_channels = out_channels
332
+ self.hidden_channels = hidden_channels
333
+ self.kernel_size = kernel_size
334
+ self.dilation_rate = dilation_rate
335
+ self.n_layers = n_layers
336
+ self.gin_channels = gin_channels
337
+
338
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
339
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
340
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
341
+
342
+ def forward(self, x, x_lengths, g=None):
343
+ if g != None:
344
+ g = g.detach()
345
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
346
+ x = self.pre(x) * x_mask
347
+ x = self.enc(x, x_mask, g=g)
348
+ stats = self.proj(x) * x_mask
349
+ return stats, x_mask
350
+
351
+
352
+ class WNEncoder(nn.Module):
353
+ def __init__(
354
+ self,
355
+ in_channels,
356
+ out_channels,
357
+ hidden_channels,
358
+ kernel_size,
359
+ dilation_rate,
360
+ n_layers,
361
+ gin_channels=0,
362
+ ):
363
+ super().__init__()
364
+ self.in_channels = in_channels
365
+ self.out_channels = out_channels
366
+ self.hidden_channels = hidden_channels
367
+ self.kernel_size = kernel_size
368
+ self.dilation_rate = dilation_rate
369
+ self.n_layers = n_layers
370
+ self.gin_channels = gin_channels
371
+
372
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
373
+ self.enc = modules.WN(
374
+ hidden_channels,
375
+ kernel_size,
376
+ dilation_rate,
377
+ n_layers,
378
+ gin_channels=gin_channels,
379
+ )
380
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
381
+ self.norm = modules.LayerNorm(out_channels)
382
+
383
+ def forward(self, x, x_lengths, g=None):
384
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
385
+ x = self.pre(x) * x_mask
386
+ x = self.enc(x, x_mask, g=g)
387
+ out = self.proj(x) * x_mask
388
+ out = self.norm(out)
389
+ return out
390
+
391
+
392
+ class Generator(torch.nn.Module):
393
+ def __init__(
394
+ self,
395
+ initial_channel,
396
+ resblock,
397
+ resblock_kernel_sizes,
398
+ resblock_dilation_sizes,
399
+ upsample_rates,
400
+ upsample_initial_channel,
401
+ upsample_kernel_sizes,
402
+ gin_channels=0,
403
+ is_bias=False,
404
+ ):
405
+ super(Generator, self).__init__()
406
+ self.num_kernels = len(resblock_kernel_sizes)
407
+ self.num_upsamples = len(upsample_rates)
408
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
409
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
410
+
411
+ self.ups = nn.ModuleList()
412
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
413
+ self.ups.append(
414
+ weight_norm(
415
+ ConvTranspose1d(
416
+ upsample_initial_channel // (2**i),
417
+ upsample_initial_channel // (2 ** (i + 1)),
418
+ k,
419
+ u,
420
+ padding=(k - u) // 2,
421
+ )
422
+ )
423
+ )
424
+
425
+ self.resblocks = nn.ModuleList()
426
+ for i in range(len(self.ups)):
427
+ ch = upsample_initial_channel // (2 ** (i + 1))
428
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
429
+ self.resblocks.append(resblock(ch, k, d))
430
+
431
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
432
+ self.ups.apply(init_weights)
433
+
434
+ if gin_channels != 0:
435
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
436
+
437
+ def forward(self, x, g=None):
438
+ x = self.conv_pre(x)
439
+ if g is not None:
440
+ x = x + self.cond(g)
441
+
442
+ for i in range(self.num_upsamples):
443
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
444
+ x = self.ups[i](x)
445
+ xs = None
446
+ for j in range(self.num_kernels):
447
+ if xs is None:
448
+ xs = self.resblocks[i * self.num_kernels + j](x)
449
+ else:
450
+ xs += self.resblocks[i * self.num_kernels + j](x)
451
+ x = xs / self.num_kernels
452
+ x = F.leaky_relu(x)
453
+ x = self.conv_post(x)
454
+ x = torch.tanh(x)
455
+
456
+ return x
457
+
458
+ def remove_weight_norm(self):
459
+ print("Removing weight norm...")
460
+ for l in self.ups:
461
+ remove_weight_norm(l)
462
+ for l in self.resblocks:
463
+ l.remove_weight_norm()
464
+
465
+
466
+ class DiscriminatorP(torch.nn.Module):
467
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
468
+ super(DiscriminatorP, self).__init__()
469
+ self.period = period
470
+ self.use_spectral_norm = use_spectral_norm
471
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
472
+ self.convs = nn.ModuleList(
473
+ [
474
+ norm_f(
475
+ Conv2d(
476
+ 1,
477
+ 32,
478
+ (kernel_size, 1),
479
+ (stride, 1),
480
+ padding=(get_padding(kernel_size, 1), 0),
481
+ )
482
+ ),
483
+ norm_f(
484
+ Conv2d(
485
+ 32,
486
+ 128,
487
+ (kernel_size, 1),
488
+ (stride, 1),
489
+ padding=(get_padding(kernel_size, 1), 0),
490
+ )
491
+ ),
492
+ norm_f(
493
+ Conv2d(
494
+ 128,
495
+ 512,
496
+ (kernel_size, 1),
497
+ (stride, 1),
498
+ padding=(get_padding(kernel_size, 1), 0),
499
+ )
500
+ ),
501
+ norm_f(
502
+ Conv2d(
503
+ 512,
504
+ 1024,
505
+ (kernel_size, 1),
506
+ (stride, 1),
507
+ padding=(get_padding(kernel_size, 1), 0),
508
+ )
509
+ ),
510
+ norm_f(
511
+ Conv2d(
512
+ 1024,
513
+ 1024,
514
+ (kernel_size, 1),
515
+ 1,
516
+ padding=(get_padding(kernel_size, 1), 0),
517
+ )
518
+ ),
519
+ ]
520
+ )
521
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
522
+
523
+ def forward(self, x):
524
+ fmap = []
525
+
526
+ # 1d to 2d
527
+ b, c, t = x.shape
528
+ if t % self.period != 0: # pad first
529
+ n_pad = self.period - (t % self.period)
530
+ x = F.pad(x, (0, n_pad), "reflect")
531
+ t = t + n_pad
532
+ x = x.view(b, c, t // self.period, self.period)
533
+
534
+ for l in self.convs:
535
+ x = l(x)
536
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
537
+ fmap.append(x)
538
+ x = self.conv_post(x)
539
+ fmap.append(x)
540
+ x = torch.flatten(x, 1, -1)
541
+
542
+ return x, fmap
543
+
544
+
545
+ class DiscriminatorS(torch.nn.Module):
546
+ def __init__(self, use_spectral_norm=False):
547
+ super(DiscriminatorS, self).__init__()
548
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
549
+ self.convs = nn.ModuleList(
550
+ [
551
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
552
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
553
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
554
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
555
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
556
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
557
+ ]
558
+ )
559
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
560
+
561
+ def forward(self, x):
562
+ fmap = []
563
+
564
+ for l in self.convs:
565
+ x = l(x)
566
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
567
+ fmap.append(x)
568
+ x = self.conv_post(x)
569
+ fmap.append(x)
570
+ x = torch.flatten(x, 1, -1)
571
+
572
+ return x, fmap
573
+
574
+
575
+ v2pro_set = {"v2Pro", "v2ProPlus"}
576
+
577
+
578
+ class MultiPeriodDiscriminator(torch.nn.Module):
579
+ def __init__(self, use_spectral_norm=False, version=None):
580
+ super(MultiPeriodDiscriminator, self).__init__()
581
+ if version in v2pro_set:
582
+ periods = [2, 3, 5, 7, 11, 17, 23]
583
+ else:
584
+ periods = [2, 3, 5, 7, 11]
585
+
586
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
587
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
588
+ self.discriminators = nn.ModuleList(discs)
589
+
590
+ def forward(self, y, y_hat):
591
+ y_d_rs = []
592
+ y_d_gs = []
593
+ fmap_rs = []
594
+ fmap_gs = []
595
+ for i, d in enumerate(self.discriminators):
596
+ y_d_r, fmap_r = d(y)
597
+ y_d_g, fmap_g = d(y_hat)
598
+ y_d_rs.append(y_d_r)
599
+ y_d_gs.append(y_d_g)
600
+ fmap_rs.append(fmap_r)
601
+ fmap_gs.append(fmap_g)
602
+
603
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
604
+
605
+
606
+ class ReferenceEncoder(nn.Module):
607
+ """
608
+ inputs --- [N, Ty/r, n_mels*r] mels
609
+ outputs --- [N, ref_enc_gru_size]
610
+ """
611
+
612
+ def __init__(self, spec_channels, gin_channels=0):
613
+ super().__init__()
614
+ self.spec_channels = spec_channels
615
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
616
+ K = len(ref_enc_filters)
617
+ filters = [1] + ref_enc_filters
618
+ convs = [
619
+ weight_norm(
620
+ nn.Conv2d(
621
+ in_channels=filters[i],
622
+ out_channels=filters[i + 1],
623
+ kernel_size=(3, 3),
624
+ stride=(2, 2),
625
+ padding=(1, 1),
626
+ )
627
+ )
628
+ for i in range(K)
629
+ ]
630
+ self.convs = nn.ModuleList(convs)
631
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
632
+
633
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
634
+ self.gru = nn.GRU(
635
+ input_size=ref_enc_filters[-1] * out_channels,
636
+ hidden_size=256 // 2,
637
+ batch_first=True,
638
+ )
639
+ self.proj = nn.Linear(128, gin_channels)
640
+
641
+ def forward(self, inputs):
642
+ N = inputs.size(0)
643
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
644
+ for conv in self.convs:
645
+ out = conv(out)
646
+ # out = wn(out)
647
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
648
+
649
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
650
+ T = out.size(1)
651
+ N = out.size(0)
652
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
653
+
654
+ self.gru.flatten_parameters()
655
+ memory, out = self.gru(out) # out --- [1, N, 128]
656
+
657
+ return self.proj(out.squeeze(0)).unsqueeze(-1)
658
+
659
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
660
+ for i in range(n_convs):
661
+ L = (L - kernel_size + 2 * pad) // stride + 1
662
+ return L
663
+
664
+
665
+ class Quantizer_module(torch.nn.Module):
666
+ def __init__(self, n_e, e_dim):
667
+ super(Quantizer_module, self).__init__()
668
+ self.embedding = nn.Embedding(n_e, e_dim)
669
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
670
+
671
+ def forward(self, x):
672
+ d = (
673
+ torch.sum(x**2, 1, keepdim=True)
674
+ + torch.sum(self.embedding.weight**2, 1)
675
+ - 2 * torch.matmul(x, self.embedding.weight.T)
676
+ )
677
+ min_indicies = torch.argmin(d, 1)
678
+ z_q = self.embedding(min_indicies)
679
+ return z_q, min_indicies
680
+
681
+
682
+ class Quantizer(torch.nn.Module):
683
+ def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
684
+ super(Quantizer, self).__init__()
685
+ assert embed_dim % n_code_groups == 0
686
+ self.quantizer_modules = nn.ModuleList(
687
+ [Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
688
+ )
689
+ self.n_code_groups = n_code_groups
690
+ self.embed_dim = embed_dim
691
+
692
+ def forward(self, xin):
693
+ # B, C, T
694
+ B, C, T = xin.shape
695
+ xin = xin.transpose(1, 2)
696
+ x = xin.reshape(-1, self.embed_dim)
697
+ x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
698
+ min_indicies = []
699
+ z_q = []
700
+ for _x, m in zip(x, self.quantizer_modules):
701
+ _z_q, _min_indicies = m(_x)
702
+ z_q.append(_z_q)
703
+ min_indicies.append(_min_indicies) # B * T,
704
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
705
+ loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
706
+ z_q = xin + (z_q - xin).detach()
707
+ z_q = z_q.transpose(1, 2)
708
+ codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
709
+ return z_q, loss, codes.transpose(1, 2)
710
+
711
+ def embed(self, x):
712
+ # idx: N, 4, T
713
+ x = x.transpose(1, 2)
714
+ x = torch.split(x, 1, 2)
715
+ ret = []
716
+ for q, embed in zip(x, self.quantizer_modules):
717
+ q = embed.embedding(q.squeeze(-1))
718
+ ret.append(q)
719
+ ret = torch.cat(ret, -1)
720
+ return ret.transpose(1, 2) # N, C, T
721
+
722
+
723
+ class CodePredictor(nn.Module):
724
+ def __init__(
725
+ self,
726
+ hidden_channels,
727
+ filter_channels,
728
+ n_heads,
729
+ n_layers,
730
+ kernel_size,
731
+ p_dropout,
732
+ n_q=8,
733
+ dims=1024,
734
+ ssl_dim=768,
735
+ ):
736
+ super().__init__()
737
+ self.hidden_channels = hidden_channels
738
+ self.filter_channels = filter_channels
739
+ self.n_heads = n_heads
740
+ self.n_layers = n_layers
741
+ self.kernel_size = kernel_size
742
+ self.p_dropout = p_dropout
743
+
744
+ self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
745
+ self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
746
+
747
+ self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
748
+
749
+ self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
750
+ self.n_q = n_q
751
+ self.dims = dims
752
+
753
+ def forward(self, x, x_mask, refer, codes, infer=False):
754
+ x = x.detach()
755
+ x = self.vq_proj(x * x_mask) * x_mask
756
+ g = self.ref_enc(refer, x_mask)
757
+ x = x + g
758
+ x = self.encoder(x * x_mask, x_mask)
759
+ x = self.out_proj(x * x_mask) * x_mask
760
+ logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
761
+ target = codes[1:].transpose(0, 1)
762
+ if not infer:
763
+ logits = logits.reshape(-1, self.dims)
764
+ target = target.reshape(-1)
765
+ loss = torch.nn.functional.cross_entropy(logits, target)
766
+ return loss
767
+ else:
768
+ _, top10_preds = torch.topk(logits, 10, dim=-1)
769
+ correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
770
+ top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
771
+
772
+ print("Top-10 Accuracy:", top3_acc, "%")
773
+
774
+ pred_codes = torch.argmax(logits, dim=-1)
775
+ acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
776
+ print("Top-1 Accuracy:", acc, "%")
777
+
778
+ return pred_codes.transpose(0, 1)
779
+
780
+
781
+ class SynthesizerTrn(nn.Module):
782
+ """
783
+ Synthesizer for Training
784
+ """
785
+
786
+ def __init__(
787
+ self,
788
+ spec_channels,
789
+ segment_size,
790
+ inter_channels,
791
+ hidden_channels,
792
+ filter_channels,
793
+ n_heads,
794
+ n_layers,
795
+ kernel_size,
796
+ p_dropout,
797
+ resblock,
798
+ resblock_kernel_sizes,
799
+ resblock_dilation_sizes,
800
+ upsample_rates,
801
+ upsample_initial_channel,
802
+ upsample_kernel_sizes,
803
+ n_speakers=0,
804
+ gin_channels=0,
805
+ use_sdp=True,
806
+ semantic_frame_rate=None,
807
+ freeze_quantizer=None,
808
+ version="v2",
809
+ **kwargs,
810
+ ):
811
+ super().__init__()
812
+ self.spec_channels = spec_channels
813
+ self.inter_channels = inter_channels
814
+ self.hidden_channels = hidden_channels
815
+ self.filter_channels = filter_channels
816
+ self.n_heads = n_heads
817
+ self.n_layers = n_layers
818
+ self.kernel_size = kernel_size
819
+ self.p_dropout = p_dropout
820
+ self.resblock = resblock
821
+ self.resblock_kernel_sizes = resblock_kernel_sizes
822
+ self.resblock_dilation_sizes = resblock_dilation_sizes
823
+ self.upsample_rates = upsample_rates
824
+ self.upsample_initial_channel = upsample_initial_channel
825
+ self.upsample_kernel_sizes = upsample_kernel_sizes
826
+ self.segment_size = segment_size
827
+ self.n_speakers = n_speakers
828
+ self.gin_channels = gin_channels
829
+ self.version = version
830
+
831
+ self.use_sdp = use_sdp
832
+ self.enc_p = TextEncoder(
833
+ inter_channels,
834
+ hidden_channels,
835
+ filter_channels,
836
+ n_heads,
837
+ n_layers,
838
+ kernel_size,
839
+ p_dropout,
840
+ version=version,
841
+ )
842
+ self.dec = Generator(
843
+ inter_channels,
844
+ resblock,
845
+ resblock_kernel_sizes,
846
+ resblock_dilation_sizes,
847
+ upsample_rates,
848
+ upsample_initial_channel,
849
+ upsample_kernel_sizes,
850
+ gin_channels=gin_channels,
851
+ )
852
+ self.enc_q = PosteriorEncoder(
853
+ spec_channels,
854
+ inter_channels,
855
+ hidden_channels,
856
+ 5,
857
+ 1,
858
+ 16,
859
+ gin_channels=gin_channels,
860
+ )
861
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
862
+
863
+ # self.version=os.environ.get("version","v1")
864
+ if self.version == "v1":
865
+ self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
866
+ else:
867
+ self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
868
+
869
+ ssl_dim = 768
870
+ assert semantic_frame_rate in ["25hz", "50hz"]
871
+ self.semantic_frame_rate = semantic_frame_rate
872
+ if semantic_frame_rate == "25hz":
873
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
874
+ else:
875
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
876
+
877
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
878
+ self.freeze_quantizer = freeze_quantizer
879
+
880
+ self.is_v2pro = self.version in v2pro_set
881
+ if self.is_v2pro:
882
+ self.sv_emb = nn.Linear(20480, gin_channels)
883
+ self.ge_to512 = nn.Linear(gin_channels, 512)
884
+ self.prelu = nn.PReLU(num_parameters=gin_channels)
885
+
886
+ def forward(self, ssl, y, y_lengths, text, text_lengths, sv_emb=None):
887
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
888
+ if self.version == "v1":
889
+ ge = self.ref_enc(y * y_mask, y_mask)
890
+ else:
891
+ ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
892
+ if self.is_v2pro:
893
+ sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
894
+ ge += sv_emb.unsqueeze(-1)
895
+ ge = self.prelu(ge)
896
+ ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
897
+ with autocast(enabled=False):
898
+ maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
899
+ with maybe_no_grad:
900
+ if self.freeze_quantizer:
901
+ self.ssl_proj.eval()
902
+ self.quantizer.eval()
903
+ ssl = self.ssl_proj(ssl)
904
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
905
+
906
+ if self.semantic_frame_rate == "25hz":
907
+ quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
908
+
909
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
910
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
911
+ z_p = self.flow(z, y_mask, g=ge)
912
+
913
+ z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
914
+ o = self.dec(z_slice, g=ge)
915
+ return (
916
+ o,
917
+ commit_loss,
918
+ ids_slice,
919
+ y_mask,
920
+ y_mask,
921
+ (z, z_p, m_p, logs_p, m_q, logs_q),
922
+ quantized,
923
+ )
924
+
925
+ def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
926
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
927
+ if self.version == "v1":
928
+ ge = self.ref_enc(y * y_mask, y_mask)
929
+ else:
930
+ ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
931
+
932
+ ssl = self.ssl_proj(ssl)
933
+ quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
934
+ if self.semantic_frame_rate == "25hz":
935
+ quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
936
+
937
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
938
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
939
+
940
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
941
+
942
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
943
+ return o, y_mask, (z, z_p, m_p, logs_p)
944
+
945
+ def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
946
+ def get_ge(refer, sv_emb):
947
+ ge = None
948
+ if refer is not None:
949
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
950
+ refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
951
+ if self.version == "v1":
952
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
953
+ else:
954
+ ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
955
+ if self.is_v2pro:
956
+ sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
957
+ ge += sv_emb.unsqueeze(-1)
958
+ ge = self.prelu(ge)
959
+ return ge
960
+
961
+ if type(refer) == list:
962
+ ges = []
963
+ for idx, _refer in enumerate(refer):
964
+ ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
965
+ ges.append(ge)
966
+ ge = torch.stack(ges, 0).mean(0)
967
+ else:
968
+ ge = get_ge(refer, sv_emb)
969
+
970
+ y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
971
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
972
+
973
+ quantized = self.quantizer.decode(codes)
974
+ if self.semantic_frame_rate == "25hz":
975
+ quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
976
+ x, m_p, logs_p, y_mask = self.enc_p(
977
+ quantized,
978
+ y_lengths,
979
+ text,
980
+ text_lengths,
981
+ self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
982
+ speed,
983
+ )
984
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
985
+
986
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
987
+
988
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
989
+ return o
990
+
991
+ def extract_latent(self, x) -> torch.Tensor:
992
+ ssl = self.ssl_proj(x)
993
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
994
+ return codes.transpose(0, 1)
995
+
996
+
997
+ class CFM(torch.nn.Module):
998
+ def __init__(self, in_channels, dit):
999
+ super().__init__()
1000
+ self.sigma_min = 1e-6
1001
+
1002
+ self.estimator = dit
1003
+
1004
+ self.in_channels = in_channels
1005
+
1006
+ self.criterion = torch.nn.MSELoss()
1007
+
1008
+ self.use_conditioner_cache = True
1009
+
1010
+ @torch.inference_mode()
1011
+ def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
1012
+ """Forward diffusion"""
1013
+ B, T = mu.size(0), mu.size(1)
1014
+ x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
1015
+ prompt_len = prompt.size(-1)
1016
+ prompt_x = torch.zeros_like(x, dtype=mu.dtype)
1017
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
1018
+ x[..., :prompt_len] = 0
1019
+ mu = mu.transpose(2, 1)
1020
+ t = 0
1021
+ d = 1 / n_timesteps
1022
+ text_cache = None
1023
+ text_cfg_cache = None
1024
+ dt_cache = None
1025
+ d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
1026
+ for j in range(n_timesteps):
1027
+ t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
1028
+ # v_pred = model(x, t_tensor, d_tensor, **extra_args)
1029
+ v_pred, text_emb, dt = self.estimator(
1030
+ x,
1031
+ prompt_x,
1032
+ x_lens,
1033
+ t_tensor,
1034
+ d_tensor,
1035
+ mu,
1036
+ use_grad_ckpt=False,
1037
+ drop_audio_cond=False,
1038
+ drop_text=False,
1039
+ infer=True,
1040
+ text_cache=text_cache,
1041
+ dt_cache=dt_cache,
1042
+ )
1043
+ v_pred = v_pred.transpose(2, 1)
1044
+ if self.use_conditioner_cache:
1045
+ text_cache = text_emb
1046
+ dt_cache = dt
1047
+ if inference_cfg_rate > 1e-5:
1048
+ neg, text_cfg_emb, _ = self.estimator(
1049
+ x,
1050
+ prompt_x,
1051
+ x_lens,
1052
+ t_tensor,
1053
+ d_tensor,
1054
+ mu,
1055
+ use_grad_ckpt=False,
1056
+ drop_audio_cond=True,
1057
+ drop_text=True,
1058
+ infer=True,
1059
+ text_cache=text_cfg_cache,
1060
+ dt_cache=dt_cache,
1061
+ )
1062
+ neg = neg.transpose(2, 1)
1063
+ if self.use_conditioner_cache:
1064
+ text_cfg_cache = text_cfg_emb
1065
+ v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
1066
+ x = x + d * v_pred
1067
+ t = t + d
1068
+ x[:, :, :prompt_len] = 0
1069
+ return x
1070
+
1071
+ def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
1072
+ b, _, t = x1.shape
1073
+ t = torch.rand([b], device=mu.device, dtype=x1.dtype)
1074
+ x0 = torch.randn_like(x1, device=mu.device)
1075
+ vt = x1 - x0
1076
+ xt = x0 + t[:, None, None] * vt
1077
+ dt = torch.zeros_like(t, device=mu.device)
1078
+ prompt = torch.zeros_like(x1)
1079
+ for i in range(b):
1080
+ prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
1081
+ xt[i, :, : prompt_lens[i]] = 0
1082
+ gailv = 0.3 # if ttime()>1736250488 else 0.1
1083
+ if random.random() < gailv:
1084
+ base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
1085
+ d = 1 / torch.pow(2, base)
1086
+ d_input = d.clone()
1087
+ d_input[d_input < 1e-2] = 0
1088
+ # with torch.no_grad():
1089
+ v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
1090
+ # v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
1091
+ x_mid = xt + d[:, None, None] * v_pred_1
1092
+ # v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
1093
+ v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
1094
+ vt = (v_pred_1 + v_pred_2) / 2
1095
+ vt = vt.detach()
1096
+ dt = 2 * d
1097
+
1098
+ vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
1099
+ loss = 0
1100
+ for i in range(b):
1101
+ loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
1102
+ loss /= b
1103
+
1104
+ return loss
1105
+
1106
+
1107
+ def set_no_grad(net_g):
1108
+ for name, param in net_g.named_parameters():
1109
+ param.requires_grad = False
1110
+
1111
+
1112
+ class SynthesizerTrnV3(nn.Module):
1113
+ """
1114
+ Synthesizer for Training
1115
+ """
1116
+
1117
+ def __init__(
1118
+ self,
1119
+ spec_channels,
1120
+ segment_size,
1121
+ inter_channels,
1122
+ hidden_channels,
1123
+ filter_channels,
1124
+ n_heads,
1125
+ n_layers,
1126
+ kernel_size,
1127
+ p_dropout,
1128
+ resblock,
1129
+ resblock_kernel_sizes,
1130
+ resblock_dilation_sizes,
1131
+ upsample_rates,
1132
+ upsample_initial_channel,
1133
+ upsample_kernel_sizes,
1134
+ n_speakers=0,
1135
+ gin_channels=0,
1136
+ use_sdp=True,
1137
+ semantic_frame_rate=None,
1138
+ freeze_quantizer=None,
1139
+ version="v3",
1140
+ **kwargs,
1141
+ ):
1142
+ super().__init__()
1143
+ self.spec_channels = spec_channels
1144
+ self.inter_channels = inter_channels
1145
+ self.hidden_channels = hidden_channels
1146
+ self.filter_channels = filter_channels
1147
+ self.n_heads = n_heads
1148
+ self.n_layers = n_layers
1149
+ self.kernel_size = kernel_size
1150
+ self.p_dropout = p_dropout
1151
+ self.resblock = resblock
1152
+ self.resblock_kernel_sizes = resblock_kernel_sizes
1153
+ self.resblock_dilation_sizes = resblock_dilation_sizes
1154
+ self.upsample_rates = upsample_rates
1155
+ self.upsample_initial_channel = upsample_initial_channel
1156
+ self.upsample_kernel_sizes = upsample_kernel_sizes
1157
+ self.segment_size = segment_size
1158
+ self.n_speakers = n_speakers
1159
+ self.gin_channels = gin_channels
1160
+ self.version = version
1161
+
1162
+ self.model_dim = 512
1163
+ self.use_sdp = use_sdp
1164
+ self.enc_p = TextEncoder(
1165
+ inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
1166
+ )
1167
+ self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
1168
+
1169
+ ssl_dim = 768
1170
+ assert semantic_frame_rate in ["25hz", "50hz"]
1171
+ self.semantic_frame_rate = semantic_frame_rate
1172
+ if semantic_frame_rate == "25hz":
1173
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
1174
+ else:
1175
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
1176
+
1177
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
1178
+ self.freeze_quantizer = freeze_quantizer
1179
+ inter_channels2 = 512
1180
+ self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
1181
+ self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
1182
+ self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
1183
+ self.cfm = CFM(
1184
+ 100,
1185
+ DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
1186
+ ) # text_dim is condition feature dim
1187
+ if self.freeze_quantizer is True:
1188
+ set_no_grad(self.ssl_proj)
1189
+ set_no_grad(self.quantizer)
1190
+ set_no_grad(self.enc_p)
1191
+
1192
+ def forward(
1193
+ self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
1194
+ ): # ssl_lengths no need now
1195
+ with autocast(enabled=False):
1196
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
1197
+ ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
1198
+ maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
1199
+ with maybe_no_grad:
1200
+ if self.freeze_quantizer:
1201
+ self.ssl_proj.eval() #
1202
+ self.quantizer.eval()
1203
+ self.enc_p.eval()
1204
+ ssl = self.ssl_proj(ssl)
1205
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
1206
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
1207
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
1208
+ fea = self.bridge(x)
1209
+ fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
1210
+ fea, y_mask_ = self.wns1(
1211
+ fea, mel_lengths, ge
1212
+ ) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
1213
+ B = ssl.shape[0]
1214
+ prompt_len_max = mel_lengths * 2 / 3
1215
+ prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
1216
+ minn = min(mel.shape[-1], fea.shape[-1])
1217
+ mel = mel[:, :, :minn]
1218
+ fea = fea[:, :, :minn]
1219
+ cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
1220
+ return cfm_loss
1221
+
1222
+ @torch.no_grad()
1223
+ def decode_encp(self, codes, text, refer, ge=None, speed=1):
1224
+ # print(2333333,refer.shape)
1225
+ # ge=None
1226
+ if ge is None:
1227
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
1228
+ refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
1229
+ ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
1230
+ y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
1231
+ if speed == 1:
1232
+ sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4))
1233
+ else:
1234
+ sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1
1235
+ y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
1236
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1237
+
1238
+ quantized = self.quantizer.decode(codes)
1239
+ if self.semantic_frame_rate == "25hz":
1240
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
1241
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
1242
+ fea = self.bridge(x)
1243
+ fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
1244
+ ####more wn paramter to learn mel
1245
+ fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
1246
+ return fea, ge
1247
+
1248
+ def extract_latent(self, x):
1249
+ ssl = self.ssl_proj(x)
1250
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
1251
+ return codes.transpose(0, 1)
1252
+
1253
+
1254
+ class SynthesizerTrnV3b(nn.Module):
1255
+ """
1256
+ Synthesizer for Training
1257
+ """
1258
+
1259
+ def __init__(
1260
+ self,
1261
+ spec_channels,
1262
+ segment_size,
1263
+ inter_channels,
1264
+ hidden_channels,
1265
+ filter_channels,
1266
+ n_heads,
1267
+ n_layers,
1268
+ kernel_size,
1269
+ p_dropout,
1270
+ resblock,
1271
+ resblock_kernel_sizes,
1272
+ resblock_dilation_sizes,
1273
+ upsample_rates,
1274
+ upsample_initial_channel,
1275
+ upsample_kernel_sizes,
1276
+ n_speakers=0,
1277
+ gin_channels=0,
1278
+ use_sdp=True,
1279
+ semantic_frame_rate=None,
1280
+ freeze_quantizer=None,
1281
+ **kwargs,
1282
+ ):
1283
+ super().__init__()
1284
+ self.spec_channels = spec_channels
1285
+ self.inter_channels = inter_channels
1286
+ self.hidden_channels = hidden_channels
1287
+ self.filter_channels = filter_channels
1288
+ self.n_heads = n_heads
1289
+ self.n_layers = n_layers
1290
+ self.kernel_size = kernel_size
1291
+ self.p_dropout = p_dropout
1292
+ self.resblock = resblock
1293
+ self.resblock_kernel_sizes = resblock_kernel_sizes
1294
+ self.resblock_dilation_sizes = resblock_dilation_sizes
1295
+ self.upsample_rates = upsample_rates
1296
+ self.upsample_initial_channel = upsample_initial_channel
1297
+ self.upsample_kernel_sizes = upsample_kernel_sizes
1298
+ self.segment_size = segment_size
1299
+ self.n_speakers = n_speakers
1300
+ self.gin_channels = gin_channels
1301
+
1302
+ self.model_dim = 512
1303
+ self.use_sdp = use_sdp
1304
+ self.enc_p = TextEncoder(
1305
+ inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
1306
+ )
1307
+ # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
1308
+ self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
1309
+ self.dec = Generator(
1310
+ inter_channels,
1311
+ resblock,
1312
+ resblock_kernel_sizes,
1313
+ resblock_dilation_sizes,
1314
+ upsample_rates,
1315
+ upsample_initial_channel,
1316
+ upsample_kernel_sizes,
1317
+ gin_channels=gin_channels,
1318
+ )
1319
+ self.enc_q = PosteriorEncoder(
1320
+ spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
1321
+ )
1322
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
1323
+
1324
+ ssl_dim = 768
1325
+ assert semantic_frame_rate in ["25hz", "50hz"]
1326
+ self.semantic_frame_rate = semantic_frame_rate
1327
+ if semantic_frame_rate == "25hz":
1328
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
1329
+ else:
1330
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
1331
+
1332
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
1333
+ self.freeze_quantizer = freeze_quantizer
1334
+
1335
+ inter_channels2 = 512
1336
+ self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
1337
+ self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
1338
+ self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
1339
+ self.cfm = CFM(
1340
+ 100,
1341
+ DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
1342
+ ) # text_dim is condition feature dim
1343
+
1344
+ def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
1345
+ with autocast(enabled=False):
1346
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
1347
+ ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
1348
+ # ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
1349
+ # ge=None
1350
+ maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
1351
+ with maybe_no_grad:
1352
+ if self.freeze_quantizer:
1353
+ self.ssl_proj.eval()
1354
+ self.quantizer.eval()
1355
+ ssl = self.ssl_proj(ssl)
1356
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
1357
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
1358
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
1359
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
1360
+ z_p = self.flow(z, y_mask, g=ge)
1361
+ z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
1362
+ o = self.dec(z_slice, g=ge)
1363
+ fea = self.bridge(x)
1364
+ fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
1365
+ fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
1366
+ learned_mel = self.linear_mel(fea)
1367
+ B = ssl.shape[0]
1368
+ prompt_len_max = mel_lengths * 2 / 3
1369
+ prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
1370
+ minn = min(mel.shape[-1], fea.shape[-1])
1371
+ mel = mel[:, :, :minn]
1372
+ fea = fea[:, :, :minn]
1373
+ cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
1374
+ return (
1375
+ commit_loss,
1376
+ cfm_loss,
1377
+ F.mse_loss(learned_mel, mel),
1378
+ o,
1379
+ ids_slice,
1380
+ y_mask,
1381
+ y_mask,
1382
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1383
+ quantized,
1384
+ )
1385
+
1386
+ @torch.no_grad()
1387
+ def decode_encp(self, codes, text, refer, ge=None):
1388
+ # print(2333333,refer.shape)
1389
+ # ge=None
1390
+ if ge is None:
1391
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
1392
+ refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
1393
+ ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
1394
+ y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
1395
+ y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
1396
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1397
+
1398
+ quantized = self.quantizer.decode(codes)
1399
+ if self.semantic_frame_rate == "25hz":
1400
+ quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
1401
+ x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
1402
+ fea = self.bridge(x)
1403
+ fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
1404
+ ####more wn paramter to learn mel
1405
+ fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
1406
+ return fea, ge
1407
+
1408
+ def extract_latent(self, x) -> torch.Tensor:
1409
+ ssl = self.ssl_proj(x)
1410
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
1411
+ return codes.transpose(0, 1)