Spaces:
Running
Running
Commit
·
0f1d9a2
1
Parent(s):
273b181
Add vits model and normalizing flow. Jupyter Notebook as example call
Browse files- pvq_manipulation/Example_Notebook.ipynb +331 -0
- pvq_manipulation/helper/characters.yaml +4 -0
- pvq_manipulation/helper/moving_batch_norm.py +140 -0
- pvq_manipulation/helper/utils.py +228 -0
- pvq_manipulation/helper/vad.py +193 -0
- pvq_manipulation/models/ffjord.py +247 -0
- pvq_manipulation/models/hubert.py +207 -0
- pvq_manipulation/models/ode_functions.py +96 -0
- pvq_manipulation/models/vits.py +742 -0
- setup.py +13 -0
pvq_manipulation/Example_Notebook.ipynb
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "f0e32cd2-4955-4140-8f48-9751a1a8c588",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import numpy as np \n",
|
| 11 |
+
"from pathlib import Path\n",
|
| 12 |
+
"import padertorch as pt\n",
|
| 13 |
+
"import paderbox as pb\n",
|
| 14 |
+
"import time\n",
|
| 15 |
+
"import torch\n",
|
| 16 |
+
"import torchaudio\n",
|
| 17 |
+
"import ipywidgets as widgets\n",
|
| 18 |
+
"from onnxruntime import InferenceSession\n",
|
| 19 |
+
"from pvq_manipulation.models.vits import Vits_NT\n",
|
| 20 |
+
"from pvq_manipulation.models.ffjord import FFJORD\n",
|
| 21 |
+
"from IPython.display import display, Audio, clear_output\n",
|
| 22 |
+
"from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER\n",
|
| 23 |
+
"from paderbox.transform.module_resample import resample_sox\n",
|
| 24 |
+
"from pvq_manipulation.helper.vad import EnergyVAD\n",
|
| 25 |
+
"from train_tts_nt.helper.utils import rms_norm"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"id": "d4df1db0-8439-4573-9dc2-5d578e8befa1",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"source": [
|
| 33 |
+
"# load TTS model"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"id": "e6691176-6119-4bf0-9dcf-44d657c76074",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"storage_dir_tts = Path(\"./Saved_models/tts_model/\")\n",
|
| 44 |
+
"tts_model = Vits_NT.load_model(storage_dir_tts, checkpoint=\"checkpoint_390000.pth\")"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "markdown",
|
| 49 |
+
"id": "c9c7541c-fab5-4d44-9b89-a26a34343e7c",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"source": [
|
| 52 |
+
"# load normalizing flow"
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"cell_type": "code",
|
| 57 |
+
"execution_count": null,
|
| 58 |
+
"id": "e4a55082-c6c6-4283-96ed-217553f33bcd",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"outputs": [],
|
| 61 |
+
"source": [
|
| 62 |
+
"storage_dir_normalizing_flow = Path(\"./Saved_models/norm_flow\")\n",
|
| 63 |
+
"config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / \"config.yaml\")\n",
|
| 64 |
+
"normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint=\"checkpoints/ckpt_best_loss.pth\")"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"id": "deebed07-b28c-49de-b30f-d80b9e1c6899",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"source": [
|
| 72 |
+
"# load hubert features model"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": null,
|
| 78 |
+
"id": "bc4627e1-bac7-4533-8cac-bbc296889855",
|
| 79 |
+
"metadata": {},
|
| 80 |
+
"outputs": [],
|
| 81 |
+
"source": [
|
| 82 |
+
"hubert_model = HubertExtractor(\n",
|
| 83 |
+
" layer=SID_LARGE_LAYER,\n",
|
| 84 |
+
" model_name=\"HUBERT_LARGE\",\n",
|
| 85 |
+
" backend=\"torchaudio\",\n",
|
| 86 |
+
" device='cpu', \n",
|
| 87 |
+
" storage_dir='/net/vol/rautenberg/storage/hubert'# target storage dir hubert model\n",
|
| 88 |
+
")"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "markdown",
|
| 93 |
+
"id": "c78fa11b-8617-4175-902c-8af0e4491201",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"source": [
|
| 96 |
+
"# Example Synthesis"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"execution_count": null,
|
| 102 |
+
"id": "4e8afa1b-b02e-4a40-982d-36aa78f37a57",
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"speaker_id = 1034\n",
|
| 107 |
+
"example_id = \"1034_121119_000028_000001\"\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"wav_1 = tts_model.synthesize_from_example({\n",
|
| 110 |
+
" 'text' : \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\", \n",
|
| 111 |
+
" 'd_vector_storage_root': f\"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth\"\n",
|
| 112 |
+
"})\n",
|
| 113 |
+
"display(Audio(wav_1, rate=24_000, normalize=True))"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"id": "feeb1d62-69f2-45c1-a172-16fcfbecd0da",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"# Manipulation Block"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"id": "625368d3-dd35-4da7-a358-7bbac448806c",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"def get_manipulation(\n",
|
| 132 |
+
" example, \n",
|
| 133 |
+
" d_vector,\n",
|
| 134 |
+
" labels,\n",
|
| 135 |
+
" flow, \n",
|
| 136 |
+
" tts_model,\n",
|
| 137 |
+
" manipulation_idx=0,\n",
|
| 138 |
+
" manipulation_fkt=1,\n",
|
| 139 |
+
"):\n",
|
| 140 |
+
" labels_manipulated = labels.clone()\n",
|
| 141 |
+
" labels_manipulated[:,manipulation_idx] += manipulation_fkt\n",
|
| 142 |
+
" \n",
|
| 143 |
+
" output_forward = flow.forward((d_vector.float(), labels))[0]\n",
|
| 144 |
+
" sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]\n",
|
| 145 |
+
"\n",
|
| 146 |
+
" wav = tts_model.synthesize_from_example({\n",
|
| 147 |
+
" 'text': \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\",\n",
|
| 148 |
+
" 'd_vector': d_vector.detach().numpy(),\n",
|
| 149 |
+
" 'd_vector_man': sampled_class_manipulated.detach().numpy(),\n",
|
| 150 |
+
" }) \n",
|
| 151 |
+
" return wav\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"def extract_speaker_embedding(example):\n",
|
| 154 |
+
" observation, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)\n",
|
| 155 |
+
" observation = resample_sox(observation, in_rate=sr, out_rate=16_000)\n",
|
| 156 |
+
" \n",
|
| 157 |
+
" vad = EnergyVAD(sample_rate=16_000)\n",
|
| 158 |
+
" if observation.ndim == 1:\n",
|
| 159 |
+
" observation = observation[None, :]\n",
|
| 160 |
+
" \n",
|
| 161 |
+
" observation = vad({'audio_data': observation})['audio_data']\n",
|
| 162 |
+
" \n",
|
| 163 |
+
" with torch.no_grad():\n",
|
| 164 |
+
" example = tts_model.speaker_manager.prepare_example({'audio_data': {'observation': observation}, **example})\n",
|
| 165 |
+
" example = pt.data.utils.collate_fn([example])\n",
|
| 166 |
+
" example['features'] = torch.tensor(np.array(example['features']))\n",
|
| 167 |
+
" d_vector = tts_model.speaker_manager.forward(example)[0]\n",
|
| 168 |
+
" return d_vector"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"id": "b722e503-a8f4-4702-acce-20bcdd828846",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [],
|
| 177 |
+
"source": [
|
| 178 |
+
"def load_speaker_labels(example, config_norm_flow, reg_stor_dir=Path('./Saved_models/pvq_extractor/')):\n",
|
| 179 |
+
" audio, _ = torchaudio.load(example['audio_path']['observation'])\n",
|
| 180 |
+
" num_samples = torch.tensor([audio.shape[-1]])\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" if torch.cuda.is_available():\n",
|
| 183 |
+
" audio = audio.cuda()\n",
|
| 184 |
+
" num_samples = num_samples.cuda()\n",
|
| 185 |
+
" providers = [\"CPUExecutionProvider\"]\n",
|
| 186 |
+
"\n",
|
| 187 |
+
" with torch.no_grad():\n",
|
| 188 |
+
" features, seq_len = hubert_model(\n",
|
| 189 |
+
" audio, \n",
|
| 190 |
+
" 24_000, \n",
|
| 191 |
+
" sequence_lengths=num_samples,\n",
|
| 192 |
+
" )\n",
|
| 193 |
+
" features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" pvqd_predictions = {}\n",
|
| 196 |
+
" for pvq in ['Breathiness', 'Loudness', 'Pitch', 'Resonance', 'Roughness', 'Strain', 'Weight']:\n",
|
| 197 |
+
" with open(reg_stor_dir / f\"{pvq}.onnx\", \"rb\") as fid:\n",
|
| 198 |
+
" onnx = fid.read()\n",
|
| 199 |
+
" sess = InferenceSession(onnx, providers=providers)\n",
|
| 200 |
+
" pred = sess.run(None, {\"X\": features[None]})[0].squeeze(1)\n",
|
| 201 |
+
" pvqd_predictions[pvq] = pred.tolist()[0]\n",
|
| 202 |
+
" labels = []\n",
|
| 203 |
+
" for key in config_norm_flow['speaker_conditioning']:\n",
|
| 204 |
+
" labels.append(pvqd_predictions[key]/100)\n",
|
| 205 |
+
" return torch.tensor(labels)"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "markdown",
|
| 210 |
+
"id": "008035ba-6054-4e6e-ab16-1aaaf68f584a",
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"source": [
|
| 213 |
+
"# Get example manipulation"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": null,
|
| 219 |
+
"id": "e921a3cd-1699-495c-b825-519fb706d89d",
|
| 220 |
+
"metadata": {},
|
| 221 |
+
"outputs": [],
|
| 222 |
+
"source": [
|
| 223 |
+
"example = {\n",
|
| 224 |
+
" 'audio_path': {'observation': \"./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav\"},\n",
|
| 225 |
+
" 'speaker_id': 1034,\n",
|
| 226 |
+
" 'example_id': \"1034_121119_000028_000001\",\n",
|
| 227 |
+
"}\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"d_vector = extract_speaker_embedding(example)\n",
|
| 230 |
+
"labels = load_speaker_labels(example, config_norm_flow)\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"wav_manipulated = get_manipulation(\n",
|
| 233 |
+
" example=example, \n",
|
| 234 |
+
" d_vector=d_vector, \n",
|
| 235 |
+
" labels=labels[None, :], \n",
|
| 236 |
+
" flow=normalizing_flow,\n",
|
| 237 |
+
" tts_model=tts_model,\n",
|
| 238 |
+
" manipulation_idx=0,\n",
|
| 239 |
+
" manipulation_fkt=1,\n",
|
| 240 |
+
")"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "code",
|
| 245 |
+
"execution_count": null,
|
| 246 |
+
"id": "09a04e5b-c2ab-43e5-b9df-171028100ab6",
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"outputs": [],
|
| 249 |
+
"source": [
|
| 250 |
+
"example = {\n",
|
| 251 |
+
" 'audio_path': {'observation': \"./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav\"},\n",
|
| 252 |
+
" 'speaker_id': 1034,\n",
|
| 253 |
+
" 'example_id': \"1034_121119_000028_000001\",\n",
|
| 254 |
+
"}\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"label_options = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"manipulation_idx_widget = widgets.Dropdown(\n",
|
| 259 |
+
" options=[(label, i) for i, label in enumerate(label_options)],\n",
|
| 260 |
+
" value=2, # Standardwert: Breathiness\n",
|
| 261 |
+
" description='Type:',\n",
|
| 262 |
+
" style={'description_width': 'initial'}\n",
|
| 263 |
+
")\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"manipulation_fkt_widget = widgets.FloatSlider(\n",
|
| 266 |
+
" value=1.0, min=-2.0, max=2.0, step=0.1,\n",
|
| 267 |
+
" description='Strength:',\n",
|
| 268 |
+
" style={'description_width': 'initial'}\n",
|
| 269 |
+
")\n",
|
| 270 |
+
"\n",
|
| 271 |
+
"run_button = widgets.Button(description=\"Run Manipulation\")\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"audio_output = widgets.Output()\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"def update_manipulation(b):\n",
|
| 276 |
+
" manipulation_idx = manipulation_idx_widget.value\n",
|
| 277 |
+
" manipulation_fkt = manipulation_fkt_widget.value\n",
|
| 278 |
+
" \n",
|
| 279 |
+
" d_vector = extract_speaker_embedding(example)\n",
|
| 280 |
+
" labels = load_speaker_labels(example, config_norm_flow)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" with audio_output:\n",
|
| 283 |
+
" clear_output(wait=True)\n",
|
| 284 |
+
" display(widgets.Label(\"Processing...\"))\n",
|
| 285 |
+
" \n",
|
| 286 |
+
" time.sleep(1) \n",
|
| 287 |
+
" \n",
|
| 288 |
+
" wav_manipulated = get_manipulation(\n",
|
| 289 |
+
" example=example, \n",
|
| 290 |
+
" d_vector=d_vector, \n",
|
| 291 |
+
" labels=labels[None, :], \n",
|
| 292 |
+
" flow=normalizing_flow,\n",
|
| 293 |
+
" tts_model=tts_model,\n",
|
| 294 |
+
" manipulation_idx=manipulation_idx,\n",
|
| 295 |
+
" manipulation_fkt=manipulation_fkt,\n",
|
| 296 |
+
" )\n",
|
| 297 |
+
" \n",
|
| 298 |
+
" with audio_output:\n",
|
| 299 |
+
" clear_output(wait=True) \n",
|
| 300 |
+
" display(Audio(wav_manipulated, rate=24_000, normalize=True))\n",
|
| 301 |
+
" display(Audio(example['audio_path']['observation'], rate=24_000, normalize=True))\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" print(f\"Manipulated {label_options[manipulation_idx]} with strength {manipulation_fkt}\")\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"run_button.on_click(update_manipulation)\n",
|
| 306 |
+
"display(manipulation_idx_widget, manipulation_fkt_widget, run_button, audio_output)"
|
| 307 |
+
]
|
| 308 |
+
}
|
| 309 |
+
],
|
| 310 |
+
"metadata": {
|
| 311 |
+
"kernelspec": {
|
| 312 |
+
"display_name": "voice editing",
|
| 313 |
+
"language": "python",
|
| 314 |
+
"name": "voice_editing"
|
| 315 |
+
},
|
| 316 |
+
"language_info": {
|
| 317 |
+
"codemirror_mode": {
|
| 318 |
+
"name": "ipython",
|
| 319 |
+
"version": 3
|
| 320 |
+
},
|
| 321 |
+
"file_extension": ".py",
|
| 322 |
+
"mimetype": "text/x-python",
|
| 323 |
+
"name": "python",
|
| 324 |
+
"nbconvert_exporter": "python",
|
| 325 |
+
"pygments_lexer": "ipython3",
|
| 326 |
+
"version": "3.11.9"
|
| 327 |
+
}
|
| 328 |
+
},
|
| 329 |
+
"nbformat": 4,
|
| 330 |
+
"nbformat_minor": 5
|
| 331 |
+
}
|
pvq_manipulation/helper/characters.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Yourtts:
|
| 2 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00af\u00b7\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e6\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u00ff\u0101\u0105\u0107\u0113\u0119\u011b\u012b\u0131\u0142\u0144\u014d\u0151\u0153\u015b\u016b\u0171\u017a\u017c\u01ce\u01d0\u01d2\u01d4\u0430\u0431\u0432\u0433\u0434\u0435\u0436\u0437\u0438\u0439\u043a\u043b\u043c\u043d\u043e\u043f\u0440\u0441\u0442\u0443\u0444\u0445\u0446\u0447\u0448\u0449\u044a\u044b\u044c\u044d\u044e\u044f\u0451\u0454\u0456\u0457\u0491\u2013!'(),-.:;? "
|
| 3 |
+
German:
|
| 4 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;?\u00af\u2013\u00fc\u00f6\u00e4\u00df\u201a\u2018\u2019"
|
pvq_manipulation/helper/moving_batch_norm.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This Code is adapted from https://github.com/RameenAbdal/StyleFlow/blob/master/module/normalization.py
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MovingBatchNormNd(nn.Module):
|
| 10 |
+
def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True):
|
| 11 |
+
super(MovingBatchNormNd, self).__init__()
|
| 12 |
+
self.num_features = num_features
|
| 13 |
+
self.affine = affine
|
| 14 |
+
self.eps = eps
|
| 15 |
+
self.decay = decay
|
| 16 |
+
self.bn_lag = bn_lag
|
| 17 |
+
self.register_buffer('step', torch.zeros(1))
|
| 18 |
+
if self.affine:
|
| 19 |
+
self.weight = Parameter(torch.Tensor(num_features))
|
| 20 |
+
self.bias = Parameter(torch.Tensor(num_features))
|
| 21 |
+
else:
|
| 22 |
+
self.register_parameter('weight', None)
|
| 23 |
+
self.register_parameter('bias', None)
|
| 24 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
| 25 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
| 26 |
+
self.reset_parameters()
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def shape(self):
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
|
| 32 |
+
def reset_parameters(self):
|
| 33 |
+
self.running_mean.zero_()
|
| 34 |
+
self.running_var.fill_(1)
|
| 35 |
+
if self.affine:
|
| 36 |
+
self.weight.data.zero_()
|
| 37 |
+
self.bias.data.zero_()
|
| 38 |
+
|
| 39 |
+
def forward(self, x, c=None, logpx=None, reverse=False):
|
| 40 |
+
if reverse:
|
| 41 |
+
return self._reverse(x, logpx)
|
| 42 |
+
else:
|
| 43 |
+
return self._forward(x, logpx)
|
| 44 |
+
|
| 45 |
+
def _forward(self, x, logpx=None):
|
| 46 |
+
num_channels = x.size(-1)
|
| 47 |
+
used_mean = self.running_mean.clone().detach()
|
| 48 |
+
used_var = self.running_var.clone().detach()
|
| 49 |
+
|
| 50 |
+
if self.training:
|
| 51 |
+
# compute batch statistics
|
| 52 |
+
x_t = x.transpose(0, -1).reshape(num_channels, -1)
|
| 53 |
+
batch_mean = torch.mean(x_t, dim=1)
|
| 54 |
+
|
| 55 |
+
batch_var = torch.var(x_t, dim=1)
|
| 56 |
+
|
| 57 |
+
# moving average
|
| 58 |
+
if self.bn_lag > 0:
|
| 59 |
+
used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
|
| 60 |
+
used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
|
| 61 |
+
used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
|
| 62 |
+
used_var /= (1. - self.bn_lag**(self.step[0] + 1))
|
| 63 |
+
|
| 64 |
+
# update running estimates
|
| 65 |
+
self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
|
| 66 |
+
self.running_var -= self.decay * (self.running_var - batch_var.data)
|
| 67 |
+
self.step += 1
|
| 68 |
+
|
| 69 |
+
# perform normalization
|
| 70 |
+
used_mean = used_mean.view(*self.shape).expand_as(x)
|
| 71 |
+
used_var = used_var.view(*self.shape).expand_as(x)
|
| 72 |
+
|
| 73 |
+
y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
|
| 74 |
+
|
| 75 |
+
if self.affine:
|
| 76 |
+
weight = self.weight.view(*self.shape).expand_as(x)
|
| 77 |
+
bias = self.bias.view(*self.shape).expand_as(x)
|
| 78 |
+
y = y * torch.exp(weight) + bias
|
| 79 |
+
|
| 80 |
+
if logpx is None:
|
| 81 |
+
return y
|
| 82 |
+
else:
|
| 83 |
+
#import ipdb
|
| 84 |
+
#ipdb.set_trace()
|
| 85 |
+
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
| 86 |
+
|
| 87 |
+
def _reverse(self, y, logpy=None):
|
| 88 |
+
used_mean = self.running_mean
|
| 89 |
+
used_var = self.running_var
|
| 90 |
+
|
| 91 |
+
if self.affine:
|
| 92 |
+
weight = self.weight.view(*self.shape).expand_as(y)
|
| 93 |
+
bias = self.bias.view(*self.shape).expand_as(y)
|
| 94 |
+
y = (y - bias) * torch.exp(-weight)
|
| 95 |
+
|
| 96 |
+
used_mean = used_mean.view(*self.shape).expand_as(y)
|
| 97 |
+
used_var = used_var.view(*self.shape).expand_as(y)
|
| 98 |
+
x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
|
| 99 |
+
|
| 100 |
+
if logpy is None:
|
| 101 |
+
return x
|
| 102 |
+
else:
|
| 103 |
+
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
| 104 |
+
|
| 105 |
+
def _logdetgrad(self, x, used_var):
|
| 106 |
+
logdetgrad = -0.5 * torch.log(used_var + self.eps)
|
| 107 |
+
if self.affine:
|
| 108 |
+
weight = self.weight.view(*self.shape).expand(*x.size())
|
| 109 |
+
logdetgrad += weight
|
| 110 |
+
return logdetgrad
|
| 111 |
+
|
| 112 |
+
def __repr__(self):
|
| 113 |
+
return (
|
| 114 |
+
'{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
|
| 115 |
+
' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def stable_var(x, mean=None, dim=1):
|
| 120 |
+
if mean is None:
|
| 121 |
+
mean = x.mean(dim, keepdim=True)
|
| 122 |
+
mean = mean.view(-1, 1)
|
| 123 |
+
res = torch.pow(x - mean, 2)
|
| 124 |
+
max_sqr = torch.max(res, dim, keepdim=True)[0]
|
| 125 |
+
var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
|
| 126 |
+
var = var.view(-1)
|
| 127 |
+
# change nan to zero
|
| 128 |
+
var[var != var] = 0
|
| 129 |
+
return var
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class MovingBatchNorm1d(MovingBatchNormNd):
|
| 133 |
+
@property
|
| 134 |
+
def shape(self):
|
| 135 |
+
return [1, -1]
|
| 136 |
+
|
| 137 |
+
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
|
| 138 |
+
ret = super(MovingBatchNorm1d, self).forward(
|
| 139 |
+
x, context, logpx=logpx, reverse=reverse)
|
| 140 |
+
return ret
|
pvq_manipulation/helper/utils.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import paderbox as pb
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from coqpit import Coqpit
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from paderbox.transform.module_resample import resample_sox
|
| 7 |
+
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
from TTS.tts.models.vits import VitsAudioConfig, VitsArgs
|
| 11 |
+
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_audio(file_path, target_sr):
|
| 15 |
+
"""Load the audio file normalized in [-1, 1]
|
| 16 |
+
|
| 17 |
+
Return Shapes:
|
| 18 |
+
- x: :math:`[1, T]`
|
| 19 |
+
"""
|
| 20 |
+
if type(file_path) is dict:
|
| 21 |
+
if 'observation' in file_path:
|
| 22 |
+
file_path = file_path['observation']
|
| 23 |
+
|
| 24 |
+
x, sr = pb.io.load_audio(file_path, return_sample_rate=True)
|
| 25 |
+
if sr != target_sr:
|
| 26 |
+
x = resample_sox(x, in_rate=sr, out_rate=target_sr)
|
| 27 |
+
x = torch.tensor(x, dtype=torch.float32)[None, :]
|
| 28 |
+
x[x < -1] = -1
|
| 29 |
+
x[x > 1] = 1
|
| 30 |
+
assert (x > 1).sum() + (x < -1).sum() == 0
|
| 31 |
+
return x, target_sr
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class VitsAudioConfig_NT(Coqpit):
|
| 36 |
+
fft_size: int = 1024
|
| 37 |
+
sample_rate: int = 16000
|
| 38 |
+
win_length: int = 1024
|
| 39 |
+
hop_length: int = 256
|
| 40 |
+
num_mels: int = 80
|
| 41 |
+
mel_fmin: int = 0
|
| 42 |
+
mel_fmax: int = None
|
| 43 |
+
fading: str = 'half'
|
| 44 |
+
window: str = 'hann'
|
| 45 |
+
pad: bool = True
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class VitsConfig_NT(BaseTTSConfig):
|
| 50 |
+
"""Defines parameters for VITS End2End TTS model.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model (str):
|
| 54 |
+
Model name. Do not change unless you know what you are doing.
|
| 55 |
+
|
| 56 |
+
model_args (VitsArgs):
|
| 57 |
+
Model architecture arguments. Defaults to `VitsArgs()`.
|
| 58 |
+
|
| 59 |
+
audio (VitsAudioConfig):
|
| 60 |
+
Audio processing configuration. Defaults to `VitsAudioConfig()`.
|
| 61 |
+
|
| 62 |
+
grad_clip (List):
|
| 63 |
+
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
| 64 |
+
|
| 65 |
+
lr_gen (float):
|
| 66 |
+
Initial learning rate for the generator. Defaults to 0.0002.
|
| 67 |
+
|
| 68 |
+
lr_disc (float):
|
| 69 |
+
Initial learning rate for the discriminator. Defaults to 0.0002.
|
| 70 |
+
|
| 71 |
+
lr_scheduler_gen (str):
|
| 72 |
+
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
| 73 |
+
`ExponentialLR`.
|
| 74 |
+
|
| 75 |
+
lr_scheduler_gen_params (dict):
|
| 76 |
+
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
| 77 |
+
|
| 78 |
+
lr_scheduler_disc (str):
|
| 79 |
+
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
| 80 |
+
`ExponentialLR`.
|
| 81 |
+
|
| 82 |
+
lr_scheduler_disc_params (dict):
|
| 83 |
+
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
| 84 |
+
|
| 85 |
+
scheduler_after_epoch (bool):
|
| 86 |
+
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
| 87 |
+
|
| 88 |
+
optimizer (str):
|
| 89 |
+
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
| 90 |
+
`torch.optim.*`. Defaults to `AdamW`.
|
| 91 |
+
|
| 92 |
+
kl_loss_alpha (float):
|
| 93 |
+
Loss weight for KL loss. Defaults to 1.0.
|
| 94 |
+
|
| 95 |
+
disc_loss_alpha (float):
|
| 96 |
+
Loss weight for the discriminator loss. Defaults to 1.0.
|
| 97 |
+
|
| 98 |
+
gen_loss_alpha (float):
|
| 99 |
+
Loss weight for the generator loss. Defaults to 1.0.
|
| 100 |
+
|
| 101 |
+
feat_loss_alpha (float):
|
| 102 |
+
Loss weight for the feature matching loss. Defaults to 1.0.
|
| 103 |
+
|
| 104 |
+
mel_loss_alpha (float):
|
| 105 |
+
Loss weight for the mel loss. Defaults to 45.0.
|
| 106 |
+
|
| 107 |
+
return_wav (bool):
|
| 108 |
+
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
| 109 |
+
|
| 110 |
+
compute_linear_spec (bool):
|
| 111 |
+
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
| 112 |
+
|
| 113 |
+
use_weighted_sampler (bool):
|
| 114 |
+
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
| 115 |
+
|
| 116 |
+
weighted_sampler_attrs (dict):
|
| 117 |
+
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
| 118 |
+
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
| 119 |
+
|
| 120 |
+
weighted_sampler_multipliers (dict):
|
| 121 |
+
Weight each unique value of a key returned by the formatter for weighted sampling.
|
| 122 |
+
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
| 123 |
+
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
| 124 |
+
|
| 125 |
+
r (int):
|
| 126 |
+
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
| 127 |
+
|
| 128 |
+
add_blank (bool):
|
| 129 |
+
If true, a blank token is added in between every character. Defaults to `True`.
|
| 130 |
+
|
| 131 |
+
test_sentences (List[List]):
|
| 132 |
+
List of sentences with speaker and language information to be used for testing.
|
| 133 |
+
|
| 134 |
+
language_ids_file (str):
|
| 135 |
+
Path to the language ids file.
|
| 136 |
+
|
| 137 |
+
use_language_embedding (bool):
|
| 138 |
+
If true, language embedding is used. Defaults to `False`.
|
| 139 |
+
|
| 140 |
+
Note:
|
| 141 |
+
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
| 142 |
+
|
| 143 |
+
Example:
|
| 144 |
+
|
| 145 |
+
>>> from TTS.tts.configs.vits_config import VitsConfig
|
| 146 |
+
>>> config = VitsConfig()
|
| 147 |
+
"""
|
| 148 |
+
model: str = "vits"
|
| 149 |
+
# model specific params
|
| 150 |
+
model_args: VitsArgs = field(default_factory=VitsArgs)
|
| 151 |
+
audio: VitsAudioConfig = field(default_factory=VitsAudioConfig)
|
| 152 |
+
|
| 153 |
+
# optimizer
|
| 154 |
+
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000, 1000])
|
| 155 |
+
lr_gen: float = 0.0002
|
| 156 |
+
lr_disc: float = 0.0002
|
| 157 |
+
lr_scheduler_gen: str = "ExponentialLR"
|
| 158 |
+
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
| 159 |
+
lr_scheduler_disc: str = "ExponentialLR"
|
| 160 |
+
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
| 161 |
+
scheduler_after_epoch: bool = True
|
| 162 |
+
optimizer: str = "AdamW"
|
| 163 |
+
optimizer_params: dict = field(
|
| 164 |
+
default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
| 165 |
+
|
| 166 |
+
# loss params
|
| 167 |
+
kl_loss_alpha: float = 1.0
|
| 168 |
+
disc_loss_alpha: float = 1.0
|
| 169 |
+
gen_loss_alpha: float = 1.0
|
| 170 |
+
feat_loss_alpha: float = 1.0
|
| 171 |
+
mel_loss_alpha: float = 45.0
|
| 172 |
+
dur_loss_alpha: float = 1.0
|
| 173 |
+
speaker_encoder_loss_alpha: float = 1.0
|
| 174 |
+
|
| 175 |
+
# data loader params
|
| 176 |
+
return_wav: bool = True
|
| 177 |
+
compute_linear_spec: bool = True
|
| 178 |
+
|
| 179 |
+
# sampler params
|
| 180 |
+
use_weighted_sampler: bool = False # TODO: move it to the base config
|
| 181 |
+
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
| 182 |
+
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
| 183 |
+
|
| 184 |
+
# overrides
|
| 185 |
+
r: int = 1 # DO NOT CHANGE
|
| 186 |
+
add_blank: bool = True
|
| 187 |
+
|
| 188 |
+
# testing
|
| 189 |
+
test_sentences: List[List] = field(
|
| 190 |
+
default_factory=lambda: [
|
| 191 |
+
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
|
| 192 |
+
["Be a voice, not an echo."],
|
| 193 |
+
["I'm sorry Dave. I'm afraid I can't do that."],
|
| 194 |
+
["This cake is great. It's so delicious and moist."],
|
| 195 |
+
["Prior to November 22, 1963."],
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# multi-speaker settings
|
| 200 |
+
# use speaker embedding layer
|
| 201 |
+
num_speakers: int = 0
|
| 202 |
+
use_speaker_embedding: bool = False
|
| 203 |
+
speakers_file: str = None
|
| 204 |
+
speaker_embedding_channels: int = 256
|
| 205 |
+
language_ids_file: str = None
|
| 206 |
+
use_language_embedding: bool = False
|
| 207 |
+
|
| 208 |
+
# use d-vectors
|
| 209 |
+
d_vectors_stor_file: bool = False
|
| 210 |
+
d_vector_model_file: str = None
|
| 211 |
+
d_vector_dim: int = None
|
| 212 |
+
d_vector_model: str = None
|
| 213 |
+
dataset_dict: dict = None
|
| 214 |
+
gan_speaker_conditioning: bool = True
|
| 215 |
+
|
| 216 |
+
sample_rate: int = 16_000
|
| 217 |
+
use_vad: bool = True
|
| 218 |
+
use_phone_labels: bool = False
|
| 219 |
+
|
| 220 |
+
CONFIG_SOLVER: str = ''
|
| 221 |
+
|
| 222 |
+
use_speaker_embedding_cond: bool = True
|
| 223 |
+
|
| 224 |
+
def __post_init__(self):
|
| 225 |
+
for key, val in self.model_args.items():
|
| 226 |
+
if hasattr(self, key):
|
| 227 |
+
self[key] = val
|
| 228 |
+
|
pvq_manipulation/helper/vad.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import paderbox as pb
|
| 3 |
+
import padertorch as pt
|
| 4 |
+
import typing
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pb.utils.functional.partial_decorator
|
| 10 |
+
def conv_smoothing(signal, window_length=7, threshold=3):
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
Boundary effects are visible at beginning and end of signal.
|
| 14 |
+
|
| 15 |
+
Examples:
|
| 16 |
+
>>> conv_smoothing(np.array([False, True, True, True, False, False, False, True]), 3, 2)
|
| 17 |
+
array([False, True, True, True, False, False, False, False])
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
signal:
|
| 21 |
+
window_length:
|
| 22 |
+
threshold:
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
left_context = right_context = (window_length - 1) // 2
|
| 28 |
+
if window_length % 2 == 0:
|
| 29 |
+
right_context += 1
|
| 30 |
+
act_conv = np.sum(pb.array.segment_axis(
|
| 31 |
+
np.pad(signal, (left_context, right_context), mode='constant'),
|
| 32 |
+
length=window_length, shift=1, axis=0, end='cut'
|
| 33 |
+
), axis=-1)
|
| 34 |
+
# act_conv = np.convolve(signal, np.ones(window_length), 's')
|
| 35 |
+
act = act_conv >= threshold
|
| 36 |
+
assert act.shape == signal.shape, (act.shape, signal.shape)
|
| 37 |
+
return act
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class VAD(pt.Configurable):
|
| 42 |
+
smoothing: typing.Optional[typing.Callable] = None
|
| 43 |
+
|
| 44 |
+
def reset(self):
|
| 45 |
+
"""Override for a stateful VAD"""
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
def compute_vad(self, signal, time_resolution=True):
|
| 49 |
+
raise NotImplementedError()
|
| 50 |
+
|
| 51 |
+
def vad_to_time(self, vad, time_length):
|
| 52 |
+
raise NotImplementedError()
|
| 53 |
+
|
| 54 |
+
def __call__(self, signal, time_resolution=True, reset=True):
|
| 55 |
+
if reset:
|
| 56 |
+
self.reset()
|
| 57 |
+
|
| 58 |
+
vad = self.compute_vad(signal)
|
| 59 |
+
|
| 60 |
+
if self.smoothing is not None:
|
| 61 |
+
vad = pb.array.interval.ArrayInterval(self.smoothing(vad))
|
| 62 |
+
|
| 63 |
+
if time_resolution:
|
| 64 |
+
vad = self.vad_to_time(vad, time_length=signal.shape[-1])
|
| 65 |
+
|
| 66 |
+
return vad
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class EnergyVAD(VAD):
|
| 70 |
+
def __init__(self, sample_rate, threshold=0.3):
|
| 71 |
+
self.sample_rate = sample_rate
|
| 72 |
+
self.threshold = threshold
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def remove_silence(signal, vad_mask):
|
| 76 |
+
return signal[vad_mask == 1]
|
| 77 |
+
|
| 78 |
+
def __call__(self, example):
|
| 79 |
+
signal = example['audio_data'] # B T
|
| 80 |
+
vad_mask = self.get_vad_mask(signal)
|
| 81 |
+
signal = self.remove_silence(signal, vad_mask)
|
| 82 |
+
example['audio_data'] = signal
|
| 83 |
+
example['vad_mask'] = vad_mask
|
| 84 |
+
return example
|
| 85 |
+
|
| 86 |
+
def get_vad_mask(self, signal):
|
| 87 |
+
window_size = int(0.1 * self.sample_rate + 1)
|
| 88 |
+
|
| 89 |
+
half_context = (window_size - 1) // 2
|
| 90 |
+
std = np.std(signal, axis=-1, keepdims=True)
|
| 91 |
+
signal = signal - np.mean(signal, axis=-1, keepdims=True)
|
| 92 |
+
signal = np.abs(signal)
|
| 93 |
+
zeros = np.zeros(
|
| 94 |
+
[
|
| 95 |
+
signal.shape[0],
|
| 96 |
+
half_context,
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
signal = np.concatenate([zeros, signal, zeros], axis=1)
|
| 100 |
+
sliding_max = np.max(pb.array.segment_axis(
|
| 101 |
+
signal,
|
| 102 |
+
length=window_size, shift=1, axis=1, end='cut'
|
| 103 |
+
), axis=-1)
|
| 104 |
+
return sliding_max > self.threshold * std
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class ThresholdVAD(VAD):
|
| 109 |
+
"""
|
| 110 |
+
Energy-based VAD for almost clean files. Tested on WSJ clean data by Lukas
|
| 111 |
+
Drude.
|
| 112 |
+
|
| 113 |
+
Attributes:
|
| 114 |
+
threshold: Fraction of total signal standard deviation. Use 0.3 for
|
| 115 |
+
(almost) clean files (SNR >= 20dB, think LibriTTS) and 0.7 for less
|
| 116 |
+
clean files (think LibriSpeech).
|
| 117 |
+
window_size: Size of sliding max window.
|
| 118 |
+
sample_rate: Sampling rate of audio data.
|
| 119 |
+
smoothing: Optional callable that uses a sliding window over the raw
|
| 120 |
+
decision to return a smoothed VAD.
|
| 121 |
+
"""
|
| 122 |
+
threshold: float = 0.3
|
| 123 |
+
window_size: typing.Optional[int] = None
|
| 124 |
+
sample_rate: int = 16_000
|
| 125 |
+
smoothing: typing.Optional[typing.Callable] = None
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def finalize_dogmatic_config(cls, config):
|
| 129 |
+
rate = config['sample_rate']
|
| 130 |
+
config['smoothing'] = {
|
| 131 |
+
'partial': conv_smoothing,
|
| 132 |
+
'window_length': int(0.3 * rate),
|
| 133 |
+
'threshold': int(0.1 * rate),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
def __post_init__(self):
|
| 137 |
+
if self.window_size is None:
|
| 138 |
+
self.window_size = int(0.1 * self.sample_rate + 1)
|
| 139 |
+
|
| 140 |
+
assert self.window_size % 2 == 1, self.window_size
|
| 141 |
+
|
| 142 |
+
def __call__(self, example):
|
| 143 |
+
if isinstance(example, dict):
|
| 144 |
+
signal = example['audio_data']
|
| 145 |
+
if signal.ndim == 2 and signal.shape[0] == 1:
|
| 146 |
+
signal = signal[0]
|
| 147 |
+
elif signal.ndim == 2 and signal.shape[0] != 1:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
'Only mono signals are supported but audio_data has shape '
|
| 150 |
+
f'{signal.shape}'
|
| 151 |
+
)
|
| 152 |
+
vad = super().__call__(signal)
|
| 153 |
+
intervals = np.asarray(vad.intervals)
|
| 154 |
+
start, stop = zip(*intervals)
|
| 155 |
+
example['vad'] = vad
|
| 156 |
+
example['vad_start_samples'] = start
|
| 157 |
+
example['vad_stop_samples'] = stop
|
| 158 |
+
else:
|
| 159 |
+
example = super().__call__(example)
|
| 160 |
+
return example
|
| 161 |
+
|
| 162 |
+
def _detect_voice_activity(self, signal):
|
| 163 |
+
assert signal.ndim == 1, signal.shape
|
| 164 |
+
|
| 165 |
+
half_context = (self.window_size - 1) // 2
|
| 166 |
+
std = np.std(signal)
|
| 167 |
+
signal = signal - np.mean(signal)
|
| 168 |
+
assert np.min(signal) < 0
|
| 169 |
+
assert np.max(signal) > 0
|
| 170 |
+
signal = np.abs(signal)
|
| 171 |
+
|
| 172 |
+
sliding_max = np.max(pb.array.segment_axis(
|
| 173 |
+
np.pad(signal, (half_context, half_context), mode='constant'),
|
| 174 |
+
length=self.window_size, shift=1, axis=0, end='cut'
|
| 175 |
+
), axis=-1)
|
| 176 |
+
|
| 177 |
+
assert sliding_max.shape == signal.shape, (
|
| 178 |
+
sliding_max.shape, signal.shape
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
unconstrained = sliding_max > self.threshold * std
|
| 182 |
+
|
| 183 |
+
return unconstrained
|
| 184 |
+
|
| 185 |
+
def compute_vad(self, signal, time_resolution=True):
|
| 186 |
+
assert time_resolution
|
| 187 |
+
return pb.array.interval.ArrayInterval(
|
| 188 |
+
self._detect_voice_activity(signal)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def vad_to_time(self, vad, time_length):
|
| 192 |
+
assert time_length == vad.shape[-1], (time_length, vad.shape[-1])
|
| 193 |
+
return vad
|
pvq_manipulation/models/ffjord.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import paderbox as pb
|
| 3 |
+
|
| 4 |
+
from padertorch.base import Model
|
| 5 |
+
from torchdiffeq import odeint_adjoint as odeint
|
| 6 |
+
from pvq_manipulation.helper.moving_batch_norm import MovingBatchNorm1d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ODEBlock(torch.nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
ode_function,
|
| 13 |
+
train_flag=True,
|
| 14 |
+
reverse=False,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.time_deriv_func = ode_function
|
| 18 |
+
self.noise = None
|
| 19 |
+
self.reverse = reverse
|
| 20 |
+
self.train_flag = train_flag
|
| 21 |
+
|
| 22 |
+
def forward(
|
| 23 |
+
self,
|
| 24 |
+
time: torch.Tensor,
|
| 25 |
+
states: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
| 26 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 27 |
+
"""
|
| 28 |
+
Helper function to use a neural network for dy(t)/dt = f_theta(t, y(t))
|
| 29 |
+
|
| 30 |
+
Hutchinson’s trace estimator, as proposed in the FFJORD Paper, was adapted from:
|
| 31 |
+
https://github.com/RameenAbdal/StyleFlow/blob/master/module/odefunc.py
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
time (torch.Tensor): Scalar tensor representing time
|
| 35 |
+
states (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
| 36 |
+
- z (torch.Tensor): (batch_size, feature_dim) representing the input data.
|
| 37 |
+
- d_log_dz_dt (torch.Tensor): (batch_size, 1) representing the log derivative.
|
| 38 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 42 |
+
- dz_dt (torch.Tensor): (batch_size, feature_dim) The derivative of z w.r.t. time
|
| 43 |
+
- d_log_dz_dt (torch.Tensor): (batch_size, 1) The negative log derivative
|
| 44 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
z, d_log_dz_dt, labels = states
|
| 48 |
+
|
| 49 |
+
if self.noise is None:
|
| 50 |
+
self.noise = self.sample_rademacher_like(z)
|
| 51 |
+
|
| 52 |
+
with torch.enable_grad():
|
| 53 |
+
z.requires_grad_(True)
|
| 54 |
+
|
| 55 |
+
dz_dt = self.time_deriv_func.forward(time, z, labels)
|
| 56 |
+
if self.train_flag:
|
| 57 |
+
d_log_dz_dt = self.divergence_approx(dz_dt, z, self.noise)
|
| 58 |
+
else:
|
| 59 |
+
d_log_dz_dt = torch.zeros_like(z[:, 0]).requires_grad_(True)
|
| 60 |
+
|
| 61 |
+
labels = torch.zeros_like(labels).requires_grad_(True)
|
| 62 |
+
return dz_dt, -d_log_dz_dt.view(z.shape[0], 1), labels
|
| 63 |
+
|
| 64 |
+
def divergence_approx(self, f, y, e=None):
|
| 65 |
+
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
| 66 |
+
e_dzdx_e = e_dzdx.mul(e)
|
| 67 |
+
|
| 68 |
+
cnt = 0
|
| 69 |
+
while not e_dzdx_e.requires_grad and cnt < 10:
|
| 70 |
+
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
| 71 |
+
e_dzdx_e = e_dzdx * e
|
| 72 |
+
cnt += 1
|
| 73 |
+
|
| 74 |
+
approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
|
| 75 |
+
assert approx_tr_dzdx.requires_grad, \
|
| 76 |
+
"(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
|
| 77 |
+
% (
|
| 78 |
+
f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad,
|
| 79 |
+
e_dzdx_e.requires_grad, cnt)
|
| 80 |
+
return approx_tr_dzdx
|
| 81 |
+
|
| 82 |
+
def before_odeint(self, e=None):
|
| 83 |
+
self.noise = e
|
| 84 |
+
|
| 85 |
+
def sample_rademacher_like(self, z):
|
| 86 |
+
if not self.training:
|
| 87 |
+
torch.manual_seed(0)
|
| 88 |
+
return torch.randint(low=0, high=2, size=z.shape).to(z) * 2 - 1
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class FFJORD(Model):
|
| 92 |
+
"""
|
| 93 |
+
This class is an implementation of the FFJORD model as proposed in
|
| 94 |
+
https://arxiv.org/pdf/1810.01367
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self, ode_function, normalize=True):
|
| 97 |
+
super().__init__()
|
| 98 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 99 |
+
self.input_dim = ode_function.input_dim
|
| 100 |
+
self.time_deriv_func = ODEBlock(ode_function=ode_function)
|
| 101 |
+
self.latent_dist = torch.distributions.MultivariateNormal(
|
| 102 |
+
torch.zeros(self.input_dim, device=device),
|
| 103 |
+
torch.eye(self.input_dim, device=device),
|
| 104 |
+
)
|
| 105 |
+
self.normalize = normalize
|
| 106 |
+
if self.normalize:
|
| 107 |
+
self.input_norm = MovingBatchNorm1d(self.input_dim, bn_lag=0)
|
| 108 |
+
self.output_norm = MovingBatchNorm1d(self.input_dim, bn_lag=0)
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def load_model(model_path, checkpoint):
|
| 112 |
+
model_dict = pb.io.load_yaml(model_path / "config.yaml")
|
| 113 |
+
model = Model.from_config(model_dict['model'])
|
| 114 |
+
cp = torch.load(
|
| 115 |
+
model_path / checkpoint,
|
| 116 |
+
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 117 |
+
)
|
| 118 |
+
model_weights = cp.copy()
|
| 119 |
+
model.load_state_dict(model_weights['model'])
|
| 120 |
+
model.eval()
|
| 121 |
+
return model
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
state: tuple[torch.Tensor, torch.Tensor],
|
| 126 |
+
integration_times: torch.Tensor = torch.tensor([0.0, 1.0]
|
| 127 |
+
)
|
| 128 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 129 |
+
"""
|
| 130 |
+
Integration from t_1 (data distribution) to t_0 (base distribution).
|
| 131 |
+
(training step)
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
state (Tuple[torch.Tensor, torch.Tensor]):
|
| 135 |
+
- z (torch.Tensor): (batch_size, feature_dim) representing the input data.
|
| 136 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
| 137 |
+
|
| 138 |
+
integration_times (torch.Tensor, optional): A tensor of shape (2,)
|
| 139 |
+
specifying the start and end times for integration. Defaults to torch.tensor([0.0, 1.0]).
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 143 |
+
- dz_dt (torch.Tensor): A tensor of shape (batch_size, feature_dim) representing the derivative of z w.r.t. time.
|
| 144 |
+
- -d_log_dz_dt (torch.Tensor): (batch_size, 1) representing the negative log derivative.
|
| 145 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
| 146 |
+
"""
|
| 147 |
+
z_1, labels = state
|
| 148 |
+
|
| 149 |
+
if z_1.dim() == 3:
|
| 150 |
+
z_1 = z_1.squeeze(1)
|
| 151 |
+
|
| 152 |
+
delta_logpz = torch.zeros(z_1.shape[0], 1).to(z_1.device)
|
| 153 |
+
|
| 154 |
+
if self.normalize:
|
| 155 |
+
z_1, delta_logpz = self.input_norm(z_1, context=labels, logpx=delta_logpz)
|
| 156 |
+
|
| 157 |
+
self.time_deriv_func.before_odeint()
|
| 158 |
+
state = odeint(
|
| 159 |
+
self.time_deriv_func, # Calculates time derivatives.
|
| 160 |
+
(z_1, delta_logpz, labels), # Values to update. init states
|
| 161 |
+
integration_times.to(z_1.device), # When to evaluate.
|
| 162 |
+
method='dopri5', # Runge-Kutta
|
| 163 |
+
atol=1e-5, # Error tolerance
|
| 164 |
+
rtol=1e-5, # Error tolerance
|
| 165 |
+
)
|
| 166 |
+
if self.normalize:
|
| 167 |
+
dz_dt, d_delta_log_dz_t = self.output_norm(state[0], context=state[2], logpx=state[1])
|
| 168 |
+
else:
|
| 169 |
+
dz_dt, d_delta_log_dz_t = state[0], state[1]
|
| 170 |
+
|
| 171 |
+
state = (dz_dt, d_delta_log_dz_t, labels)
|
| 172 |
+
|
| 173 |
+
if len(integration_times) == 2:
|
| 174 |
+
state = tuple(s[1] if s.shape[0] > 1 else s[0] for s in state)
|
| 175 |
+
return state
|
| 176 |
+
|
| 177 |
+
def sample(
|
| 178 |
+
self,
|
| 179 |
+
state: tuple[torch.Tensor, torch.Tensor],
|
| 180 |
+
integration_times: torch.Tensor = torch.tensor([1.0, 0.0])
|
| 181 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 182 |
+
"""
|
| 183 |
+
This is the sampling step. Integration from t_0 (base distribution) to t_1 (data distribution).
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
state (Tuple[torch.Tensor, torch.Tensor]):
|
| 187 |
+
- z_0 (torch.Tensor): (batch_size, feature_dim) representing the initial state from the base distribution
|
| 188 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
| 189 |
+
|
| 190 |
+
integration_times (torch.Tensor, optional): A tensor of shape (2,) specifying the start (t_0) and end (t_1) times for integration.
|
| 191 |
+
Defaults to torch.tensor([1.0, 0.0])
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
| 195 |
+
- z_t1 (torch.Tensor): (batch_size, feature_dim) representing the sampled data at time t_1 (data distribution).
|
| 196 |
+
- labels (torch.Tensor): (batch_size, num_labeled_classes)
|
| 197 |
+
"""
|
| 198 |
+
z_0, labels = state
|
| 199 |
+
delta_logpz = torch.zeros(z_0.shape[0], 1).to(z_0.device)
|
| 200 |
+
if self.normalize:
|
| 201 |
+
z_0, delta_logpz = self.output_norm(
|
| 202 |
+
z_0,
|
| 203 |
+
context=labels,
|
| 204 |
+
logpx=delta_logpz,
|
| 205 |
+
reverse=True
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
state = odeint(
|
| 209 |
+
self.time_deriv_func, # Calculates time derivatives.
|
| 210 |
+
(z_0, delta_logpz, labels), # Values to update. init states
|
| 211 |
+
integration_times.to(z_0.device), # When to evaluate.
|
| 212 |
+
method='dopri5', # Runge-Kutta
|
| 213 |
+
atol=1e-5, # Error tolerance
|
| 214 |
+
rtol=1e-5, # Error tolerance
|
| 215 |
+
)
|
| 216 |
+
if self.normalize:
|
| 217 |
+
dz_dt, d_delta_log_dz_t = self.input_norm(
|
| 218 |
+
state[0],
|
| 219 |
+
context=state[2],
|
| 220 |
+
logpx=state[1],
|
| 221 |
+
reverse=True
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
dz_dt, d_delta_log_dz_t = state[0], state[1]
|
| 225 |
+
state = (dz_dt, d_delta_log_dz_t, labels)
|
| 226 |
+
|
| 227 |
+
if len(integration_times) == 2:
|
| 228 |
+
state = tuple(s[1] if s.shape[0] > 1 else s[0] for s in state)
|
| 229 |
+
return state
|
| 230 |
+
|
| 231 |
+
def example_to_device(self, examples, device):
|
| 232 |
+
observations = [example['observation'] for example in examples]
|
| 233 |
+
labels = [example['speaker_conditioning'].tolist() for example in examples if 'speaker_conditioning' in example]
|
| 234 |
+
observations_tensor = torch.tensor(observations, device=device, dtype=torch.float)
|
| 235 |
+
labels_tensor = torch.tensor(labels, device=device, dtype=torch.float) if labels else None
|
| 236 |
+
return observations_tensor, labels_tensor
|
| 237 |
+
|
| 238 |
+
def review(self, example, outputs):
|
| 239 |
+
z_t0, delta_logpz, _ = outputs
|
| 240 |
+
logpz_t1 = self.latent_dist.log_prob(z_t0) - delta_logpz
|
| 241 |
+
loss = -torch.mean(logpz_t1)
|
| 242 |
+
return dict(loss=loss, scalars=dict(loss=loss))
|
| 243 |
+
|
| 244 |
+
def modify_summary(self, summary):
|
| 245 |
+
summary = super().modify_summary(summary)
|
| 246 |
+
return summary
|
| 247 |
+
|
pvq_manipulation/models/hubert.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
import typing as tp
|
| 5 |
+
from typing import List, Tuple, Optional
|
| 6 |
+
|
| 7 |
+
import einops
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchaudio
|
| 12 |
+
|
| 13 |
+
import padertorch as pt
|
| 14 |
+
from padertorch.contrib.je.modules.conv_utils import (
|
| 15 |
+
compute_conv_output_sequence_lengths
|
| 16 |
+
)
|
| 17 |
+
from padertorch.utils import to_numpy
|
| 18 |
+
from transformers.models.hubert.modeling_hubert import HubertModel
|
| 19 |
+
|
| 20 |
+
# See https://ieeexplore.ieee.org/abstract/document/9814838, Fig. 2
|
| 21 |
+
PR_BASE_LAYER = 11
|
| 22 |
+
PR_LARGE_LAYER = 22
|
| 23 |
+
SID_BASE_LAYER = 4
|
| 24 |
+
SID_LARGE_LAYER = 6
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def tuple_to_int(sequence) -> list:
|
| 28 |
+
return list(map(lambda t: t[0], sequence))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HubertExtractor(pt.Module):
|
| 32 |
+
"""Extract HuBERT features from raw waveform.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model_name (str): Name of the pretrained HuBERT model on huggingface.co.
|
| 36 |
+
Defaults to "facebook/hubert-large-ll60k".
|
| 37 |
+
layer (int): Index of the layer to extract features from. Defaults to
|
| 38 |
+
22.
|
| 39 |
+
freeze (bool): If True, freeze the weights of the encoder
|
| 40 |
+
(i.e., no finetuning of Transformer layers). Defaults to True.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
model_name: str = "facebook/hubert-large-ll60k",
|
| 46 |
+
layer: tp.Union[int, str] = PR_LARGE_LAYER,
|
| 47 |
+
freeze: bool = True,
|
| 48 |
+
detach: bool = False,
|
| 49 |
+
device: str = "cpu",
|
| 50 |
+
backend: str = "torchaudio",
|
| 51 |
+
storage_dir: str = None,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
if not freeze and detach:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
'detach=True only supported if freeze=True\n'
|
| 58 |
+
f'Got: freeze={freeze}, detach={detach}'
|
| 59 |
+
)
|
| 60 |
+
if backend == "torchaudio":
|
| 61 |
+
bundle = getattr(torchaudio.pipelines, model_name)
|
| 62 |
+
self.model = bundle.get_model(dl_kwargs={'model_dir': storage_dir}).eval().to(device)
|
| 63 |
+
self.sampling_rate = bundle.sample_rate
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f'Unknown backend: {backend}')
|
| 66 |
+
self.backend = backend
|
| 67 |
+
|
| 68 |
+
if freeze:
|
| 69 |
+
for param in self.model.parameters():
|
| 70 |
+
param.requires_grad = False
|
| 71 |
+
else:
|
| 72 |
+
# Always freeze feature extractor and feature projection layers
|
| 73 |
+
for param in self.model.feature_extractor.parameters():
|
| 74 |
+
param.requires_grad = False
|
| 75 |
+
for param in self.model.feature_projection.parameters():
|
| 76 |
+
param.requires_grad = False
|
| 77 |
+
|
| 78 |
+
self.layer = layer
|
| 79 |
+
self.freeze = freeze
|
| 80 |
+
self.detach = detach
|
| 81 |
+
|
| 82 |
+
if self.layer == 'all':
|
| 83 |
+
self.weights = torch.nn.Parameter(
|
| 84 |
+
torch.ones(24), requires_grad=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def cache_dir(self):
|
| 89 |
+
return Path(os.environ['STORAGE_ROOT']) / 'huggingface' / 'hub'
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def context(self):
|
| 93 |
+
if self.detach:
|
| 94 |
+
return torch.no_grad()
|
| 95 |
+
else:
|
| 96 |
+
return nullcontext()
|
| 97 |
+
|
| 98 |
+
def compute_output_lengths(
|
| 99 |
+
self, input_lengths: Optional[List[int]]
|
| 100 |
+
) -> Optional[List[int]]:
|
| 101 |
+
"""Compute the number of time frames for each batch entry.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
input_lengths: List with number of samples per batch entry.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List with number of time frames per batch entry.
|
| 108 |
+
"""
|
| 109 |
+
if input_lengths is None:
|
| 110 |
+
return input_lengths
|
| 111 |
+
output_lengths = np.asarray(input_lengths) + self.window_size - 1
|
| 112 |
+
for kernel_size, dilation, stride in zip(
|
| 113 |
+
self.kernel_sizes, self.dilations, self.strides,
|
| 114 |
+
):
|
| 115 |
+
output_lengths = compute_conv_output_sequence_lengths(
|
| 116 |
+
output_lengths,
|
| 117 |
+
kernel_size=kernel_size,
|
| 118 |
+
dilation=dilation,
|
| 119 |
+
pad_type=None,
|
| 120 |
+
stride=stride,
|
| 121 |
+
)
|
| 122 |
+
return output_lengths.tolist()
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
time_signal: torch.Tensor,
|
| 127 |
+
sampling_rate: int,
|
| 128 |
+
sequence_lengths: Optional[List[int]] = None,
|
| 129 |
+
extract_features: bool = False,
|
| 130 |
+
other_inputs: Optional[dict] = None,
|
| 131 |
+
) -> Tuple[torch.Tensor, Optional[List[int]]]:
|
| 132 |
+
"""Extract HuBERT features from raw waveform.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
time_signal: Time signal of shape (batch, 1, time) or (batch, time)
|
| 136 |
+
sampled at 16 kHz.
|
| 137 |
+
sequence_lengths: List with number of samples per batch entry.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
x (torch.Tensor): HuBERT features of shape
|
| 141 |
+
(batch, D, time frames).
|
| 142 |
+
sequence_lengths (List[int]): List with number of time frames per
|
| 143 |
+
batch entry.
|
| 144 |
+
"""
|
| 145 |
+
del other_inputs
|
| 146 |
+
|
| 147 |
+
if time_signal.ndim == 3:
|
| 148 |
+
time_signal = einops.rearrange(time_signal, 'b c t -> (b c) t')
|
| 149 |
+
|
| 150 |
+
time_signal = torchaudio.functional.resample(
|
| 151 |
+
time_signal, sampling_rate, self.sampling_rate
|
| 152 |
+
)
|
| 153 |
+
if sequence_lengths is not None:
|
| 154 |
+
if isinstance(sequence_lengths, (list, tuple)):
|
| 155 |
+
sequence_lengths = torch.tensor(sequence_lengths).long() \
|
| 156 |
+
.to(time_signal.device)
|
| 157 |
+
sequence_lengths = (
|
| 158 |
+
sequence_lengths / sampling_rate * self.sampling_rate
|
| 159 |
+
).long()
|
| 160 |
+
|
| 161 |
+
if self.freeze or self.detach:
|
| 162 |
+
self.model.eval()
|
| 163 |
+
with self.context:
|
| 164 |
+
if self.backend == "torchaudio":
|
| 165 |
+
self.model: torchaudio.models.Wav2Vec2Model
|
| 166 |
+
x, sequence_lengths = self.model.extract_features(
|
| 167 |
+
time_signal, lengths=sequence_lengths,
|
| 168 |
+
num_layers=self.layer,
|
| 169 |
+
)
|
| 170 |
+
if isinstance(self.layer, int):
|
| 171 |
+
x = x[-1].transpose(1, 2)
|
| 172 |
+
else:
|
| 173 |
+
raise NotImplementedError(self.layer)
|
| 174 |
+
return x, sequence_lengths
|
| 175 |
+
|
| 176 |
+
self.model: HubertModel
|
| 177 |
+
n_pad = self.window_size - 1
|
| 178 |
+
time_signal = F.pad(time_signal, (0, n_pad), value=0)
|
| 179 |
+
if extract_features:
|
| 180 |
+
features = self.model.feature_extractor(time_signal.float()) \
|
| 181 |
+
.transpose(1, 2)
|
| 182 |
+
x = self.model.feature_projection(features).transpose(1, 2)
|
| 183 |
+
else:
|
| 184 |
+
outputs = self.model(
|
| 185 |
+
time_signal.float(), output_hidden_states=True
|
| 186 |
+
)
|
| 187 |
+
if isinstance(self.layer, int):
|
| 188 |
+
x = outputs.hidden_states[self.layer].transpose(1, 2)
|
| 189 |
+
if self.detach:
|
| 190 |
+
x = x.detach()
|
| 191 |
+
elif self.layer == 'all':
|
| 192 |
+
hidden_states = []
|
| 193 |
+
for _, hidden_state in enumerate(outputs.hidden_states):
|
| 194 |
+
x = hidden_state.transpose(1, 2)
|
| 195 |
+
if self.detach:
|
| 196 |
+
x = x.detach()
|
| 197 |
+
hidden_states.append(x)
|
| 198 |
+
hidden_states = torch.stack(hidden_states) # (L, B, D, T)
|
| 199 |
+
x = (hidden_states * self.weights[:, None, None, None]) \
|
| 200 |
+
.sum(dim=0)
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f'Unknown layer: {self.layer}')
|
| 203 |
+
|
| 204 |
+
sequence_lengths = to_numpy(sequence_lengths)
|
| 205 |
+
sequence_lengths = self.compute_output_lengths(sequence_lengths)
|
| 206 |
+
|
| 207 |
+
return x.unsqueeze(1), sequence_lengths
|
pvq_manipulation/models/ode_functions.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implementation of Δz = f(t, z, labels)
|
| 3 |
+
f() is a neural network with the architecture defined in StyleFlow
|
| 4 |
+
StyleFlow: Attribute-conditioned Exploration of StyleGAN-Generated Images using Conditional Continuous Normalizing Flows
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CNFNN(torch.nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
input_dim,
|
| 13 |
+
condition_dim,
|
| 14 |
+
hidden_channels,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.layers = torch.nn.ModuleList()
|
| 18 |
+
hidden_dims = hidden_channels + [input_dim]
|
| 19 |
+
self.input_dim = input_dim
|
| 20 |
+
|
| 21 |
+
for idx, hidden_dim in enumerate(hidden_dims):
|
| 22 |
+
self.layers.append(CNFBlock(
|
| 23 |
+
input_dim=input_dim,
|
| 24 |
+
condition_dim=condition_dim,
|
| 25 |
+
output_dim=hidden_dim,
|
| 26 |
+
output_layer=False if idx < len(hidden_dims) - 1 else True,
|
| 27 |
+
))
|
| 28 |
+
input_dim = hidden_dim
|
| 29 |
+
|
| 30 |
+
def forward(self, t, z, labels):
|
| 31 |
+
"""
|
| 32 |
+
This function computes: Δz = f(t, z, labels)
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
t (torch.Tensor): () Time step of the ODE
|
| 36 |
+
z (torch.Tensor): (Batch_size, Input_dim) Intermediate value
|
| 37 |
+
labels (torch.Tensor): (Batch_size, condition_dim) Speaker attributes
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Δz (torch.Tensor): : (Batch_size, Input_dim) Computed delta
|
| 41 |
+
"""
|
| 42 |
+
for layer in self.layers:
|
| 43 |
+
z = layer(t, z, labels)
|
| 44 |
+
return z
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CNFBlock(torch.nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
input_dim,
|
| 51 |
+
output_dim,
|
| 52 |
+
condition_dim,
|
| 53 |
+
output_layer,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self._layer = torch.nn.Linear(input_dim, output_dim)
|
| 57 |
+
self._hyper_bias = torch.nn.Linear(
|
| 58 |
+
1 + condition_dim,
|
| 59 |
+
output_dim,
|
| 60 |
+
bias=False
|
| 61 |
+
)
|
| 62 |
+
self._hyper_gate = torch.nn.Linear(
|
| 63 |
+
1 + condition_dim,
|
| 64 |
+
output_dim
|
| 65 |
+
)
|
| 66 |
+
self.output_layer = output_layer
|
| 67 |
+
|
| 68 |
+
def forward(self, t, z, labels):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
t (torch.Tensor): () Time step of the ODE
|
| 72 |
+
z (torch.Tensor): (Batch_size, Input_dim) Intermediate value
|
| 73 |
+
labels (torch.Tensor): (Batch_size, condition_dim) Speaker attributes
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
z (torch.Tensor): : (Batch_size, Output_dim) Intermediate value
|
| 77 |
+
"""
|
| 78 |
+
if labels.dim() == 1:
|
| 79 |
+
labels = labels[:, None]
|
| 80 |
+
elif labels.dim() == 3:
|
| 81 |
+
labels = labels.squeeze(1)
|
| 82 |
+
|
| 83 |
+
tz_cat = torch.cat((t.expand(z.shape[0], 1), labels), dim=1)
|
| 84 |
+
|
| 85 |
+
gate = torch.sigmoid(self._hyper_gate(tz_cat))
|
| 86 |
+
bias = self._hyper_bias(tz_cat)
|
| 87 |
+
|
| 88 |
+
if z.dim() == 3:
|
| 89 |
+
gate = gate.unsqueeze(1)
|
| 90 |
+
bias = bias.unsqueeze(1)
|
| 91 |
+
|
| 92 |
+
z = self._layer(z) * gate + bias
|
| 93 |
+
|
| 94 |
+
if not self.output_layer:
|
| 95 |
+
z = torch.tanh(z)
|
| 96 |
+
return z
|
pvq_manipulation/models/vits.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is a wrapper for the TTS VITS model.
|
| 3 |
+
TTS.tts.models.vits
|
| 4 |
+
https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/models/vits.py
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import paderbox as pb
|
| 9 |
+
import padertorch as pt
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from coqpit import Coqpit
|
| 13 |
+
from padertorch.ops._stft import STFT
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from pvq_manipulation.helper.utils import VitsAudioConfig_NT, VitsConfig_NT, load_audio
|
| 16 |
+
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from torch.cuda.amp.autocast_mode import autocast
|
| 19 |
+
from TTS.tts.configs.shared_configs import CharactersConfig
|
| 20 |
+
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
| 21 |
+
from TTS.tts.models.vits import Vits, VitsArgs, VitsDataset, spec_to_mel, wav_to_spec
|
| 22 |
+
from TTS.tts.utils.languages import LanguageManager
|
| 23 |
+
from TTS.tts.utils.speakers import SpeakerManager
|
| 24 |
+
from TTS.tts.utils.synthesis import embedding_to_torch, numpy_to_torch
|
| 25 |
+
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
| 26 |
+
from TTS.tts.utils.helpers import generate_path, rand_segments, segment, sequence_mask
|
| 27 |
+
from TTS.utils.audio import AudioProcessor
|
| 28 |
+
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
| 29 |
+
from trainer.trainer import to_cuda
|
| 30 |
+
from typing import Dict, List, Union
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
STORAGE_ROOT = Path(os.getenv('STORAGE_ROOT')).expanduser()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Vits_NT(Vits):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
config: Coqpit,
|
| 40 |
+
ap: "AudioProcessor" = None,
|
| 41 |
+
tokenizer: "TTSTokenizer" = None,
|
| 42 |
+
speaker_manager: SpeakerManager = None,
|
| 43 |
+
language_manager: LanguageManager = None,
|
| 44 |
+
sample_rate: int = None,
|
| 45 |
+
):
|
| 46 |
+
super().__init__(
|
| 47 |
+
config,
|
| 48 |
+
ap,
|
| 49 |
+
tokenizer,
|
| 50 |
+
speaker_manager,
|
| 51 |
+
language_manager
|
| 52 |
+
)
|
| 53 |
+
self.sample_rate = sample_rate
|
| 54 |
+
self.embedded_speaker_dim = self.args.d_vector_dim
|
| 55 |
+
self.posterior_encoder = PosteriorEncoder(
|
| 56 |
+
self.args.out_channels,
|
| 57 |
+
self.args.hidden_channels,
|
| 58 |
+
self.args.hidden_channels,
|
| 59 |
+
kernel_size=self.args.kernel_size_posterior_encoder,
|
| 60 |
+
dilation_rate=self.args.dilation_rate_posterior_encoder,
|
| 61 |
+
num_layers=self.args.num_layers_posterior_encoder,
|
| 62 |
+
cond_channels=self.embedded_speaker_dim,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.flow = ResidualCouplingBlocks(
|
| 66 |
+
self.args.hidden_channels,
|
| 67 |
+
self.args.hidden_channels,
|
| 68 |
+
kernel_size=self.args.kernel_size_flow,
|
| 69 |
+
dilation_rate=self.args.dilation_rate_flow,
|
| 70 |
+
num_layers=self.args.num_layers_flow,
|
| 71 |
+
cond_channels=self.embedded_speaker_dim,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.text_encoder = TextEncoder(
|
| 75 |
+
self.args.num_chars,
|
| 76 |
+
self.args.hidden_channels,
|
| 77 |
+
self.args.hidden_channels,
|
| 78 |
+
self.args.hidden_channels_ffn_text_encoder,
|
| 79 |
+
self.args.num_heads_text_encoder,
|
| 80 |
+
self.args.num_layers_text_encoder,
|
| 81 |
+
self.args.kernel_size_text_encoder,
|
| 82 |
+
self.args.dropout_p_text_encoder,
|
| 83 |
+
language_emb_dim=self.embedded_language_dim,
|
| 84 |
+
)
|
| 85 |
+
self.waveform_decoder = HifiganGenerator(
|
| 86 |
+
self.args.hidden_channels,
|
| 87 |
+
1,
|
| 88 |
+
self.args.resblock_type_decoder,
|
| 89 |
+
self.args.resblock_dilation_sizes_decoder,
|
| 90 |
+
self.args.resblock_kernel_sizes_decoder,
|
| 91 |
+
self.args.upsample_kernel_sizes_decoder,
|
| 92 |
+
self.args.upsample_initial_channel_decoder,
|
| 93 |
+
self.args.upsample_rates_decoder,
|
| 94 |
+
inference_padding=0,
|
| 95 |
+
cond_channels=self.embedded_speaker_dim if self.config.gan_speaker_conditioning else 0,
|
| 96 |
+
conv_pre_weight_norm=False,
|
| 97 |
+
conv_post_weight_norm=False,
|
| 98 |
+
conv_post_bias=False,
|
| 99 |
+
)
|
| 100 |
+
self.speaker_manager = self.speaker_manager
|
| 101 |
+
self.speaker_encoder = self.speaker_manager
|
| 102 |
+
|
| 103 |
+
self.speaker_manager.eval()
|
| 104 |
+
|
| 105 |
+
self.epoch = 0
|
| 106 |
+
self.num_epochs = config['epochs']
|
| 107 |
+
self.lr_lambda = 0
|
| 108 |
+
self.config_solver = config['CONFIG_SOLVER']
|
| 109 |
+
self.config = config
|
| 110 |
+
|
| 111 |
+
self.stft = STFT(
|
| 112 |
+
size=self.config.audio.win_length,
|
| 113 |
+
shift=self.config.audio.hop_length,
|
| 114 |
+
window_length=self.config.audio.win_length,
|
| 115 |
+
fading=self.config.audio.fading,
|
| 116 |
+
window=self.config.audio.window,
|
| 117 |
+
pad=self.config.audio.pad
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def get_spectogram_nt(self, wav):
|
| 121 |
+
"""
|
| 122 |
+
Extracts spectrogram from audio
|
| 123 |
+
Args:
|
| 124 |
+
wav (torch.Tensor): (Batch_size, Num_samples)
|
| 125 |
+
Returns:
|
| 126 |
+
spectrogram (torch.Tensor): (Batch_size, Frequency_bins, Time) spectrogram
|
| 127 |
+
"""
|
| 128 |
+
wav = wav.squeeze(1)
|
| 129 |
+
stft_signal = self.stft(wav)
|
| 130 |
+
stft_signal = torch.einsum('btf-> bft', stft_signal)
|
| 131 |
+
spectrogram = stft_signal.real ** 2 + stft_signal.imag ** 2
|
| 132 |
+
spectrogram = torch.sqrt(spectrogram + 1e-6)
|
| 133 |
+
return spectrogram
|
| 134 |
+
|
| 135 |
+
def get_aux_input_from_test_sentences(self, sentence_info):
|
| 136 |
+
"""
|
| 137 |
+
Get aux input for the inference step from test sentences
|
| 138 |
+
Args:
|
| 139 |
+
sentence_info (dict): Expected keys:
|
| 140 |
+
- "d_vector_storage_root" (str)
|
| 141 |
+
- "d_vector" (torch.Tensor)
|
| 142 |
+
- "d_vector_man" (torch.Tensor) (optional)
|
| 143 |
+
Returns:
|
| 144 |
+
aux_input (dict): aux input for the inference step
|
| 145 |
+
"""
|
| 146 |
+
if 'd_vector' not in sentence_info.keys():
|
| 147 |
+
d_vector_file = sentence_info['d_vector_storage_root']
|
| 148 |
+
d_vector = torch.load(d_vector_file)
|
| 149 |
+
return {"d_vector": d_vector, **sentence_info}
|
| 150 |
+
else:
|
| 151 |
+
return sentence_info
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def init_from_config(
|
| 155 |
+
config: "VitsConfig",
|
| 156 |
+
samples= None,
|
| 157 |
+
verbose=True
|
| 158 |
+
):
|
| 159 |
+
"""
|
| 160 |
+
Initiate model from config
|
| 161 |
+
Args:
|
| 162 |
+
config (VitsConfig): Model config.
|
| 163 |
+
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
| 164 |
+
Defaults to None.
|
| 165 |
+
Returns:
|
| 166 |
+
model (Vits): Initialized model.
|
| 167 |
+
"""
|
| 168 |
+
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
|
| 169 |
+
assert (upsample_rate == config.audio.hop_length), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
| 170 |
+
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
| 171 |
+
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
| 172 |
+
language_manager = LanguageManager.init_from_config(config)
|
| 173 |
+
speaker_manager = pt.Module.from_storage_dir(
|
| 174 |
+
config['d_vector_model_file'],
|
| 175 |
+
checkpoint_name='ckpt_latest.pth',
|
| 176 |
+
consider_mpi=False,
|
| 177 |
+
config_name='config.json',
|
| 178 |
+
)
|
| 179 |
+
speaker_manager.num_speakers = config['num_speakers']
|
| 180 |
+
for param in speaker_manager.parameters():
|
| 181 |
+
param.requires_grad = False
|
| 182 |
+
|
| 183 |
+
return Vits_NT(
|
| 184 |
+
new_config,
|
| 185 |
+
ap,
|
| 186 |
+
tokenizer,
|
| 187 |
+
speaker_manager=speaker_manager,
|
| 188 |
+
language_manager=language_manager,
|
| 189 |
+
sample_rate=config['sample_rate'],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
@torch.no_grad()
|
| 193 |
+
def inference(self, x, aux_input=None):
|
| 194 |
+
"""
|
| 195 |
+
Note:
|
| 196 |
+
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
x (torch.Tensor): (batch_size, T_seq) or (T_seq) Input character sequence IDs
|
| 200 |
+
aux_input (dict): Expected keys:
|
| 201 |
+
- d_vector (torch.Tensor): (batch_size, Feature_dim) speaker_embedding
|
| 202 |
+
- x_lengths: (torch.Tensor): (batch_size) length of each text token
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
- model_outputs (torch.Tensor): (batch_size, T_wav) Synthesized waveform
|
| 206 |
+
"""
|
| 207 |
+
speaker_embedding = aux_input['d_vector'].detach()[:, :, None]
|
| 208 |
+
if aux_input['d_vector_man'] is not None:
|
| 209 |
+
speaker_embedding_man = aux_input['d_vector_man'].detach()[:, :, None]
|
| 210 |
+
else:
|
| 211 |
+
speaker_embedding_man = speaker_embedding
|
| 212 |
+
aux_input['tokens'] = x.clone()
|
| 213 |
+
x_lengths = self._set_x_lengths(x, aux_input)
|
| 214 |
+
x, m_p, logs_p, x_mask = self.text_encoder(
|
| 215 |
+
x,
|
| 216 |
+
x_lengths,
|
| 217 |
+
lang_emb=None
|
| 218 |
+
)
|
| 219 |
+
logw = self.duration_predictor(
|
| 220 |
+
x,
|
| 221 |
+
x_mask,
|
| 222 |
+
g=speaker_embedding,
|
| 223 |
+
lang_emb=None,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
w = torch.exp(logw) * x_mask * self.length_scale
|
| 227 |
+
w_ceil = torch.ceil(w)
|
| 228 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
| 229 |
+
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
| 230 |
+
|
| 231 |
+
attn_mask = x_mask * y_mask.transpose(1, 2)
|
| 232 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
| 233 |
+
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
| 234 |
+
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
| 235 |
+
|
| 236 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
| 237 |
+
|
| 238 |
+
z = self.flow(z_p, y_mask, g=speaker_embedding_man, reverse=True)
|
| 239 |
+
z, _, _, y_mask = self.upsampling_z(
|
| 240 |
+
z,
|
| 241 |
+
y_lengths=y_lengths,
|
| 242 |
+
y_mask=y_mask
|
| 243 |
+
)
|
| 244 |
+
o = self.waveform_decoder(
|
| 245 |
+
(z * y_mask)[:, :, : self.max_inference_len],
|
| 246 |
+
g=speaker_embedding_man if self.config.gan_speaker_conditioning else None
|
| 247 |
+
)
|
| 248 |
+
return o
|
| 249 |
+
|
| 250 |
+
def forward(self, x, x_lengths, y, y_lengths, aux_input, inference=False):
|
| 251 |
+
"""
|
| 252 |
+
Forward pass of the model.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
x (torch.tensor): (Batch, T_seq) Input character sequence IDs
|
| 256 |
+
x_lengths (torch.tensor): (Batch) Input character sequence lengths.
|
| 257 |
+
y (torch.tensor): (Batch_size, Frequency_bins, Time) Input spectrograms.
|
| 258 |
+
y_lengths (torch.tensor): (Batch) Input spectrogram lengths.
|
| 259 |
+
aux_input (dict, optional): Expected keys:
|
| 260 |
+
- d_vector (torch.Tensor): (batch_size, Feature_dim) speaker_embedding
|
| 261 |
+
- waveform: (torch.Tensor): (Batch_size, Num_samples) Target waveform
|
| 262 |
+
Returns:
|
| 263 |
+
Dict: model outputs keyed by the output name.
|
| 264 |
+
"""
|
| 265 |
+
outputs = {}
|
| 266 |
+
speaker_embedding = aux_input['d_vector'].detach()[:, :, None]
|
| 267 |
+
x, m_p, logs_p, x_mask = self.text_encoder(
|
| 268 |
+
x,
|
| 269 |
+
x_lengths,
|
| 270 |
+
lang_emb=None
|
| 271 |
+
)
|
| 272 |
+
z, m_q, logs_q, y_mask = self.posterior_encoder(
|
| 273 |
+
y,
|
| 274 |
+
y_lengths,
|
| 275 |
+
g=speaker_embedding,
|
| 276 |
+
)
|
| 277 |
+
z_p = self.flow(z, y_mask, g=speaker_embedding)
|
| 278 |
+
outputs, attn = self.forward_mas(
|
| 279 |
+
outputs,
|
| 280 |
+
z_p,
|
| 281 |
+
m_p,
|
| 282 |
+
logs_p,
|
| 283 |
+
x,
|
| 284 |
+
x_mask,
|
| 285 |
+
y_mask,
|
| 286 |
+
g=speaker_embedding,
|
| 287 |
+
lang_emb=None,
|
| 288 |
+
)
|
| 289 |
+
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
| 290 |
+
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
| 291 |
+
|
| 292 |
+
z_slice, slice_ids = rand_segments(
|
| 293 |
+
z,
|
| 294 |
+
y_lengths,
|
| 295 |
+
self.spec_segment_size,
|
| 296 |
+
let_short_samples=True,
|
| 297 |
+
pad_short=True
|
| 298 |
+
)
|
| 299 |
+
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(
|
| 300 |
+
z_slice,
|
| 301 |
+
slice_ids=slice_ids,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
wav_seg = segment(
|
| 305 |
+
aux_input['waveform'],
|
| 306 |
+
slice_ids * self.config.audio.hop_length,
|
| 307 |
+
spec_segment_size * self.config.audio.hop_length,
|
| 308 |
+
pad_short=True,
|
| 309 |
+
)
|
| 310 |
+
o = self.waveform_decoder(
|
| 311 |
+
z_slice,
|
| 312 |
+
g=speaker_embedding if self.config.gan_speaker_conditioning else None
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
|
| 316 |
+
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
| 317 |
+
if self.audio_transform is not None:
|
| 318 |
+
wavs_batch = self.audio_transform(wavs_batch)
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
|
| 321 |
+
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
| 322 |
+
else:
|
| 323 |
+
gt_spk_emb, syn_spk_emb = None, None
|
| 324 |
+
|
| 325 |
+
outputs.update(
|
| 326 |
+
{
|
| 327 |
+
"model_outputs": o,
|
| 328 |
+
"alignments": attn.squeeze(1),
|
| 329 |
+
"m_p": m_p,
|
| 330 |
+
"logs_p": logs_p,
|
| 331 |
+
"z": z,
|
| 332 |
+
"z_p": z_p,
|
| 333 |
+
"m_q": m_q,
|
| 334 |
+
"logs_q": logs_q,
|
| 335 |
+
"waveform_seg": wav_seg,
|
| 336 |
+
"gt_spk_emb": gt_spk_emb,
|
| 337 |
+
"syn_spk_emb": syn_spk_emb,
|
| 338 |
+
"slice_ids": slice_ids,
|
| 339 |
+
"z_slice": z_slice,
|
| 340 |
+
"speaker_embedding": speaker_embedding,
|
| 341 |
+
}
|
| 342 |
+
)
|
| 343 |
+
return outputs
|
| 344 |
+
|
| 345 |
+
@staticmethod
|
| 346 |
+
def load_model(model_path, checkpoint):
|
| 347 |
+
"""
|
| 348 |
+
Load model from checkpoint
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
model_path (str): model path
|
| 352 |
+
checkpoint (str): checkpoint name
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
model (pvq_manipulation.models.vits.Vits_NT): model
|
| 356 |
+
"""
|
| 357 |
+
config = pb.io.load_json(model_path / "config.json")
|
| 358 |
+
model_args = VitsArgs(**config['model_args'])
|
| 359 |
+
audio_config = VitsAudioConfig_NT(**config['audio'])
|
| 360 |
+
characters_config = CharactersConfig(**config['characters'])
|
| 361 |
+
del config['audio']
|
| 362 |
+
del config['characters']
|
| 363 |
+
del config['model_args']
|
| 364 |
+
|
| 365 |
+
config = VitsConfig_NT(
|
| 366 |
+
model_args=model_args,
|
| 367 |
+
audio=audio_config,
|
| 368 |
+
characters=characters_config,
|
| 369 |
+
**config,
|
| 370 |
+
)
|
| 371 |
+
model = Vits_NT.init_from_config(config)
|
| 372 |
+
cp = torch.load(
|
| 373 |
+
model_path / checkpoint,
|
| 374 |
+
map_location=torch.device('cpu')
|
| 375 |
+
)
|
| 376 |
+
model_weights = cp['model'].copy()
|
| 377 |
+
model.load_state_dict(model_weights, strict=False)
|
| 378 |
+
model.eval()
|
| 379 |
+
return model
|
| 380 |
+
|
| 381 |
+
def synthesize_from_example(self, s_info):
|
| 382 |
+
"""
|
| 383 |
+
Synthesize voice from example
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
s_info (dict): Expected keys:
|
| 387 |
+
- "speaker_id" (str),
|
| 388 |
+
- "example_id" (str),
|
| 389 |
+
- "audio_path" (str),
|
| 390 |
+
- "d_vector_storage_root" (str),
|
| 391 |
+
- "text" (str) specifying the text to synthesize
|
| 392 |
+
Returns:
|
| 393 |
+
wav (torch.Tensor): synthesized waveform
|
| 394 |
+
"""
|
| 395 |
+
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
| 396 |
+
use_cuda = "cuda" in str(next(self.parameters()).device)
|
| 397 |
+
|
| 398 |
+
device = next(self.parameters()).device
|
| 399 |
+
if use_cuda:
|
| 400 |
+
device = "cuda"
|
| 401 |
+
|
| 402 |
+
text_inputs = np.asarray(
|
| 403 |
+
self.tokenizer.text_to_ids(aux_inputs["text"], language=None),
|
| 404 |
+
dtype=np.int32,
|
| 405 |
+
)
|
| 406 |
+
d_vector = embedding_to_torch(aux_inputs["d_vector"], device=device)
|
| 407 |
+
|
| 408 |
+
if "d_vector_man" in aux_inputs.keys():
|
| 409 |
+
d_vector_man = embedding_to_torch(aux_inputs["d_vector_man"], device=device)
|
| 410 |
+
|
| 411 |
+
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
| 412 |
+
text_inputs = text_inputs.unsqueeze(0)
|
| 413 |
+
|
| 414 |
+
wav = self.inference(
|
| 415 |
+
text_inputs,
|
| 416 |
+
aux_input={
|
| 417 |
+
"x_lengths": torch.tensor(
|
| 418 |
+
text_inputs.shape[1:2]
|
| 419 |
+
).to(text_inputs.device),
|
| 420 |
+
"d_vector": d_vector,
|
| 421 |
+
"d_vector_man": d_vector_man if "d_vector_man" in aux_inputs.keys() else None
|
| 422 |
+
}
|
| 423 |
+
)[0].data.cpu().numpy().squeeze()
|
| 424 |
+
return wav
|
| 425 |
+
|
| 426 |
+
def format_batch_on_device(self, batch):
|
| 427 |
+
"""Format batch on device"""
|
| 428 |
+
ac = self.config.audio
|
| 429 |
+
|
| 430 |
+
batch['waveform'] = to_cuda(batch['waveform'])
|
| 431 |
+
wav = batch["waveform"]
|
| 432 |
+
|
| 433 |
+
batch['spec'] = self.get_spectogram_nt(wav)
|
| 434 |
+
|
| 435 |
+
if self.args.encoder_sample_rate:
|
| 436 |
+
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
| 437 |
+
if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
|
| 438 |
+
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
| 439 |
+
else:
|
| 440 |
+
batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
|
| 441 |
+
else:
|
| 442 |
+
spec_mel = batch["spec"]
|
| 443 |
+
|
| 444 |
+
batch["mel"] = spec_to_mel(
|
| 445 |
+
spec=spec_mel,
|
| 446 |
+
n_fft=ac.fft_size,
|
| 447 |
+
num_mels=ac.num_mels,
|
| 448 |
+
sample_rate=ac.sample_rate,
|
| 449 |
+
fmin=ac.mel_fmin,
|
| 450 |
+
fmax=ac.mel_fmax,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
if self.args.encoder_sample_rate:
|
| 454 |
+
assert batch["spec"].shape[2] == int(
|
| 455 |
+
batch["mel"].shape[2] / self.interpolate_factor
|
| 456 |
+
), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
| 457 |
+
else:
|
| 458 |
+
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
| 459 |
+
|
| 460 |
+
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
| 461 |
+
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
| 462 |
+
|
| 463 |
+
if self.args.encoder_sample_rate:
|
| 464 |
+
assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
|
| 465 |
+
else:
|
| 466 |
+
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
| 467 |
+
|
| 468 |
+
batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1)
|
| 469 |
+
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
| 470 |
+
return batch
|
| 471 |
+
|
| 472 |
+
def train_step(
|
| 473 |
+
self,
|
| 474 |
+
batch: dict,
|
| 475 |
+
criterion: torch.nn.Module,
|
| 476 |
+
optimizer_idx: int,
|
| 477 |
+
):
|
| 478 |
+
"""
|
| 479 |
+
Perform a single training step. Run the model forward pass and compute losses.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
batch (Dict): Input tensors.
|
| 483 |
+
criterion (nn.Module): Loss layer designed for the model.
|
| 484 |
+
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
| 488 |
+
"""
|
| 489 |
+
if optimizer_idx == 0:
|
| 490 |
+
# generator pass
|
| 491 |
+
outputs = self.forward(
|
| 492 |
+
batch["tokens"],
|
| 493 |
+
batch["token_lens"],
|
| 494 |
+
batch["spec"],
|
| 495 |
+
batch["spec_lens"],
|
| 496 |
+
aux_input={
|
| 497 |
+
**batch,
|
| 498 |
+
},
|
| 499 |
+
)
|
| 500 |
+
# cache tensors for the generator pass
|
| 501 |
+
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
| 502 |
+
scores_disc_fake, _, scores_disc_real, _ = self.disc(
|
| 503 |
+
outputs["model_outputs"].detach(),
|
| 504 |
+
outputs["waveform_seg"]
|
| 505 |
+
)
|
| 506 |
+
# compute loss
|
| 507 |
+
with autocast(enabled=False): # use float32 for the criterion
|
| 508 |
+
loss_dict = criterion[optimizer_idx](
|
| 509 |
+
scores_disc_real,
|
| 510 |
+
scores_disc_fake,
|
| 511 |
+
)
|
| 512 |
+
return outputs, loss_dict
|
| 513 |
+
|
| 514 |
+
if optimizer_idx == 1:
|
| 515 |
+
# compute melspec segment
|
| 516 |
+
with autocast(enabled=False):
|
| 517 |
+
if self.args.encoder_sample_rate:
|
| 518 |
+
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
| 519 |
+
else:
|
| 520 |
+
spec_segment_size = self.spec_segment_size
|
| 521 |
+
mel_slice = segment(
|
| 522 |
+
batch["mel"].float(),
|
| 523 |
+
self.model_outputs_cache["slice_ids"],
|
| 524 |
+
spec_segment_size,
|
| 525 |
+
pad_short=True
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
spec = self.get_spectogram_nt(
|
| 529 |
+
self.model_outputs_cache["model_outputs"].float(),
|
| 530 |
+
)
|
| 531 |
+
mel_slice_hat = spec_to_mel(
|
| 532 |
+
spec=spec,
|
| 533 |
+
n_fft=self.config.audio.fft_size,
|
| 534 |
+
num_mels=self.config.audio.num_mels,
|
| 535 |
+
sample_rate=self.config.audio.sample_rate,
|
| 536 |
+
fmin=self.config.audio.mel_fmin,
|
| 537 |
+
fmax=self.config.audio.mel_fmax,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# compute discriminator scores and features
|
| 541 |
+
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
|
| 542 |
+
self.model_outputs_cache["model_outputs"],
|
| 543 |
+
self.model_outputs_cache["waveform_seg"],
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# compute losses
|
| 547 |
+
with autocast(enabled=False): # use float32 for the criterion
|
| 548 |
+
loss_dict = criterion[optimizer_idx](
|
| 549 |
+
mel_slice_hat=mel_slice.float(),
|
| 550 |
+
mel_slice=mel_slice_hat.float(),
|
| 551 |
+
z_p=self.model_outputs_cache["z_p"].float(),
|
| 552 |
+
logs_q=self.model_outputs_cache["logs_q"].float(),
|
| 553 |
+
m_p=self.model_outputs_cache["m_p"].float(),
|
| 554 |
+
logs_p=self.model_outputs_cache["logs_p"].float(),
|
| 555 |
+
z_len=batch["spec_lens"],
|
| 556 |
+
scores_disc_fake=scores_disc_fake,
|
| 557 |
+
feats_disc_fake=feats_disc_fake,
|
| 558 |
+
feats_disc_real=feats_disc_real,
|
| 559 |
+
loss_duration=self.model_outputs_cache["loss_duration"],
|
| 560 |
+
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
| 561 |
+
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
|
| 562 |
+
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
|
| 563 |
+
)
|
| 564 |
+
return self.model_outputs_cache, loss_dict
|
| 565 |
+
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
| 566 |
+
|
| 567 |
+
@torch.no_grad()
|
| 568 |
+
def test_run(self, assets):
|
| 569 |
+
"""Generic test run for `tts` models used by `Trainer`.
|
| 570 |
+
|
| 571 |
+
You can override this for a different behaviour.
|
| 572 |
+
|
| 573 |
+
Returns:
|
| 574 |
+
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
| 575 |
+
"""
|
| 576 |
+
print(" | > Synthesizing test sentences.")
|
| 577 |
+
test_audios = {}
|
| 578 |
+
test_figures = {}
|
| 579 |
+
test_sentences = self.config.test_sentences
|
| 580 |
+
for idx, s_info in enumerate(test_sentences):
|
| 581 |
+
wav = self.synthesize_from_example(s_info)
|
| 582 |
+
test_audios["{}-audio".format(idx)] = wav
|
| 583 |
+
return {"figures": test_figures, "audios": test_audios}
|
| 584 |
+
|
| 585 |
+
def get_data_loader(
|
| 586 |
+
self,
|
| 587 |
+
config: Coqpit,
|
| 588 |
+
assets: Dict,
|
| 589 |
+
is_eval: bool,
|
| 590 |
+
samples: Union[List[Dict], List[List]],
|
| 591 |
+
verbose: bool,
|
| 592 |
+
num_gpus: int,
|
| 593 |
+
rank: int = None,
|
| 594 |
+
) -> "DataLoader":
|
| 595 |
+
dataset = VitsDataset_NT(
|
| 596 |
+
model_args=self.args,
|
| 597 |
+
speaker_manager=self.speaker_manager,
|
| 598 |
+
config=self.config,
|
| 599 |
+
use_phone_labels=config.use_phone_labels,
|
| 600 |
+
sample_rate=self.sample_rate,
|
| 601 |
+
samples=samples,
|
| 602 |
+
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
| 603 |
+
min_text_len=config.min_text_len,
|
| 604 |
+
max_text_len=config.max_text_len,
|
| 605 |
+
min_audio_len=config.min_audio_len,
|
| 606 |
+
max_audio_len=config.max_audio_len,
|
| 607 |
+
phoneme_cache_path=config.phoneme_cache_path,
|
| 608 |
+
precompute_num_workers=config.precompute_num_workers,
|
| 609 |
+
verbose=verbose,
|
| 610 |
+
tokenizer=self.tokenizer,
|
| 611 |
+
start_by_longest=config.start_by_longest,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# sort input sequences from short to long
|
| 615 |
+
dataset.preprocess_samples()
|
| 616 |
+
|
| 617 |
+
# get samplers
|
| 618 |
+
sampler = self.get_sampler(config, dataset, num_gpus)
|
| 619 |
+
loader = DataLoader(
|
| 620 |
+
dataset,
|
| 621 |
+
batch_sampler=sampler,
|
| 622 |
+
collate_fn=dataset.collate_fn,
|
| 623 |
+
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
| 624 |
+
pin_memory=False,
|
| 625 |
+
)
|
| 626 |
+
return loader
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class VitsDataset_NT(VitsDataset):
|
| 630 |
+
def __init__(
|
| 631 |
+
self,
|
| 632 |
+
model_args,
|
| 633 |
+
speaker_manager,
|
| 634 |
+
sample_rate,
|
| 635 |
+
config,
|
| 636 |
+
use_phone_labels,
|
| 637 |
+
*args,
|
| 638 |
+
**kwargs
|
| 639 |
+
):
|
| 640 |
+
super().__init__(model_args, *args, **kwargs)
|
| 641 |
+
self.speaker_manager = speaker_manager
|
| 642 |
+
self.sample_rate = sample_rate
|
| 643 |
+
self.config = config
|
| 644 |
+
self.use_phone_labels = use_phone_labels
|
| 645 |
+
|
| 646 |
+
def __getitem__(self, idx):
|
| 647 |
+
example = self.samples[idx]
|
| 648 |
+
token_ids = self.get_token_ids(idx, example["text"])
|
| 649 |
+
|
| 650 |
+
wav, _ = load_audio(example["audio_file"], target_sr=self.sample_rate)
|
| 651 |
+
|
| 652 |
+
speaker_id = example['speaker_name']
|
| 653 |
+
example_id = example['example_id']
|
| 654 |
+
d_vector = None
|
| 655 |
+
for dataset_dict_sub in self.config.dataset_dict['datasets'].values():
|
| 656 |
+
d_vector_file = dataset_dict_sub['d_vector_storage_root']
|
| 657 |
+
if (Path(d_vector_file) / f'{speaker_id}/{example_id}.pth').is_file():
|
| 658 |
+
d_vector = torch.load(Path(d_vector_file) / f'{speaker_id}/{example_id}.pth')
|
| 659 |
+
break
|
| 660 |
+
if d_vector is None:
|
| 661 |
+
raise ValueError(f'Could not find d_vector for example {example_id}')
|
| 662 |
+
|
| 663 |
+
if d_vector.dim() == 1:
|
| 664 |
+
d_vector = d_vector[None, :]
|
| 665 |
+
return {
|
| 666 |
+
"raw_text": example['text'],
|
| 667 |
+
"token_ids": token_ids,
|
| 668 |
+
"token_len": len(token_ids),
|
| 669 |
+
"wav": wav,
|
| 670 |
+
"d_vector": d_vector,
|
| 671 |
+
"speaker_name": example["speaker_name"]
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
def collate_fn(self, batch):
|
| 675 |
+
"""
|
| 676 |
+
Collate a list of samples from a Dataset into a batch for VITS.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
batch (dict): Expeted keys:
|
| 680 |
+
- wav (list): list of tensors
|
| 681 |
+
- token_ids (list):
|
| 682 |
+
- token_len (list):
|
| 683 |
+
- speaker_name (list):
|
| 684 |
+
- language_name (list):
|
| 685 |
+
- audiofile_path (list):
|
| 686 |
+
- raw_text (list):
|
| 687 |
+
- wav_d_vector (list):
|
| 688 |
+
Returns:
|
| 689 |
+
- tokens (torch.Tensor): (B, T)
|
| 690 |
+
- token_lens (torch.Tensor): (B)
|
| 691 |
+
- token_rel_lens (torch.Tensor): (B)
|
| 692 |
+
- wav (torch.Tensor): (B, 1, T)
|
| 693 |
+
- wav_lens (torch.Tensor): (B)
|
| 694 |
+
- wav_rel_lens (torch.Tensor): (B)
|
| 695 |
+
- speaker_names (torch.Tensor): (B)
|
| 696 |
+
- language_names (torch.Tensor): (B)
|
| 697 |
+
- audiofile_paths (torch.Tensor): (B)
|
| 698 |
+
- raw_texts (torch.Tensor): (B)
|
| 699 |
+
- audio_unique_names (torch.Tensor): (B)
|
| 700 |
+
"""
|
| 701 |
+
B = len(batch)
|
| 702 |
+
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
| 703 |
+
|
| 704 |
+
_, ids_sorted_decreasing = torch.sort(
|
| 705 |
+
torch.LongTensor(
|
| 706 |
+
[
|
| 707 |
+
x.size(1) for x in batch["wav"]]
|
| 708 |
+
),
|
| 709 |
+
dim=0,
|
| 710 |
+
descending=True
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
max_text_len = max([len(x) for x in batch["token_ids"]])
|
| 714 |
+
token_lens = torch.LongTensor(batch["token_len"])
|
| 715 |
+
token_rel_lens = token_lens / token_lens.max()
|
| 716 |
+
|
| 717 |
+
wav_lens = [w.shape[1] for w in batch["wav"]]
|
| 718 |
+
wav_lens = torch.LongTensor(wav_lens)
|
| 719 |
+
wav_lens_max = torch.max(wav_lens)
|
| 720 |
+
wav_rel_lens = wav_lens / wav_lens_max
|
| 721 |
+
|
| 722 |
+
token_padded = torch.LongTensor(B, max_text_len)
|
| 723 |
+
wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
|
| 724 |
+
token_padded = token_padded.zero_() + self.pad_id
|
| 725 |
+
wav_padded = wav_padded.zero_() + self.pad_id
|
| 726 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 727 |
+
token_ids = batch["token_ids"][i]
|
| 728 |
+
token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
|
| 729 |
+
wav = batch["wav"][i]
|
| 730 |
+
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
| 731 |
+
|
| 732 |
+
return {
|
| 733 |
+
"tokens": token_padded,
|
| 734 |
+
"token_lens": token_lens,
|
| 735 |
+
"token_rel_lens": token_rel_lens,
|
| 736 |
+
"waveform": wav_padded,
|
| 737 |
+
"waveform_lens": wav_lens,
|
| 738 |
+
"waveform_rel_lens": wav_rel_lens,
|
| 739 |
+
"speaker_names": batch["speaker_name"],
|
| 740 |
+
"raw_text": batch["raw_text"],
|
| 741 |
+
"d_vector": torch.concatenate(batch["d_vector"]) if 'd_vector' in batch.keys() else None,
|
| 742 |
+
}
|
setup.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from distutils.core import setup
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='pvq_manipulation',
|
| 5 |
+
version='0.0.0',
|
| 6 |
+
author='Department of Communications Engineering, Paderborn University',
|
| 7 |
+
author_email='[email protected]',
|
| 8 |
+
license='MIT',
|
| 9 |
+
keywords='audio speech',
|
| 10 |
+
install_requires=[
|
| 11 |
+
'torchdiffeq',
|
| 12 |
+
],
|
| 13 |
+
)
|