Spaces:
Running
on
A10G
Running
on
A10G
Update to V1.5
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +286 -340
- fish_speech/callbacks/__init__.py +3 -3
- fish_speech/callbacks/grad_norm.py +113 -113
- fish_speech/configs/base.yaml +87 -87
- fish_speech/configs/firefly_gan_vq.yaml +33 -33
- fish_speech/configs/lora/r_8_alpha_16.yaml +4 -4
- fish_speech/configs/model/dual_ar_2_codebook_large.yaml +0 -9
- fish_speech/configs/model/dual_ar_2_codebook_medium.yaml +0 -9
- fish_speech/configs/model/dual_ar_2_codebook_small.yaml +0 -13
- fish_speech/configs/model/naive_2_codebook_small.yaml +0 -12
- fish_speech/configs/text2semantic_finetune.yaml +83 -83
- fish_speech/configs/text2semantic_finetune_lora.yaml +0 -13
- fish_speech/configs/text2semantic_pretrain.yaml +0 -74
- fish_speech/configs/text2semantic_sft.yaml +0 -87
- fish_speech/configs/vqgan_finetune.yaml +0 -135
- fish_speech/configs/vqgan_pretrain.yaml +0 -139
- fish_speech/conversation.py +267 -2
- fish_speech/datasets/concat_repeat.py +53 -53
- fish_speech/datasets/protos/text-data.proto +24 -24
- fish_speech/datasets/protos/text_data_pb2.py +33 -33
- fish_speech/datasets/protos/text_data_stream.py +36 -36
- fish_speech/datasets/semantic.py +496 -496
- fish_speech/datasets/text.py +0 -661
- fish_speech/datasets/vqgan.py +147 -147
- fish_speech/i18n/README.md +27 -27
- fish_speech/i18n/__init__.py +3 -3
- fish_speech/i18n/core.py +40 -40
- fish_speech/i18n/locale/en_US.json +123 -122
- fish_speech/i18n/locale/es_ES.json +123 -122
- fish_speech/i18n/locale/ja_JP.json +123 -123
- fish_speech/i18n/locale/ko_KR.json +123 -0
- fish_speech/i18n/locale/pt_BR.json +133 -133
- fish_speech/i18n/locale/zh_CN.json +123 -122
- fish_speech/i18n/scan.py +122 -122
- fish_speech/models/text2semantic/lit_module.py +202 -202
- fish_speech/models/text2semantic/llama.py +887 -779
- fish_speech/models/text2semantic/lora.py +92 -92
- fish_speech/models/vqgan/lit_module.py +0 -442
- fish_speech/models/vqgan/modules/discriminator.py +0 -44
- fish_speech/models/vqgan/modules/firefly.py +596 -596
- fish_speech/models/vqgan/modules/fsq.py +116 -116
- fish_speech/models/vqgan/modules/reference.py +0 -113
- fish_speech/models/vqgan/modules/wavenet.py +0 -225
- fish_speech/models/vqgan/spectrogram.py +0 -122
- fish_speech/models/vqgan/utils.py +94 -94
- fish_speech/scheduler.py +40 -40
- fish_speech/text/__init__.py +4 -4
- fish_speech/text/chn_text_norm/.gitignore +114 -114
- fish_speech/text/chn_text_norm/README.md +36 -36
- fish_speech/text/chn_text_norm/basic_class.py +172 -172
app.py
CHANGED
@@ -10,7 +10,7 @@ import gc
|
|
10 |
|
11 |
# Download if not exists
|
12 |
os.makedirs("checkpoints", exist_ok=True)
|
13 |
-
snapshot_download(repo_id="fishaudio/fish-speech-1.
|
14 |
|
15 |
print("All checkpoints downloaded")
|
16 |
|
@@ -31,11 +31,11 @@ torchaudio.set_audio_backend("soundfile")
|
|
31 |
from loguru import logger
|
32 |
from transformers import AutoTokenizer
|
33 |
|
34 |
-
from
|
35 |
-
from tools.vqgan.inference import load_model as load_vqgan_model
|
36 |
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
37 |
from tools.api import decode_vq_tokens, encode_reference
|
38 |
-
from tools.
|
39 |
from tools.llama.generate import (
|
40 |
GenerateRequest,
|
41 |
GenerateResponse,
|
@@ -44,20 +44,43 @@ from tools.llama.generate import (
|
|
44 |
)
|
45 |
from tools.vqgan.inference import load_model as load_decoder_model
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# Make einx happy
|
48 |
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
49 |
|
50 |
|
51 |
HEADER_MD = """# Fish Speech
|
52 |
|
53 |
-
## The demo in this space is version 1.
|
54 |
-
## 该 Demo 为 Fish Speech 1.
|
55 |
|
56 |
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
57 |
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
58 |
|
59 |
-
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.
|
60 |
-
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.
|
61 |
|
62 |
Related code and weights are released under CC BY-NC-SA 4.0 License.
|
63 |
相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
|
@@ -65,8 +88,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
|
|
65 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
66 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
67 |
|
68 |
-
The model running in this WebUI is Fish Speech V1.
|
69 |
-
在此 WebUI 中运行的模型是 Fish Speech V1.
|
70 |
"""
|
71 |
|
72 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
@@ -95,48 +118,77 @@ def build_html_error_message(error):
|
|
95 |
|
96 |
@GPU_DECORATOR
|
97 |
@torch.inference_mode()
|
98 |
-
def inference(
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
streaming=False
|
109 |
-
):
|
110 |
-
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
111 |
-
return (
|
112 |
-
None,
|
113 |
-
None,
|
114 |
-
"Text is too long, please keep it under {} characters.".format(
|
115 |
-
args.max_gradio_length
|
116 |
-
),
|
117 |
)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# LLAMA Inference
|
127 |
request = dict(
|
128 |
device=decoder_model.device,
|
129 |
-
max_new_tokens=max_new_tokens,
|
130 |
-
text=
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
134 |
compile=args.compile,
|
135 |
-
iterative_prompt=chunk_length > 0,
|
136 |
-
chunk_length=chunk_length,
|
137 |
-
max_length=
|
138 |
-
prompt_tokens=prompt_tokens
|
139 |
-
prompt_text=
|
140 |
)
|
141 |
|
142 |
response_queue = queue.Queue()
|
@@ -152,19 +204,15 @@ def inference(
|
|
152 |
while True:
|
153 |
result: WrappedGenerateResponse = response_queue.get()
|
154 |
if result.status == "error":
|
155 |
-
|
|
|
156 |
|
157 |
result: GenerateResponse = result.response
|
158 |
if result.action == "next":
|
159 |
break
|
160 |
|
161 |
-
with
|
162 |
-
device_type=
|
163 |
-
"cpu"
|
164 |
-
if decoder_model.device.type == "mps"
|
165 |
-
else decoder_model.device.type
|
166 |
-
),
|
167 |
-
dtype=args.precision,
|
168 |
):
|
169 |
fake_audios = decode_vq_tokens(
|
170 |
decoder_model=decoder_model,
|
@@ -179,79 +227,24 @@ def inference(
|
|
179 |
None,
|
180 |
None,
|
181 |
build_html_error_message(
|
182 |
-
"No audio generated, please check the input text."
|
183 |
),
|
184 |
)
|
185 |
|
186 |
-
#
|
187 |
audio = np.concatenate(segments, axis=0)
|
188 |
-
|
189 |
|
190 |
if torch.cuda.is_available():
|
191 |
torch.cuda.empty_cache()
|
192 |
gc.collect()
|
193 |
|
194 |
-
|
195 |
-
def inference_with_auto_rerank(
|
196 |
-
text,
|
197 |
-
enable_reference_audio,
|
198 |
-
reference_audio,
|
199 |
-
reference_text,
|
200 |
-
max_new_tokens,
|
201 |
-
chunk_length,
|
202 |
-
top_p,
|
203 |
-
repetition_penalty,
|
204 |
-
temperature,
|
205 |
-
use_auto_rerank,
|
206 |
-
streaming=False,
|
207 |
-
):
|
208 |
-
max_attempts = 2 if use_auto_rerank else 1
|
209 |
-
best_wer = float("inf")
|
210 |
-
best_audio = None
|
211 |
-
best_sample_rate = None
|
212 |
-
|
213 |
-
for attempt in range(max_attempts):
|
214 |
-
_, (sample_rate, audio), message = inference(
|
215 |
-
text,
|
216 |
-
enable_reference_audio,
|
217 |
-
reference_audio,
|
218 |
-
reference_text,
|
219 |
-
max_new_tokens,
|
220 |
-
chunk_length,
|
221 |
-
top_p,
|
222 |
-
repetition_penalty,
|
223 |
-
temperature,
|
224 |
-
streaming=False,
|
225 |
-
)
|
226 |
-
|
227 |
-
if audio is None:
|
228 |
-
return None, None, message
|
229 |
-
|
230 |
-
if not use_auto_rerank:
|
231 |
-
return None, (sample_rate, audio), None
|
232 |
-
|
233 |
-
asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
|
234 |
-
wer = calculate_wer(text, asr_result["text"])
|
235 |
-
|
236 |
-
if wer <= 0.3 and not asr_result["huge_gap"]:
|
237 |
-
return None, (sample_rate, audio), None
|
238 |
-
|
239 |
-
if wer < best_wer:
|
240 |
-
best_wer = wer
|
241 |
-
best_audio = audio
|
242 |
-
best_sample_rate = sample_rate
|
243 |
-
|
244 |
-
if attempt == max_attempts - 1:
|
245 |
-
break
|
246 |
-
|
247 |
-
return None, (best_sample_rate, best_audio), None
|
248 |
-
|
249 |
-
|
250 |
n_audios = 4
|
251 |
|
252 |
global_audio_list = []
|
253 |
global_error_list = []
|
254 |
|
|
|
255 |
def inference_wrapper(
|
256 |
text,
|
257 |
enable_reference_audio,
|
@@ -262,14 +255,14 @@ def inference_wrapper(
|
|
262 |
top_p,
|
263 |
repetition_penalty,
|
264 |
temperature,
|
|
|
265 |
batch_infer_num,
|
266 |
-
if_load_asr_model,
|
267 |
):
|
268 |
audios = []
|
269 |
errors = []
|
270 |
|
271 |
for _ in range(batch_infer_num):
|
272 |
-
result =
|
273 |
text,
|
274 |
enable_reference_audio,
|
275 |
reference_audio,
|
@@ -279,10 +272,10 @@ def inference_wrapper(
|
|
279 |
top_p,
|
280 |
repetition_penalty,
|
281 |
temperature,
|
282 |
-
|
283 |
)
|
284 |
|
285 |
-
_, audio_data, error_message = result
|
286 |
|
287 |
audios.append(
|
288 |
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
@@ -314,52 +307,17 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
314 |
buffer.close()
|
315 |
return wav_header_bytes
|
316 |
|
317 |
-
|
318 |
def normalize_text(user_input, use_normalization):
|
319 |
if use_normalization:
|
320 |
return ChnNormedText(raw_text=user_input).normalize()
|
321 |
else:
|
322 |
return user_input
|
323 |
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
global asr_model
|
330 |
-
|
331 |
-
if if_load:
|
332 |
-
gr.Warning("Loading faster whisper model...")
|
333 |
-
if asr_model is None:
|
334 |
-
asr_model = load_model()
|
335 |
-
return gr.Checkbox(label="Unload faster whisper model", value=if_load)
|
336 |
-
|
337 |
-
if if_load is False:
|
338 |
-
gr.Warning("Unloading faster whisper model...")
|
339 |
-
del asr_model
|
340 |
-
asr_model = None
|
341 |
-
if torch.cuda.is_available():
|
342 |
-
torch.cuda.empty_cache()
|
343 |
-
gc.collect()
|
344 |
-
return gr.Checkbox(label="Load faster whisper model", value=if_load)
|
345 |
-
|
346 |
-
|
347 |
-
def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
|
348 |
-
if if_load and asr_model is not None:
|
349 |
-
if (
|
350 |
-
if_auto_label
|
351 |
-
and enable_ref
|
352 |
-
and ref_audio is not None
|
353 |
-
and ref_text.strip() == ""
|
354 |
-
):
|
355 |
-
data, sample_rate = librosa.load(ref_audio)
|
356 |
-
res = batch_asr(asr_model, [data], sample_rate)[0]
|
357 |
-
ref_text = res["text"]
|
358 |
-
else:
|
359 |
-
gr.Warning("Whisper model not loaded!")
|
360 |
-
|
361 |
-
return gr.Textbox(value=ref_text)
|
362 |
-
|
363 |
|
364 |
def build_app():
|
365 |
with gr.Blocks(theme=gr.themes.Base()) as app:
|
@@ -377,202 +335,185 @@ def build_app():
|
|
377 |
with gr.Row():
|
378 |
with gr.Column(scale=3):
|
379 |
text = gr.Textbox(
|
380 |
-
label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
381 |
)
|
382 |
refined_text = gr.Textbox(
|
383 |
-
label="Realtime Transform Text",
|
384 |
-
placeholder=
|
385 |
-
"Normalization Result Preview (Currently Only Chinese)"
|
|
|
386 |
lines=5,
|
387 |
interactive=False,
|
388 |
)
|
389 |
|
390 |
with gr.Row():
|
391 |
-
|
392 |
-
label="Text Normalization
|
393 |
-
value=False,
|
394 |
-
scale=1,
|
395 |
-
)
|
396 |
-
|
397 |
-
if_load_asr_model = gr.Checkbox(
|
398 |
-
label="Load / Unload ASR model for auto-reranking",
|
399 |
value=False,
|
400 |
-
scale=3,
|
401 |
)
|
402 |
|
403 |
with gr.Row():
|
404 |
-
with gr.
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
487 |
|
488 |
with gr.Column(scale=3):
|
489 |
-
for _ in range(n_audios):
|
490 |
-
with gr.Row():
|
491 |
-
error = gr.HTML(
|
492 |
-
label="Error Message",
|
493 |
-
visible=True if _ == 0 else False,
|
494 |
-
)
|
495 |
-
global_error_list.append(error)
|
496 |
-
with gr.Row():
|
497 |
-
audio = gr.Audio(
|
498 |
-
label="Generated Audio",
|
499 |
-
type="numpy",
|
500 |
-
interactive=False,
|
501 |
-
visible=True if _ == 0 else False,
|
502 |
-
)
|
503 |
-
global_audio_list.append(audio)
|
504 |
-
|
505 |
with gr.Row():
|
506 |
-
|
507 |
-
label="
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
510 |
interactive=False,
|
511 |
-
|
512 |
)
|
|
|
513 |
with gr.Row():
|
514 |
with gr.Column(scale=3):
|
515 |
generate = gr.Button(
|
516 |
-
value="\U0001F3A7 " + "Generate", variant="primary"
|
517 |
-
)
|
518 |
-
generate_stream = gr.Button(
|
519 |
-
value="\U0001F3A7 " + "Streaming Generate",
|
520 |
-
variant="primary",
|
521 |
)
|
522 |
|
523 |
text.input(
|
524 |
-
fn=normalize_text, inputs=[text,
|
525 |
)
|
526 |
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
|
|
|
|
571 |
generate.click(
|
572 |
inference_wrapper,
|
573 |
[
|
574 |
refined_text,
|
575 |
-
|
|
|
576 |
reference_audio,
|
577 |
reference_text,
|
578 |
max_new_tokens,
|
@@ -580,26 +521,28 @@ def build_app():
|
|
580 |
top_p,
|
581 |
repetition_penalty,
|
582 |
temperature,
|
583 |
-
|
584 |
-
|
585 |
],
|
586 |
-
[
|
587 |
concurrency_limit=1,
|
588 |
)
|
|
|
589 |
return app
|
590 |
|
591 |
|
|
|
592 |
def parse_args():
|
593 |
parser = ArgumentParser()
|
594 |
parser.add_argument(
|
595 |
"--llama-checkpoint-path",
|
596 |
type=Path,
|
597 |
-
default="checkpoints/fish-speech-1.
|
598 |
)
|
599 |
parser.add_argument(
|
600 |
"--decoder-checkpoint-path",
|
601 |
type=Path,
|
602 |
-
default="checkpoints/fish-speech-1.
|
603 |
)
|
604 |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
605 |
parser.add_argument("--device", type=str, default="cuda")
|
@@ -634,17 +577,20 @@ if __name__ == "__main__":
|
|
634 |
|
635 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
636 |
list(
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
|
|
|
|
|
|
648 |
)
|
649 |
|
650 |
logger.info("Warming up done, launching the web UI...")
|
|
|
10 |
|
11 |
# Download if not exists
|
12 |
os.makedirs("checkpoints", exist_ok=True)
|
13 |
+
snapshot_download(repo_id="fishaudio/fish-speech-1.5", local_dir="./checkpoints/fish-speech-1.5")
|
14 |
|
15 |
print("All checkpoints downloaded")
|
16 |
|
|
|
31 |
from loguru import logger
|
32 |
from transformers import AutoTokenizer
|
33 |
|
34 |
+
from fish_speech.i18n import i18n
|
|
|
35 |
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
36 |
+
from fish_speech.utils import autocast_exclude_mps, set_seed
|
37 |
from tools.api import decode_vq_tokens, encode_reference
|
38 |
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
39 |
from tools.llama.generate import (
|
40 |
GenerateRequest,
|
41 |
GenerateResponse,
|
|
|
44 |
)
|
45 |
from tools.vqgan.inference import load_model as load_decoder_model
|
46 |
|
47 |
+
from tools.schema import (
|
48 |
+
GLOBAL_NUM_SAMPLES,
|
49 |
+
ASRPackRequest,
|
50 |
+
ServeASRRequest,
|
51 |
+
ServeASRResponse,
|
52 |
+
ServeASRSegment,
|
53 |
+
ServeAudioPart,
|
54 |
+
ServeForwardMessage,
|
55 |
+
ServeMessage,
|
56 |
+
ServeRequest,
|
57 |
+
ServeResponse,
|
58 |
+
ServeStreamDelta,
|
59 |
+
ServeStreamResponse,
|
60 |
+
ServeTextPart,
|
61 |
+
ServeTimedASRResponse,
|
62 |
+
ServeTTSRequest,
|
63 |
+
ServeVQGANDecodeRequest,
|
64 |
+
ServeVQGANDecodeResponse,
|
65 |
+
ServeVQGANEncodeRequest,
|
66 |
+
ServeVQGANEncodeResponse,
|
67 |
+
ServeVQPart,
|
68 |
+
ServeReferenceAudio
|
69 |
+
)
|
70 |
# Make einx happy
|
71 |
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
72 |
|
73 |
|
74 |
HEADER_MD = """# Fish Speech
|
75 |
|
76 |
+
## The demo in this space is version 1.5, Please check [Fish Audio](https://fish.audio) for the best model.
|
77 |
+
## 该 Demo 为 Fish Speech 1.5 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
|
78 |
|
79 |
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
80 |
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
81 |
|
82 |
+
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).
|
83 |
+
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型.
|
84 |
|
85 |
Related code and weights are released under CC BY-NC-SA 4.0 License.
|
86 |
相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
|
|
|
88 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
89 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
90 |
|
91 |
+
The model running in this WebUI is Fish Speech V1.5 Medium.
|
92 |
+
在此 WebUI 中运行的模型是 Fish Speech V1.5 Medium.
|
93 |
"""
|
94 |
|
95 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
|
118 |
|
119 |
@GPU_DECORATOR
|
120 |
@torch.inference_mode()
|
121 |
+
def inference(req: ServeTTSRequest):
|
122 |
+
|
123 |
+
global prompt_tokens, prompt_texts
|
124 |
+
|
125 |
+
idstr: str | None = req.reference_id
|
126 |
+
if idstr is not None:
|
127 |
+
ref_folder = Path("references") / idstr
|
128 |
+
ref_folder.mkdir(parents=True, exist_ok=True)
|
129 |
+
ref_audios = list_files(
|
130 |
+
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
)
|
132 |
|
133 |
+
if req.use_memory_cache == "never" or (
|
134 |
+
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
135 |
+
):
|
136 |
+
prompt_tokens = [
|
137 |
+
encode_reference(
|
138 |
+
decoder_model=decoder_model,
|
139 |
+
reference_audio=audio_to_bytes(str(ref_audio)),
|
140 |
+
enable_reference_audio=True,
|
141 |
+
)
|
142 |
+
for ref_audio in ref_audios
|
143 |
+
]
|
144 |
+
prompt_texts = [
|
145 |
+
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
146 |
+
for ref_audio in ref_audios
|
147 |
+
]
|
148 |
+
else:
|
149 |
+
logger.info("Use same references")
|
150 |
+
|
151 |
+
else:
|
152 |
+
# Parse reference audio aka prompt
|
153 |
+
refs = req.references
|
154 |
+
|
155 |
+
if req.use_memory_cache == "never" or (
|
156 |
+
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
157 |
+
):
|
158 |
+
prompt_tokens = [
|
159 |
+
encode_reference(
|
160 |
+
decoder_model=decoder_model,
|
161 |
+
reference_audio=ref.audio,
|
162 |
+
enable_reference_audio=True,
|
163 |
+
)
|
164 |
+
for ref in refs
|
165 |
+
]
|
166 |
+
prompt_texts = [ref.text for ref in refs]
|
167 |
+
else:
|
168 |
+
logger.info("Use same references")
|
169 |
+
|
170 |
+
if req.seed is not None:
|
171 |
+
set_seed(req.seed)
|
172 |
+
logger.warning(f"set seed: {req.seed}")
|
173 |
|
174 |
# LLAMA Inference
|
175 |
request = dict(
|
176 |
device=decoder_model.device,
|
177 |
+
max_new_tokens=req.max_new_tokens,
|
178 |
+
text=(
|
179 |
+
req.text
|
180 |
+
if not req.normalize
|
181 |
+
else ChnNormedText(raw_text=req.text).normalize()
|
182 |
+
),
|
183 |
+
top_p=req.top_p,
|
184 |
+
repetition_penalty=req.repetition_penalty,
|
185 |
+
temperature=req.temperature,
|
186 |
compile=args.compile,
|
187 |
+
iterative_prompt=req.chunk_length > 0,
|
188 |
+
chunk_length=req.chunk_length,
|
189 |
+
max_length=4096,
|
190 |
+
prompt_tokens=prompt_tokens,
|
191 |
+
prompt_text=prompt_texts,
|
192 |
)
|
193 |
|
194 |
response_queue = queue.Queue()
|
|
|
204 |
while True:
|
205 |
result: WrappedGenerateResponse = response_queue.get()
|
206 |
if result.status == "error":
|
207 |
+
yield None, None, build_html_error_message(result.response)
|
208 |
+
break
|
209 |
|
210 |
result: GenerateResponse = result.response
|
211 |
if result.action == "next":
|
212 |
break
|
213 |
|
214 |
+
with autocast_exclude_mps(
|
215 |
+
device_type=decoder_model.device.type, dtype=args.precision
|
|
|
|
|
|
|
|
|
|
|
216 |
):
|
217 |
fake_audios = decode_vq_tokens(
|
218 |
decoder_model=decoder_model,
|
|
|
227 |
None,
|
228 |
None,
|
229 |
build_html_error_message(
|
230 |
+
i18n("No audio generated, please check the input text.")
|
231 |
),
|
232 |
)
|
233 |
|
234 |
+
# No matter streaming or not, we need to return the final audio
|
235 |
audio = np.concatenate(segments, axis=0)
|
236 |
+
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
237 |
|
238 |
if torch.cuda.is_available():
|
239 |
torch.cuda.empty_cache()
|
240 |
gc.collect()
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
n_audios = 4
|
243 |
|
244 |
global_audio_list = []
|
245 |
global_error_list = []
|
246 |
|
247 |
+
|
248 |
def inference_wrapper(
|
249 |
text,
|
250 |
enable_reference_audio,
|
|
|
255 |
top_p,
|
256 |
repetition_penalty,
|
257 |
temperature,
|
258 |
+
seed,
|
259 |
batch_infer_num,
|
|
|
260 |
):
|
261 |
audios = []
|
262 |
errors = []
|
263 |
|
264 |
for _ in range(batch_infer_num):
|
265 |
+
result = inference(
|
266 |
text,
|
267 |
enable_reference_audio,
|
268 |
reference_audio,
|
|
|
272 |
top_p,
|
273 |
repetition_penalty,
|
274 |
temperature,
|
275 |
+
seed,
|
276 |
)
|
277 |
|
278 |
+
_, audio_data, error_message = next(result)
|
279 |
|
280 |
audios.append(
|
281 |
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
|
|
307 |
buffer.close()
|
308 |
return wav_header_bytes
|
309 |
|
|
|
310 |
def normalize_text(user_input, use_normalization):
|
311 |
if use_normalization:
|
312 |
return ChnNormedText(raw_text=user_input).normalize()
|
313 |
else:
|
314 |
return user_input
|
315 |
|
316 |
+
def update_examples():
|
317 |
+
examples_dir = Path("references")
|
318 |
+
examples_dir.mkdir(parents=True, exist_ok=True)
|
319 |
+
example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
|
320 |
+
return gr.Dropdown(choices=example_audios + [""])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
def build_app():
|
323 |
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
335 |
with gr.Row():
|
336 |
with gr.Column(scale=3):
|
337 |
text = gr.Textbox(
|
338 |
+
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
339 |
)
|
340 |
refined_text = gr.Textbox(
|
341 |
+
label=i18n("Realtime Transform Text"),
|
342 |
+
placeholder=i18n(
|
343 |
+
"Normalization Result Preview (Currently Only Chinese)"
|
344 |
+
),
|
345 |
lines=5,
|
346 |
interactive=False,
|
347 |
)
|
348 |
|
349 |
with gr.Row():
|
350 |
+
normalize = gr.Checkbox(
|
351 |
+
label=i18n("Text Normalization"),
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
value=False,
|
|
|
353 |
)
|
354 |
|
355 |
with gr.Row():
|
356 |
+
with gr.Column():
|
357 |
+
with gr.Tab(label=i18n("Advanced Config")):
|
358 |
+
with gr.Row():
|
359 |
+
chunk_length = gr.Slider(
|
360 |
+
label=i18n("Iterative Prompt Length, 0 means off"),
|
361 |
+
minimum=0,
|
362 |
+
maximum=300,
|
363 |
+
value=200,
|
364 |
+
step=8,
|
365 |
+
)
|
366 |
+
|
367 |
+
max_new_tokens = gr.Slider(
|
368 |
+
label=i18n(
|
369 |
+
"Maximum tokens per batch, 0 means no limit"
|
370 |
+
),
|
371 |
+
minimum=0,
|
372 |
+
maximum=2048,
|
373 |
+
value=0,
|
374 |
+
step=8,
|
375 |
+
)
|
376 |
+
|
377 |
+
with gr.Row():
|
378 |
+
top_p = gr.Slider(
|
379 |
+
label="Top-P",
|
380 |
+
minimum=0.6,
|
381 |
+
maximum=0.9,
|
382 |
+
value=0.7,
|
383 |
+
step=0.01,
|
384 |
+
)
|
385 |
+
|
386 |
+
repetition_penalty = gr.Slider(
|
387 |
+
label=i18n("Repetition Penalty"),
|
388 |
+
minimum=1,
|
389 |
+
maximum=1.5,
|
390 |
+
value=1.2,
|
391 |
+
step=0.01,
|
392 |
+
)
|
393 |
+
|
394 |
+
with gr.Row():
|
395 |
+
temperature = gr.Slider(
|
396 |
+
label="Temperature",
|
397 |
+
minimum=0.6,
|
398 |
+
maximum=0.9,
|
399 |
+
value=0.7,
|
400 |
+
step=0.01,
|
401 |
+
)
|
402 |
+
seed = gr.Number(
|
403 |
+
label="Seed",
|
404 |
+
info="0 means randomized inference, otherwise deterministic",
|
405 |
+
value=0,
|
406 |
+
)
|
407 |
+
|
408 |
+
with gr.Tab(label=i18n("Reference Audio")):
|
409 |
+
with gr.Row():
|
410 |
+
gr.Markdown(
|
411 |
+
i18n(
|
412 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
413 |
+
)
|
414 |
+
)
|
415 |
+
with gr.Row():
|
416 |
+
reference_id = gr.Textbox(
|
417 |
+
label=i18n("Reference ID"),
|
418 |
+
placeholder="Leave empty to use uploaded references",
|
419 |
+
)
|
420 |
+
|
421 |
+
with gr.Row():
|
422 |
+
use_memory_cache = gr.Radio(
|
423 |
+
label=i18n("Use Memory Cache"),
|
424 |
+
choices=["never", "on-demand", "always"],
|
425 |
+
value="on-demand",
|
426 |
+
)
|
427 |
+
|
428 |
+
with gr.Row():
|
429 |
+
reference_audio = gr.Audio(
|
430 |
+
label=i18n("Reference Audio"),
|
431 |
+
type="filepath",
|
432 |
+
)
|
433 |
+
with gr.Row():
|
434 |
+
reference_text = gr.Textbox(
|
435 |
+
label=i18n("Reference Text"),
|
436 |
+
lines=1,
|
437 |
+
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
438 |
+
value="",
|
439 |
+
)
|
440 |
|
441 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
with gr.Row():
|
443 |
+
error = gr.HTML(
|
444 |
+
label=i18n("Error Message"),
|
445 |
+
visible=True,
|
446 |
+
)
|
447 |
+
with gr.Row():
|
448 |
+
audio = gr.Audio(
|
449 |
+
label=i18n("Generated Audio"),
|
450 |
+
type="numpy",
|
451 |
interactive=False,
|
452 |
+
visible=True,
|
453 |
)
|
454 |
+
|
455 |
with gr.Row():
|
456 |
with gr.Column(scale=3):
|
457 |
generate = gr.Button(
|
458 |
+
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
|
|
|
|
|
|
|
|
459 |
)
|
460 |
|
461 |
text.input(
|
462 |
+
fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]
|
463 |
)
|
464 |
|
465 |
+
def inference_wrapper(
|
466 |
+
text,
|
467 |
+
normalize,
|
468 |
+
reference_id,
|
469 |
+
reference_audio,
|
470 |
+
reference_text,
|
471 |
+
max_new_tokens,
|
472 |
+
chunk_length,
|
473 |
+
top_p,
|
474 |
+
repetition_penalty,
|
475 |
+
temperature,
|
476 |
+
seed,
|
477 |
+
use_memory_cache,
|
478 |
+
):
|
479 |
+
references = []
|
480 |
+
if reference_audio:
|
481 |
+
# 将文件路径转换为字节
|
482 |
+
with open(reference_audio, 'rb') as audio_file:
|
483 |
+
audio_bytes = audio_file.read()
|
484 |
+
references = [
|
485 |
+
ServeReferenceAudio(audio=audio_bytes, text=reference_text)
|
486 |
+
]
|
487 |
+
|
488 |
+
req = ServeTTSRequest(
|
489 |
+
text=text,
|
490 |
+
normalize=normalize,
|
491 |
+
reference_id=reference_id if reference_id else None,
|
492 |
+
references=references,
|
493 |
+
max_new_tokens=max_new_tokens,
|
494 |
+
chunk_length=chunk_length,
|
495 |
+
top_p=top_p,
|
496 |
+
repetition_penalty=repetition_penalty,
|
497 |
+
temperature=temperature,
|
498 |
+
seed=int(seed) if seed else None,
|
499 |
+
use_memory_cache=use_memory_cache,
|
500 |
+
)
|
501 |
+
|
502 |
+
for result in inference(req):
|
503 |
+
if result[2]: # Error message
|
504 |
+
return None, result[2]
|
505 |
+
elif result[1]: # Audio data
|
506 |
+
return result[1], None
|
507 |
+
|
508 |
+
return None, i18n("No audio generated")
|
509 |
+
|
510 |
+
# Submit
|
511 |
generate.click(
|
512 |
inference_wrapper,
|
513 |
[
|
514 |
refined_text,
|
515 |
+
normalize,
|
516 |
+
reference_id,
|
517 |
reference_audio,
|
518 |
reference_text,
|
519 |
max_new_tokens,
|
|
|
521 |
top_p,
|
522 |
repetition_penalty,
|
523 |
temperature,
|
524 |
+
seed,
|
525 |
+
use_memory_cache,
|
526 |
],
|
527 |
+
[audio, error],
|
528 |
concurrency_limit=1,
|
529 |
)
|
530 |
+
|
531 |
return app
|
532 |
|
533 |
|
534 |
+
|
535 |
def parse_args():
|
536 |
parser = ArgumentParser()
|
537 |
parser.add_argument(
|
538 |
"--llama-checkpoint-path",
|
539 |
type=Path,
|
540 |
+
default="checkpoints/fish-speech-1.5",
|
541 |
)
|
542 |
parser.add_argument(
|
543 |
"--decoder-checkpoint-path",
|
544 |
type=Path,
|
545 |
+
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
546 |
)
|
547 |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
548 |
parser.add_argument("--device", type=str, default="cuda")
|
|
|
577 |
|
578 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
579 |
list(
|
580 |
+
inference(
|
581 |
+
ServeTTSRequest(
|
582 |
+
text="Hello world.",
|
583 |
+
references=[],
|
584 |
+
reference_id=None,
|
585 |
+
max_new_tokens=0,
|
586 |
+
chunk_length=200,
|
587 |
+
top_p=0.7,
|
588 |
+
repetition_penalty=1.5,
|
589 |
+
temperature=0.7,
|
590 |
+
emotion=None,
|
591 |
+
format="wav",
|
592 |
+
)
|
593 |
+
)
|
594 |
)
|
595 |
|
596 |
logger.info("Warming up done, launching the web UI...")
|
fish_speech/callbacks/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from .grad_norm import GradNormMonitor
|
2 |
-
|
3 |
-
__all__ = ["GradNormMonitor"]
|
|
|
1 |
+
from .grad_norm import GradNormMonitor
|
2 |
+
|
3 |
+
__all__ = ["GradNormMonitor"]
|
fish_speech/callbacks/grad_norm.py
CHANGED
@@ -1,113 +1,113 @@
|
|
1 |
-
from typing import Optional, Union
|
2 |
-
|
3 |
-
import lightning.pytorch as pl
|
4 |
-
import torch
|
5 |
-
from lightning import LightningModule, Trainer
|
6 |
-
from lightning.pytorch.callbacks import Callback
|
7 |
-
from torch import Tensor, nn
|
8 |
-
from torch.utils._foreach_utils import (
|
9 |
-
_group_tensors_by_device_and_dtype,
|
10 |
-
_has_foreach_support,
|
11 |
-
)
|
12 |
-
|
13 |
-
|
14 |
-
@torch.no_grad()
|
15 |
-
def grad_norm(
|
16 |
-
parameters: Union[Tensor, list[Tensor]],
|
17 |
-
norm_type: float = 2.0,
|
18 |
-
) -> float:
|
19 |
-
"""
|
20 |
-
Returns the norm of the gradients of the given parameters.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
24 |
-
single Tensor that will have gradients normalized
|
25 |
-
norm_type (float): type of the used p-norm.
|
26 |
-
|
27 |
-
Returns:
|
28 |
-
Total norm of the parameter gradients (viewed as a single vector).
|
29 |
-
""" # noqa: E501
|
30 |
-
|
31 |
-
if isinstance(parameters, Tensor):
|
32 |
-
parameters = [parameters]
|
33 |
-
|
34 |
-
grads = [p.grad for p in parameters if p.grad is not None]
|
35 |
-
if len(grads) == 0:
|
36 |
-
return None
|
37 |
-
|
38 |
-
first_device = grads[0].device
|
39 |
-
grouped_grads: dict[
|
40 |
-
tuple[torch.device, torch.dtype], list[list[Tensor]]
|
41 |
-
] = _group_tensors_by_device_and_dtype(
|
42 |
-
[[g.detach() for g in grads]]
|
43 |
-
) # type: ignore[assignment]
|
44 |
-
|
45 |
-
norms = []
|
46 |
-
for (device, _), ([grads], _) in grouped_grads.items():
|
47 |
-
if _has_foreach_support(grads, device=device):
|
48 |
-
norms.extend(torch._foreach_norm(grads, norm_type))
|
49 |
-
else:
|
50 |
-
norms.extend([torch.norm(g, norm_type) for g in grads])
|
51 |
-
|
52 |
-
return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
|
53 |
-
|
54 |
-
|
55 |
-
class GradNormMonitor(Callback):
|
56 |
-
"""
|
57 |
-
Callback that computes the gradient norm of the model parameters.
|
58 |
-
"""
|
59 |
-
|
60 |
-
def __init__(
|
61 |
-
self,
|
62 |
-
norm_type: float = 2.0,
|
63 |
-
logging_interval: str = "step",
|
64 |
-
sub_module: Optional[Union[str, list[str]]] = None,
|
65 |
-
) -> None:
|
66 |
-
"""
|
67 |
-
Args:
|
68 |
-
norm_type (float): type of the used p-norm.
|
69 |
-
logging_interval (str): "step" or "epoch".
|
70 |
-
"""
|
71 |
-
super().__init__()
|
72 |
-
|
73 |
-
self.norm_type = norm_type
|
74 |
-
self.logging_interval = logging_interval
|
75 |
-
self.sub_module = sub_module
|
76 |
-
|
77 |
-
def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
|
78 |
-
"""
|
79 |
-
Computes the gradient norm of the model parameters and logs it to the logger.
|
80 |
-
|
81 |
-
Args:
|
82 |
-
trainer (Trainer): The trainer object
|
83 |
-
model (LightningModule): The current lightningModule
|
84 |
-
"""
|
85 |
-
|
86 |
-
lightning_model = model
|
87 |
-
|
88 |
-
if self.sub_module is None:
|
89 |
-
return self.log_sub_module_grad_norm(lightning_model, model, "")
|
90 |
-
|
91 |
-
sub_modules = self.sub_module
|
92 |
-
if isinstance(sub_modules, str):
|
93 |
-
sub_modules = [sub_modules]
|
94 |
-
|
95 |
-
for sub_module in sub_modules:
|
96 |
-
self.log_sub_module_grad_norm(
|
97 |
-
lightning_model, getattr(model, sub_module), f"/{sub_module}"
|
98 |
-
)
|
99 |
-
|
100 |
-
def log_sub_module_grad_norm(
|
101 |
-
self, lightning_model: LightningModule, model: nn.Module, path: str
|
102 |
-
) -> None:
|
103 |
-
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
|
104 |
-
if grad_norm_val is None:
|
105 |
-
return
|
106 |
-
|
107 |
-
on_step = self.logging_interval == "step"
|
108 |
-
lightning_model.log(
|
109 |
-
f"train{path}/grad_norm",
|
110 |
-
grad_norm_val,
|
111 |
-
on_step=on_step,
|
112 |
-
on_epoch=not on_step,
|
113 |
-
)
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import lightning.pytorch as pl
|
4 |
+
import torch
|
5 |
+
from lightning import LightningModule, Trainer
|
6 |
+
from lightning.pytorch.callbacks import Callback
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torch.utils._foreach_utils import (
|
9 |
+
_group_tensors_by_device_and_dtype,
|
10 |
+
_has_foreach_support,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def grad_norm(
|
16 |
+
parameters: Union[Tensor, list[Tensor]],
|
17 |
+
norm_type: float = 2.0,
|
18 |
+
) -> float:
|
19 |
+
"""
|
20 |
+
Returns the norm of the gradients of the given parameters.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
24 |
+
single Tensor that will have gradients normalized
|
25 |
+
norm_type (float): type of the used p-norm.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
29 |
+
""" # noqa: E501
|
30 |
+
|
31 |
+
if isinstance(parameters, Tensor):
|
32 |
+
parameters = [parameters]
|
33 |
+
|
34 |
+
grads = [p.grad for p in parameters if p.grad is not None]
|
35 |
+
if len(grads) == 0:
|
36 |
+
return None
|
37 |
+
|
38 |
+
first_device = grads[0].device
|
39 |
+
grouped_grads: dict[
|
40 |
+
tuple[torch.device, torch.dtype], list[list[Tensor]]
|
41 |
+
] = _group_tensors_by_device_and_dtype(
|
42 |
+
[[g.detach() for g in grads]]
|
43 |
+
) # type: ignore[assignment]
|
44 |
+
|
45 |
+
norms = []
|
46 |
+
for (device, _), ([grads], _) in grouped_grads.items():
|
47 |
+
if _has_foreach_support(grads, device=device):
|
48 |
+
norms.extend(torch._foreach_norm(grads, norm_type))
|
49 |
+
else:
|
50 |
+
norms.extend([torch.norm(g, norm_type) for g in grads])
|
51 |
+
|
52 |
+
return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
|
53 |
+
|
54 |
+
|
55 |
+
class GradNormMonitor(Callback):
|
56 |
+
"""
|
57 |
+
Callback that computes the gradient norm of the model parameters.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
norm_type: float = 2.0,
|
63 |
+
logging_interval: str = "step",
|
64 |
+
sub_module: Optional[Union[str, list[str]]] = None,
|
65 |
+
) -> None:
|
66 |
+
"""
|
67 |
+
Args:
|
68 |
+
norm_type (float): type of the used p-norm.
|
69 |
+
logging_interval (str): "step" or "epoch".
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.norm_type = norm_type
|
74 |
+
self.logging_interval = logging_interval
|
75 |
+
self.sub_module = sub_module
|
76 |
+
|
77 |
+
def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
|
78 |
+
"""
|
79 |
+
Computes the gradient norm of the model parameters and logs it to the logger.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
trainer (Trainer): The trainer object
|
83 |
+
model (LightningModule): The current lightningModule
|
84 |
+
"""
|
85 |
+
|
86 |
+
lightning_model = model
|
87 |
+
|
88 |
+
if self.sub_module is None:
|
89 |
+
return self.log_sub_module_grad_norm(lightning_model, model, "")
|
90 |
+
|
91 |
+
sub_modules = self.sub_module
|
92 |
+
if isinstance(sub_modules, str):
|
93 |
+
sub_modules = [sub_modules]
|
94 |
+
|
95 |
+
for sub_module in sub_modules:
|
96 |
+
self.log_sub_module_grad_norm(
|
97 |
+
lightning_model, getattr(model, sub_module), f"/{sub_module}"
|
98 |
+
)
|
99 |
+
|
100 |
+
def log_sub_module_grad_norm(
|
101 |
+
self, lightning_model: LightningModule, model: nn.Module, path: str
|
102 |
+
) -> None:
|
103 |
+
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
|
104 |
+
if grad_norm_val is None:
|
105 |
+
return
|
106 |
+
|
107 |
+
on_step = self.logging_interval == "step"
|
108 |
+
lightning_model.log(
|
109 |
+
f"train{path}/grad_norm",
|
110 |
+
grad_norm_val,
|
111 |
+
on_step=on_step,
|
112 |
+
on_epoch=not on_step,
|
113 |
+
)
|
fish_speech/configs/base.yaml
CHANGED
@@ -1,87 +1,87 @@
|
|
1 |
-
# Base configuration for training a model
|
2 |
-
paths:
|
3 |
-
run_dir: results/${project}
|
4 |
-
ckpt_dir: ${paths.run_dir}/checkpoints
|
5 |
-
|
6 |
-
hydra:
|
7 |
-
run:
|
8 |
-
dir: ${paths.run_dir}
|
9 |
-
|
10 |
-
# Lightning Trainer
|
11 |
-
trainer:
|
12 |
-
_target_: lightning.pytorch.trainer.Trainer
|
13 |
-
|
14 |
-
default_root_dir: ${paths.run_dir}
|
15 |
-
accelerator: gpu
|
16 |
-
num_nodes: 1
|
17 |
-
devices: auto
|
18 |
-
strategy:
|
19 |
-
_target_: lightning.pytorch.strategies.DDPStrategy
|
20 |
-
process_group_backend: nccl # This should be override when training on windows
|
21 |
-
|
22 |
-
precision: bf16-mixed
|
23 |
-
|
24 |
-
# disable validation by epoch end
|
25 |
-
check_val_every_n_epoch: null
|
26 |
-
val_check_interval: 5000
|
27 |
-
max_steps: 100_000
|
28 |
-
|
29 |
-
# Use torch.backends.cudnn.benchmark to speed up training
|
30 |
-
benchmark: true
|
31 |
-
|
32 |
-
# Callbacks
|
33 |
-
callbacks:
|
34 |
-
model_checkpoint:
|
35 |
-
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
36 |
-
dirpath: ${paths.ckpt_dir}
|
37 |
-
filename: "step_{step:09d}"
|
38 |
-
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
39 |
-
save_top_k: 5 # save 5 latest checkpoints
|
40 |
-
monitor: step # use step to monitor checkpoints
|
41 |
-
mode: max # save the latest checkpoint with the highest global_step
|
42 |
-
every_n_epochs: null # don't save checkpoints by epoch end
|
43 |
-
every_n_train_steps: 5000 # save checkpoints every 5000 steps
|
44 |
-
auto_insert_metric_name: false
|
45 |
-
|
46 |
-
model_summary:
|
47 |
-
_target_: lightning.pytorch.callbacks.ModelSummary
|
48 |
-
max_depth: 2 # the maximum depth of layer nesting that the summary will include
|
49 |
-
|
50 |
-
learning_rate_monitor:
|
51 |
-
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
52 |
-
logging_interval: step
|
53 |
-
log_momentum: false
|
54 |
-
|
55 |
-
grad_norm_monitor:
|
56 |
-
_target_: fish_speech.callbacks.GradNormMonitor
|
57 |
-
norm_type: 2
|
58 |
-
logging_interval: step
|
59 |
-
|
60 |
-
# Logger
|
61 |
-
logger:
|
62 |
-
tensorboard:
|
63 |
-
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
64 |
-
save_dir: "${paths.run_dir}/tensorboard/"
|
65 |
-
name: null
|
66 |
-
log_graph: false
|
67 |
-
default_hp_metric: true
|
68 |
-
prefix: ""
|
69 |
-
|
70 |
-
# wandb:
|
71 |
-
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
|
72 |
-
# # name: "" # name of the run (normally generated by wandb)
|
73 |
-
# save_dir: "${paths.run_dir}"
|
74 |
-
# offline: False
|
75 |
-
# id: null # pass correct id to resume experiment!
|
76 |
-
# anonymous: null # enable anonymous logging
|
77 |
-
# project: "fish-speech"
|
78 |
-
# log_model: False # upload lightning ckpts
|
79 |
-
# prefix: "" # a string to put at the beginning of metric keys
|
80 |
-
# # entity: "" # set to name of your wandb team
|
81 |
-
# group: ""
|
82 |
-
# tags: ["vq", "hq", "finetune"]
|
83 |
-
# job_type: ""
|
84 |
-
|
85 |
-
# Loop
|
86 |
-
train: true
|
87 |
-
test: false
|
|
|
1 |
+
# Base configuration for training a model
|
2 |
+
paths:
|
3 |
+
run_dir: results/${project}
|
4 |
+
ckpt_dir: ${paths.run_dir}/checkpoints
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ${paths.run_dir}
|
9 |
+
|
10 |
+
# Lightning Trainer
|
11 |
+
trainer:
|
12 |
+
_target_: lightning.pytorch.trainer.Trainer
|
13 |
+
|
14 |
+
default_root_dir: ${paths.run_dir}
|
15 |
+
accelerator: gpu
|
16 |
+
num_nodes: 1
|
17 |
+
devices: auto
|
18 |
+
strategy:
|
19 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
20 |
+
process_group_backend: nccl # This should be override when training on windows
|
21 |
+
|
22 |
+
precision: bf16-mixed
|
23 |
+
|
24 |
+
# disable validation by epoch end
|
25 |
+
check_val_every_n_epoch: null
|
26 |
+
val_check_interval: 5000
|
27 |
+
max_steps: 100_000
|
28 |
+
|
29 |
+
# Use torch.backends.cudnn.benchmark to speed up training
|
30 |
+
benchmark: true
|
31 |
+
|
32 |
+
# Callbacks
|
33 |
+
callbacks:
|
34 |
+
model_checkpoint:
|
35 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
36 |
+
dirpath: ${paths.ckpt_dir}
|
37 |
+
filename: "step_{step:09d}"
|
38 |
+
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
39 |
+
save_top_k: 5 # save 5 latest checkpoints
|
40 |
+
monitor: step # use step to monitor checkpoints
|
41 |
+
mode: max # save the latest checkpoint with the highest global_step
|
42 |
+
every_n_epochs: null # don't save checkpoints by epoch end
|
43 |
+
every_n_train_steps: 5000 # save checkpoints every 5000 steps
|
44 |
+
auto_insert_metric_name: false
|
45 |
+
|
46 |
+
model_summary:
|
47 |
+
_target_: lightning.pytorch.callbacks.ModelSummary
|
48 |
+
max_depth: 2 # the maximum depth of layer nesting that the summary will include
|
49 |
+
|
50 |
+
learning_rate_monitor:
|
51 |
+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
52 |
+
logging_interval: step
|
53 |
+
log_momentum: false
|
54 |
+
|
55 |
+
grad_norm_monitor:
|
56 |
+
_target_: fish_speech.callbacks.GradNormMonitor
|
57 |
+
norm_type: 2
|
58 |
+
logging_interval: step
|
59 |
+
|
60 |
+
# Logger
|
61 |
+
logger:
|
62 |
+
tensorboard:
|
63 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
64 |
+
save_dir: "${paths.run_dir}/tensorboard/"
|
65 |
+
name: null
|
66 |
+
log_graph: false
|
67 |
+
default_hp_metric: true
|
68 |
+
prefix: ""
|
69 |
+
|
70 |
+
# wandb:
|
71 |
+
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
|
72 |
+
# # name: "" # name of the run (normally generated by wandb)
|
73 |
+
# save_dir: "${paths.run_dir}"
|
74 |
+
# offline: False
|
75 |
+
# id: null # pass correct id to resume experiment!
|
76 |
+
# anonymous: null # enable anonymous logging
|
77 |
+
# project: "fish-speech"
|
78 |
+
# log_model: False # upload lightning ckpts
|
79 |
+
# prefix: "" # a string to put at the beginning of metric keys
|
80 |
+
# # entity: "" # set to name of your wandb team
|
81 |
+
# group: ""
|
82 |
+
# tags: ["vq", "hq", "finetune"]
|
83 |
+
# job_type: ""
|
84 |
+
|
85 |
+
# Loop
|
86 |
+
train: true
|
87 |
+
test: false
|
fish_speech/configs/firefly_gan_vq.yaml
CHANGED
@@ -1,33 +1,33 @@
|
|
1 |
-
_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
|
2 |
-
spec_transform:
|
3 |
-
_target_: fish_speech.utils.spectrogram.LogMelSpectrogram
|
4 |
-
sample_rate: 44100
|
5 |
-
n_mels: 160
|
6 |
-
n_fft: 2048
|
7 |
-
hop_length: 512
|
8 |
-
win_length: 2048
|
9 |
-
backbone:
|
10 |
-
_target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
|
11 |
-
input_channels: 160
|
12 |
-
depths: [3, 3, 9, 3]
|
13 |
-
dims: [128, 256, 384, 512]
|
14 |
-
drop_path_rate: 0.2
|
15 |
-
kernel_size: 7
|
16 |
-
head:
|
17 |
-
_target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
|
18 |
-
hop_length: 512
|
19 |
-
upsample_rates: [8, 8, 2, 2, 2] # aka. strides
|
20 |
-
upsample_kernel_sizes: [16, 16, 4, 4, 4]
|
21 |
-
resblock_kernel_sizes: [3, 7, 11]
|
22 |
-
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
23 |
-
num_mels: 512
|
24 |
-
upsample_initial_channel: 512
|
25 |
-
pre_conv_kernel_size: 13
|
26 |
-
post_conv_kernel_size: 13
|
27 |
-
quantizer:
|
28 |
-
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
29 |
-
input_dim: 512
|
30 |
-
n_groups: 8
|
31 |
-
n_codebooks: 1
|
32 |
-
levels: [8, 5, 5, 5]
|
33 |
-
downsample_factor: [2, 2]
|
|
|
1 |
+
_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
|
2 |
+
spec_transform:
|
3 |
+
_target_: fish_speech.utils.spectrogram.LogMelSpectrogram
|
4 |
+
sample_rate: 44100
|
5 |
+
n_mels: 160
|
6 |
+
n_fft: 2048
|
7 |
+
hop_length: 512
|
8 |
+
win_length: 2048
|
9 |
+
backbone:
|
10 |
+
_target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
|
11 |
+
input_channels: 160
|
12 |
+
depths: [3, 3, 9, 3]
|
13 |
+
dims: [128, 256, 384, 512]
|
14 |
+
drop_path_rate: 0.2
|
15 |
+
kernel_size: 7
|
16 |
+
head:
|
17 |
+
_target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
|
18 |
+
hop_length: 512
|
19 |
+
upsample_rates: [8, 8, 2, 2, 2] # aka. strides
|
20 |
+
upsample_kernel_sizes: [16, 16, 4, 4, 4]
|
21 |
+
resblock_kernel_sizes: [3, 7, 11]
|
22 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
23 |
+
num_mels: 512
|
24 |
+
upsample_initial_channel: 512
|
25 |
+
pre_conv_kernel_size: 13
|
26 |
+
post_conv_kernel_size: 13
|
27 |
+
quantizer:
|
28 |
+
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
29 |
+
input_dim: 512
|
30 |
+
n_groups: 8
|
31 |
+
n_codebooks: 1
|
32 |
+
levels: [8, 5, 5, 5]
|
33 |
+
downsample_factor: [2, 2]
|
fish_speech/configs/lora/r_8_alpha_16.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
_target_: fish_speech.models.text2semantic.lora.LoraConfig
|
2 |
-
r: 8
|
3 |
-
lora_alpha: 16
|
4 |
-
lora_dropout: 0.01
|
|
|
1 |
+
_target_: fish_speech.models.text2semantic.lora.LoraConfig
|
2 |
+
r: 8
|
3 |
+
lora_alpha: 16
|
4 |
+
lora_dropout: 0.01
|
fish_speech/configs/model/dual_ar_2_codebook_large.yaml
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- dual_ar_2_codebook_small
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
config:
|
6 |
-
n_layer: 30
|
7 |
-
n_fast_layer: 6
|
8 |
-
n_head: 24
|
9 |
-
dim: 1536
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/model/dual_ar_2_codebook_medium.yaml
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- dual_ar_2_codebook_small
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
config:
|
6 |
-
n_layer: 24
|
7 |
-
n_fast_layer: 6
|
8 |
-
n_head: 16
|
9 |
-
dim: 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/model/dual_ar_2_codebook_small.yaml
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
_target_: fish_speech.models.text2semantic.llama.DualARTransformer
|
2 |
-
config:
|
3 |
-
_target_: fish_speech.models.text2semantic.llama.DualARModelArgs
|
4 |
-
max_seq_len: ${max_length}
|
5 |
-
vocab_size: 264 # pad 262 to 8x
|
6 |
-
n_layer: 12
|
7 |
-
n_fast_layer: 4
|
8 |
-
n_head: 12
|
9 |
-
dim: 768
|
10 |
-
rope_base: 10000
|
11 |
-
norm_eps: 1e-5
|
12 |
-
num_codebooks: 2 # input/output codebook size
|
13 |
-
codebook_size: 1032 # codebook size 1024 + 2 special tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/model/naive_2_codebook_small.yaml
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
|
2 |
-
config:
|
3 |
-
_target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
|
4 |
-
max_seq_len: ${max_length}
|
5 |
-
vocab_size: 36408
|
6 |
-
n_layer: 12
|
7 |
-
n_head: 12
|
8 |
-
dim: 768
|
9 |
-
rope_base: 10000
|
10 |
-
norm_eps: 1e-5
|
11 |
-
num_codebooks: 2 # input/output codebook size
|
12 |
-
codebook_size: 1032 # codebook size 1024 + 2 special tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/text2semantic_finetune.yaml
CHANGED
@@ -1,83 +1,83 @@
|
|
1 |
-
defaults:
|
2 |
-
- base
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
project: text2semantic_finetune_dual_ar
|
6 |
-
max_length: 4096
|
7 |
-
pretrained_ckpt_path: checkpoints/fish-speech-1.4
|
8 |
-
|
9 |
-
# Lightning Trainer
|
10 |
-
trainer:
|
11 |
-
accumulate_grad_batches: 1
|
12 |
-
gradient_clip_val: 1.0
|
13 |
-
gradient_clip_algorithm: "norm"
|
14 |
-
max_steps: 1000
|
15 |
-
precision: bf16-true
|
16 |
-
limit_val_batches: 10
|
17 |
-
val_check_interval: 100
|
18 |
-
|
19 |
-
# Dataset Configuration
|
20 |
-
tokenizer:
|
21 |
-
_target_: transformers.AutoTokenizer.from_pretrained
|
22 |
-
pretrained_model_name_or_path: ${pretrained_ckpt_path}
|
23 |
-
|
24 |
-
# Dataset Configuration
|
25 |
-
train_dataset:
|
26 |
-
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
|
27 |
-
proto_files:
|
28 |
-
- data/protos
|
29 |
-
tokenizer: ${tokenizer}
|
30 |
-
causal: true
|
31 |
-
max_length: ${max_length}
|
32 |
-
use_speaker: false
|
33 |
-
interactive_prob: 0.7
|
34 |
-
|
35 |
-
val_dataset:
|
36 |
-
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
|
37 |
-
proto_files:
|
38 |
-
- data/protos
|
39 |
-
tokenizer: ${tokenizer}
|
40 |
-
causal: true
|
41 |
-
max_length: ${max_length}
|
42 |
-
use_speaker: false
|
43 |
-
interactive_prob: 0.7
|
44 |
-
|
45 |
-
data:
|
46 |
-
_target_: fish_speech.datasets.semantic.SemanticDataModule
|
47 |
-
train_dataset: ${train_dataset}
|
48 |
-
val_dataset: ${val_dataset}
|
49 |
-
num_workers: 4
|
50 |
-
batch_size: 8
|
51 |
-
tokenizer: ${tokenizer}
|
52 |
-
max_length: ${max_length}
|
53 |
-
|
54 |
-
# Model Configuration
|
55 |
-
model:
|
56 |
-
_target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
|
57 |
-
model:
|
58 |
-
_target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
|
59 |
-
path: ${pretrained_ckpt_path}
|
60 |
-
load_weights: true
|
61 |
-
max_length: ${max_length}
|
62 |
-
lora_config: null
|
63 |
-
|
64 |
-
optimizer:
|
65 |
-
_target_: torch.optim.AdamW
|
66 |
-
_partial_: true
|
67 |
-
lr: 1e-4
|
68 |
-
weight_decay: 0
|
69 |
-
betas: [0.9, 0.95]
|
70 |
-
eps: 1e-5
|
71 |
-
|
72 |
-
lr_scheduler:
|
73 |
-
_target_: torch.optim.lr_scheduler.LambdaLR
|
74 |
-
_partial_: true
|
75 |
-
lr_lambda:
|
76 |
-
_target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
|
77 |
-
_partial_: true
|
78 |
-
num_warmup_steps: 10
|
79 |
-
|
80 |
-
# Callbacks
|
81 |
-
callbacks:
|
82 |
-
model_checkpoint:
|
83 |
-
every_n_train_steps: ${trainer.val_check_interval}
|
|
|
1 |
+
defaults:
|
2 |
+
- base
|
3 |
+
- _self_
|
4 |
+
|
5 |
+
project: text2semantic_finetune_dual_ar
|
6 |
+
max_length: 4096
|
7 |
+
pretrained_ckpt_path: checkpoints/fish-speech-1.4
|
8 |
+
|
9 |
+
# Lightning Trainer
|
10 |
+
trainer:
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
gradient_clip_val: 1.0
|
13 |
+
gradient_clip_algorithm: "norm"
|
14 |
+
max_steps: 1000
|
15 |
+
precision: bf16-true
|
16 |
+
limit_val_batches: 10
|
17 |
+
val_check_interval: 100
|
18 |
+
|
19 |
+
# Dataset Configuration
|
20 |
+
tokenizer:
|
21 |
+
_target_: transformers.AutoTokenizer.from_pretrained
|
22 |
+
pretrained_model_name_or_path: ${pretrained_ckpt_path}
|
23 |
+
|
24 |
+
# Dataset Configuration
|
25 |
+
train_dataset:
|
26 |
+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
|
27 |
+
proto_files:
|
28 |
+
- data/protos
|
29 |
+
tokenizer: ${tokenizer}
|
30 |
+
causal: true
|
31 |
+
max_length: ${max_length}
|
32 |
+
use_speaker: false
|
33 |
+
interactive_prob: 0.7
|
34 |
+
|
35 |
+
val_dataset:
|
36 |
+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
|
37 |
+
proto_files:
|
38 |
+
- data/protos
|
39 |
+
tokenizer: ${tokenizer}
|
40 |
+
causal: true
|
41 |
+
max_length: ${max_length}
|
42 |
+
use_speaker: false
|
43 |
+
interactive_prob: 0.7
|
44 |
+
|
45 |
+
data:
|
46 |
+
_target_: fish_speech.datasets.semantic.SemanticDataModule
|
47 |
+
train_dataset: ${train_dataset}
|
48 |
+
val_dataset: ${val_dataset}
|
49 |
+
num_workers: 4
|
50 |
+
batch_size: 8
|
51 |
+
tokenizer: ${tokenizer}
|
52 |
+
max_length: ${max_length}
|
53 |
+
|
54 |
+
# Model Configuration
|
55 |
+
model:
|
56 |
+
_target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
|
57 |
+
model:
|
58 |
+
_target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
|
59 |
+
path: ${pretrained_ckpt_path}
|
60 |
+
load_weights: true
|
61 |
+
max_length: ${max_length}
|
62 |
+
lora_config: null
|
63 |
+
|
64 |
+
optimizer:
|
65 |
+
_target_: torch.optim.AdamW
|
66 |
+
_partial_: true
|
67 |
+
lr: 1e-4
|
68 |
+
weight_decay: 0
|
69 |
+
betas: [0.9, 0.95]
|
70 |
+
eps: 1e-5
|
71 |
+
|
72 |
+
lr_scheduler:
|
73 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
74 |
+
_partial_: true
|
75 |
+
lr_lambda:
|
76 |
+
_target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
|
77 |
+
_partial_: true
|
78 |
+
num_warmup_steps: 10
|
79 |
+
|
80 |
+
# Callbacks
|
81 |
+
callbacks:
|
82 |
+
model_checkpoint:
|
83 |
+
every_n_train_steps: ${trainer.val_check_interval}
|
fish_speech/configs/text2semantic_finetune_lora.yaml
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- text2semantic_finetune
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
project: text2semantic_finetune_dual_ar_lora
|
6 |
-
|
7 |
-
# Model Configuration
|
8 |
-
model:
|
9 |
-
save_lora_only: true
|
10 |
-
lora_config:
|
11 |
-
_target_: fish_speech.models.text2semantic.lit_module.LoraConfig
|
12 |
-
r: 8
|
13 |
-
lora_alpha: 16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/text2semantic_pretrain.yaml
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base
|
3 |
-
- [email protected]: dual_ar_2_codebook_small
|
4 |
-
- _self_
|
5 |
-
|
6 |
-
project: text2semantic_pretrain_dual_ar_debug
|
7 |
-
max_length: 2048
|
8 |
-
|
9 |
-
# Lightning Trainer
|
10 |
-
trainer:
|
11 |
-
accumulate_grad_batches: 1
|
12 |
-
gradient_clip_val: 1.0
|
13 |
-
gradient_clip_algorithm: 'norm'
|
14 |
-
max_steps: 1_000_000
|
15 |
-
precision: bf16-true
|
16 |
-
limit_val_batches: 10
|
17 |
-
|
18 |
-
# Dataset Configuration
|
19 |
-
tokenizer:
|
20 |
-
_target_: transformers.AutoTokenizer.from_pretrained
|
21 |
-
pretrained_model_name_or_path: fishaudio/fish-speech-1
|
22 |
-
|
23 |
-
# Dataset Configuration
|
24 |
-
train_dataset:
|
25 |
-
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
26 |
-
proto_files:
|
27 |
-
- data/protos/train
|
28 |
-
tokenizer: ${tokenizer}
|
29 |
-
max_length: ${max_length}
|
30 |
-
num_codebooks: ${model.model.config.num_codebooks}
|
31 |
-
use_speaker: false
|
32 |
-
interactive_prob: 0.5
|
33 |
-
|
34 |
-
val_dataset:
|
35 |
-
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
36 |
-
proto_files:
|
37 |
-
- data/protos/test
|
38 |
-
tokenizer: ${tokenizer}
|
39 |
-
max_length: ${max_length}
|
40 |
-
num_codebooks: ${model.model.config.num_codebooks}
|
41 |
-
use_speaker: false
|
42 |
-
interactive_prob: 0.5
|
43 |
-
|
44 |
-
data:
|
45 |
-
_target_: fish_speech.datasets.text.TextDataModule
|
46 |
-
train_dataset: ${train_dataset}
|
47 |
-
val_dataset: ${val_dataset}
|
48 |
-
num_workers: 4
|
49 |
-
batch_size: 8
|
50 |
-
tokenizer: ${tokenizer}
|
51 |
-
max_length: ${max_length}
|
52 |
-
|
53 |
-
# Model Configuration
|
54 |
-
model:
|
55 |
-
_target_: fish_speech.models.text2semantic.TextToSemantic
|
56 |
-
model: {}
|
57 |
-
|
58 |
-
optimizer:
|
59 |
-
_target_: torch.optim.AdamW
|
60 |
-
_partial_: true
|
61 |
-
lr: 3e-4
|
62 |
-
weight_decay: 0.01
|
63 |
-
betas: [0.9, 0.95]
|
64 |
-
eps: 1e-5
|
65 |
-
|
66 |
-
lr_scheduler:
|
67 |
-
_target_: torch.optim.lr_scheduler.LambdaLR
|
68 |
-
_partial_: true
|
69 |
-
lr_lambda:
|
70 |
-
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
71 |
-
_partial_: true
|
72 |
-
num_warmup_steps: 2000
|
73 |
-
num_training_steps: ${trainer.max_steps}
|
74 |
-
final_lr_ratio: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/text2semantic_sft.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base
|
3 |
-
- [email protected]: dual_ar_8_codebook_small
|
4 |
-
- _self_
|
5 |
-
|
6 |
-
project: text2semantic_sft_medium_dual_ar
|
7 |
-
max_length: 4096
|
8 |
-
ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
|
9 |
-
resume_weights_only: true
|
10 |
-
|
11 |
-
# Lightning Trainer
|
12 |
-
trainer:
|
13 |
-
accumulate_grad_batches: 1
|
14 |
-
gradient_clip_val: 1.0
|
15 |
-
gradient_clip_algorithm: 'norm'
|
16 |
-
max_steps: 10_000
|
17 |
-
precision: bf16-true
|
18 |
-
limit_val_batches: 10
|
19 |
-
val_check_interval: 500
|
20 |
-
|
21 |
-
# Dataset Configuration
|
22 |
-
tokenizer:
|
23 |
-
_target_: transformers.AutoTokenizer.from_pretrained
|
24 |
-
pretrained_model_name_or_path: fishaudio/speech-lm-v1
|
25 |
-
|
26 |
-
# Dataset Configuration
|
27 |
-
train_dataset:
|
28 |
-
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
29 |
-
use_data_server: false
|
30 |
-
proto_files:
|
31 |
-
- data/protos/sft/train_Genshin.protos
|
32 |
-
- data/protos/sft/sft.protos
|
33 |
-
tokenizer: ${tokenizer}
|
34 |
-
max_length: ${max_length}
|
35 |
-
num_codebooks: ${model.model.config.num_codebooks}
|
36 |
-
use_speaker: false
|
37 |
-
phones_prob: 0.5
|
38 |
-
interactive_prob: 0.5
|
39 |
-
|
40 |
-
val_dataset:
|
41 |
-
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
42 |
-
use_data_server: false
|
43 |
-
proto_files:
|
44 |
-
- data/protos/sft/val_Genshin.protos
|
45 |
-
tokenizer: ${tokenizer}
|
46 |
-
max_length: ${max_length}
|
47 |
-
num_codebooks: ${model.model.config.num_codebooks}
|
48 |
-
use_speaker: false
|
49 |
-
phones_prob: 0.5
|
50 |
-
interactive_prob: 0.5
|
51 |
-
|
52 |
-
data:
|
53 |
-
_target_: fish_speech.datasets.text.TextDataModule
|
54 |
-
train_dataset: ${train_dataset}
|
55 |
-
val_dataset: ${val_dataset}
|
56 |
-
num_workers: 4
|
57 |
-
batch_size: 8
|
58 |
-
tokenizer: ${tokenizer}
|
59 |
-
max_length: ${max_length}
|
60 |
-
|
61 |
-
# Model Configuration
|
62 |
-
model:
|
63 |
-
_target_: fish_speech.models.text2semantic.TextToSemantic
|
64 |
-
model: {}
|
65 |
-
|
66 |
-
optimizer:
|
67 |
-
_target_: torch.optim.AdamW
|
68 |
-
_partial_: true
|
69 |
-
lr: 4e-5
|
70 |
-
weight_decay: 0
|
71 |
-
betas: [0.9, 0.95]
|
72 |
-
eps: 1e-5
|
73 |
-
|
74 |
-
lr_scheduler:
|
75 |
-
_target_: torch.optim.lr_scheduler.LambdaLR
|
76 |
-
_partial_: true
|
77 |
-
lr_lambda:
|
78 |
-
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
79 |
-
_partial_: true
|
80 |
-
num_warmup_steps: 100
|
81 |
-
num_training_steps: ${trainer.max_steps}
|
82 |
-
final_lr_ratio: 0
|
83 |
-
|
84 |
-
callbacks:
|
85 |
-
model_checkpoint:
|
86 |
-
every_n_train_steps: 1000
|
87 |
-
save_top_k: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/vqgan_finetune.yaml
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
project: vq-gan-finetune
|
6 |
-
ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
|
7 |
-
resume_weights_only: true
|
8 |
-
|
9 |
-
# Lightning Trainer
|
10 |
-
trainer:
|
11 |
-
accelerator: gpu
|
12 |
-
devices: auto
|
13 |
-
precision: bf16-mixed
|
14 |
-
max_steps: 100_000
|
15 |
-
val_check_interval: 5000
|
16 |
-
strategy: ddp_find_unused_parameters_true
|
17 |
-
|
18 |
-
sample_rate: 44100
|
19 |
-
hop_length: 512
|
20 |
-
num_mels: 128
|
21 |
-
n_fft: 2048
|
22 |
-
win_length: 2048
|
23 |
-
freeze_encoder: true
|
24 |
-
|
25 |
-
# Dataset Configuration
|
26 |
-
train_dataset:
|
27 |
-
_target_: fish_speech.datasets.vqgan.VQGANDataset
|
28 |
-
filelist: data/filelist.train.txt
|
29 |
-
sample_rate: ${sample_rate}
|
30 |
-
hop_length: ${hop_length}
|
31 |
-
slice_frames: 512
|
32 |
-
|
33 |
-
val_dataset:
|
34 |
-
_target_: fish_speech.datasets.vqgan.VQGANDataset
|
35 |
-
filelist: data/filelist.val.txt
|
36 |
-
sample_rate: ${sample_rate}
|
37 |
-
hop_length: ${hop_length}
|
38 |
-
|
39 |
-
data:
|
40 |
-
_target_: fish_speech.datasets.vqgan.VQGANDataModule
|
41 |
-
train_dataset: ${train_dataset}
|
42 |
-
val_dataset: ${val_dataset}
|
43 |
-
num_workers: 4
|
44 |
-
batch_size: 16
|
45 |
-
val_batch_size: 16
|
46 |
-
|
47 |
-
# Model Configuration
|
48 |
-
model:
|
49 |
-
_target_: fish_speech.models.vqgan.VQGAN
|
50 |
-
|
51 |
-
sampling_rate: ${sample_rate}
|
52 |
-
weight_adv: 0.2
|
53 |
-
weight_vq: 1.0
|
54 |
-
weight_mel: 1.0
|
55 |
-
freeze_encoder: false
|
56 |
-
|
57 |
-
encoder:
|
58 |
-
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
59 |
-
input_channels: ${num_mels}
|
60 |
-
residual_channels: 768
|
61 |
-
residual_layers: 20
|
62 |
-
dilation_cycle: 4
|
63 |
-
|
64 |
-
quantizer:
|
65 |
-
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
66 |
-
input_dim: 768
|
67 |
-
n_codebooks: 1
|
68 |
-
n_groups: 2
|
69 |
-
levels: [8, 5, 5, 5]
|
70 |
-
|
71 |
-
decoder:
|
72 |
-
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
73 |
-
output_channels: ${num_mels}
|
74 |
-
residual_channels: 768
|
75 |
-
residual_layers: 20
|
76 |
-
dilation_cycle: 4
|
77 |
-
condition_channels: 768
|
78 |
-
|
79 |
-
discriminator:
|
80 |
-
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
|
81 |
-
|
82 |
-
vocoder:
|
83 |
-
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
|
84 |
-
ckpt_path: null # You may download the pretrained vocoder and set the path here
|
85 |
-
|
86 |
-
encode_mel_transform:
|
87 |
-
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
88 |
-
sample_rate: ${sample_rate}
|
89 |
-
n_fft: ${n_fft}
|
90 |
-
hop_length: ${hop_length}
|
91 |
-
win_length: ${win_length}
|
92 |
-
n_mels: ${num_mels}
|
93 |
-
f_min: 0.0
|
94 |
-
f_max: 8000.0
|
95 |
-
|
96 |
-
gt_mel_transform:
|
97 |
-
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
98 |
-
sample_rate: ${sample_rate}
|
99 |
-
n_fft: ${n_fft}
|
100 |
-
hop_length: ${hop_length}
|
101 |
-
win_length: ${win_length}
|
102 |
-
n_mels: ${num_mels}
|
103 |
-
|
104 |
-
optimizer:
|
105 |
-
_target_: torch.optim.AdamW
|
106 |
-
_partial_: true
|
107 |
-
lr: 4e-5
|
108 |
-
betas: [0.8, 0.99]
|
109 |
-
eps: 1e-5
|
110 |
-
weight_decay: 0.01
|
111 |
-
|
112 |
-
lr_scheduler:
|
113 |
-
_target_: torch.optim.lr_scheduler.LambdaLR
|
114 |
-
_partial_: true
|
115 |
-
lr_lambda:
|
116 |
-
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
117 |
-
_partial_: true
|
118 |
-
num_warmup_steps: 100
|
119 |
-
num_training_steps: ${trainer.max_steps}
|
120 |
-
final_lr_ratio: 0
|
121 |
-
|
122 |
-
callbacks:
|
123 |
-
model_summary:
|
124 |
-
_target_: lightning.pytorch.callbacks.ModelSummary
|
125 |
-
max_depth: 1
|
126 |
-
|
127 |
-
model_checkpoint:
|
128 |
-
every_n_train_steps: ${trainer.val_check_interval}
|
129 |
-
|
130 |
-
grad_norm_monitor:
|
131 |
-
sub_module:
|
132 |
-
- encoder
|
133 |
-
- decoder
|
134 |
-
- quantizer
|
135 |
-
- discriminator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/configs/vqgan_pretrain.yaml
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
project: vq-gan-pretrain
|
6 |
-
|
7 |
-
# Lightning Trainer
|
8 |
-
trainer:
|
9 |
-
accelerator: gpu
|
10 |
-
devices: auto
|
11 |
-
precision: bf16-mixed
|
12 |
-
max_steps: 1_000_000
|
13 |
-
val_check_interval: 5000
|
14 |
-
strategy: ddp_find_unused_parameters_true
|
15 |
-
|
16 |
-
sample_rate: 44100
|
17 |
-
hop_length: 512
|
18 |
-
num_mels: 128
|
19 |
-
n_fft: 2048
|
20 |
-
win_length: 2048
|
21 |
-
|
22 |
-
# Dataset Configuration
|
23 |
-
train_dataset:
|
24 |
-
_target_: torch.utils.data.ConcatDataset
|
25 |
-
datasets:
|
26 |
-
- _target_: fish_speech.datasets.vqgan.VQGANDataset
|
27 |
-
filelist: data/gigaspeech/vq_train_filelist.txt
|
28 |
-
sample_rate: ${sample_rate}
|
29 |
-
hop_length: ${hop_length}
|
30 |
-
slice_frames: 512
|
31 |
-
- _target_: fish_speech.datasets.vqgan.VQGANDataset
|
32 |
-
filelist: data/sft/vq_train_filelist.txt
|
33 |
-
sample_rate: ${sample_rate}
|
34 |
-
hop_length: ${hop_length}
|
35 |
-
slice_frames: 512
|
36 |
-
|
37 |
-
val_dataset:
|
38 |
-
_target_: fish_speech.datasets.vqgan.VQGANDataset
|
39 |
-
filelist: data/sft/vq_val_filelist.txt
|
40 |
-
sample_rate: ${sample_rate}
|
41 |
-
hop_length: ${hop_length}
|
42 |
-
|
43 |
-
data:
|
44 |
-
_target_: fish_speech.datasets.vqgan.VQGANDataModule
|
45 |
-
train_dataset: ${train_dataset}
|
46 |
-
val_dataset: ${val_dataset}
|
47 |
-
num_workers: 4
|
48 |
-
batch_size: 32
|
49 |
-
val_batch_size: 32
|
50 |
-
|
51 |
-
# Model Configuration
|
52 |
-
model:
|
53 |
-
_target_: fish_speech.models.vqgan.VQGAN
|
54 |
-
|
55 |
-
sampling_rate: ${sample_rate}
|
56 |
-
weight_adv: 0.2
|
57 |
-
weight_vq: 1.0
|
58 |
-
weight_mel: 1.0
|
59 |
-
freeze_encoder: false
|
60 |
-
|
61 |
-
encoder:
|
62 |
-
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
63 |
-
input_channels: ${num_mels}
|
64 |
-
residual_channels: 768
|
65 |
-
residual_layers: 20
|
66 |
-
dilation_cycle: 4
|
67 |
-
|
68 |
-
quantizer:
|
69 |
-
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
70 |
-
input_dim: 768
|
71 |
-
n_codebooks: 1
|
72 |
-
n_groups: 2
|
73 |
-
levels: [8, 5, 5, 5]
|
74 |
-
|
75 |
-
decoder:
|
76 |
-
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
|
77 |
-
output_channels: ${num_mels}
|
78 |
-
residual_channels: 768
|
79 |
-
residual_layers: 20
|
80 |
-
dilation_cycle: 4
|
81 |
-
condition_channels: 768
|
82 |
-
|
83 |
-
discriminator:
|
84 |
-
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
|
85 |
-
|
86 |
-
vocoder:
|
87 |
-
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
|
88 |
-
ckpt_path: null # You may download the pretrained vocoder and set the path here
|
89 |
-
|
90 |
-
encode_mel_transform:
|
91 |
-
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
92 |
-
sample_rate: ${sample_rate}
|
93 |
-
n_fft: ${n_fft}
|
94 |
-
hop_length: ${hop_length}
|
95 |
-
win_length: ${win_length}
|
96 |
-
n_mels: ${num_mels}
|
97 |
-
f_min: 0.0
|
98 |
-
f_max: 8000.0
|
99 |
-
|
100 |
-
gt_mel_transform:
|
101 |
-
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
|
102 |
-
sample_rate: ${sample_rate}
|
103 |
-
n_fft: ${n_fft}
|
104 |
-
hop_length: ${hop_length}
|
105 |
-
win_length: ${win_length}
|
106 |
-
n_mels: ${num_mels}
|
107 |
-
|
108 |
-
optimizer:
|
109 |
-
_target_: torch.optim.AdamW
|
110 |
-
_partial_: true
|
111 |
-
lr: 1e-4
|
112 |
-
betas: [0.8, 0.99]
|
113 |
-
eps: 1e-5
|
114 |
-
weight_decay: 0.01
|
115 |
-
|
116 |
-
lr_scheduler:
|
117 |
-
_target_: torch.optim.lr_scheduler.LambdaLR
|
118 |
-
_partial_: true
|
119 |
-
lr_lambda:
|
120 |
-
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
121 |
-
_partial_: true
|
122 |
-
num_warmup_steps: 100
|
123 |
-
num_training_steps: ${trainer.max_steps}
|
124 |
-
final_lr_ratio: 0
|
125 |
-
|
126 |
-
callbacks:
|
127 |
-
model_summary:
|
128 |
-
_target_: lightning.pytorch.callbacks.ModelSummary
|
129 |
-
max_depth: 1
|
130 |
-
|
131 |
-
model_checkpoint:
|
132 |
-
every_n_train_steps: ${trainer.val_check_interval}
|
133 |
-
|
134 |
-
grad_norm_monitor:
|
135 |
-
sub_module:
|
136 |
-
- encoder
|
137 |
-
- decoder
|
138 |
-
- quantizer
|
139 |
-
- discriminator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/conversation.py
CHANGED
@@ -1,2 +1,267 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .tokenizer import MODALITY_TOKENS, FishTokenizer
|
7 |
+
|
8 |
+
CODEBOOK_PAD_TOKEN_ID = 0
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass(kw_only=True)
|
12 |
+
class BasePart:
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass(kw_only=True)
|
17 |
+
class VQPart(BasePart):
|
18 |
+
codes: torch.Tensor
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass(kw_only=True)
|
22 |
+
class TextPart(BasePart):
|
23 |
+
text: str
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass(kw_only=True)
|
27 |
+
class EncodedMessage:
|
28 |
+
tokens: torch.Tensor
|
29 |
+
labels: torch.Tensor
|
30 |
+
vq_mask_tokens: torch.Tensor | None = None
|
31 |
+
vq_mask_labels: torch.Tensor | None = None
|
32 |
+
vq_parts: list[torch.Tensor]
|
33 |
+
vq_require_losses: torch.Tensor | None = None
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass(kw_only=True)
|
37 |
+
class Message:
|
38 |
+
role: Literal["system", "user", "assistant"]
|
39 |
+
parts: list[VQPart | TextPart] = field(default_factory=list)
|
40 |
+
add_im_start: bool = True
|
41 |
+
add_im_end: bool = True
|
42 |
+
cal_loss: bool = False
|
43 |
+
modality: Literal["text", "voice", "interleave"] | None = None
|
44 |
+
|
45 |
+
# By default, ignore the loss of the auto-generated im_start token
|
46 |
+
ignore_im_start_loss: bool = True
|
47 |
+
|
48 |
+
def encode(
|
49 |
+
self: "Message",
|
50 |
+
tokenizer: FishTokenizer,
|
51 |
+
) -> EncodedMessage:
|
52 |
+
all_tokens = []
|
53 |
+
all_labels = []
|
54 |
+
|
55 |
+
# Multi-modal tokens
|
56 |
+
vq_parts = []
|
57 |
+
vq_masks = []
|
58 |
+
|
59 |
+
parts = self.parts.copy()
|
60 |
+
if self.add_im_start:
|
61 |
+
modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
|
62 |
+
parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
|
63 |
+
|
64 |
+
if self.add_im_end:
|
65 |
+
parts.append(TextPart(text="<|im_end|>"))
|
66 |
+
|
67 |
+
for part in parts:
|
68 |
+
if isinstance(part, TextPart):
|
69 |
+
tokens = torch.tensor(
|
70 |
+
tokenizer.encode(part.text),
|
71 |
+
dtype=torch.int,
|
72 |
+
)
|
73 |
+
elif isinstance(part, VQPart):
|
74 |
+
curr_codes = part.codes.clone()
|
75 |
+
tokens = torch.tensor(
|
76 |
+
[
|
77 |
+
tokenizer.semantic_id_to_token_id[i.item()]
|
78 |
+
for i in curr_codes[0].int()
|
79 |
+
],
|
80 |
+
dtype=torch.int,
|
81 |
+
)
|
82 |
+
vq_parts.append(curr_codes)
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Unsupported part type: {type(part)}")
|
85 |
+
|
86 |
+
all_tokens.append(tokens)
|
87 |
+
if isinstance(part, VQPart):
|
88 |
+
vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
|
89 |
+
else:
|
90 |
+
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
|
91 |
+
|
92 |
+
if self.cal_loss:
|
93 |
+
all_labels.append(tokens.clone())
|
94 |
+
else:
|
95 |
+
all_labels.append(torch.full_like(tokens, -100))
|
96 |
+
|
97 |
+
tokens = torch.cat(all_tokens, dim=0)
|
98 |
+
labels = torch.cat(all_labels, dim=0)
|
99 |
+
vq_masks = torch.cat(vq_masks, dim=0)
|
100 |
+
|
101 |
+
assert tokens.shape == labels.shape == vq_masks.shape
|
102 |
+
|
103 |
+
if self.ignore_im_start_loss and self.add_im_start:
|
104 |
+
labels[: len(all_tokens[0])] = -100
|
105 |
+
|
106 |
+
return EncodedMessage(
|
107 |
+
tokens=tokens,
|
108 |
+
labels=labels,
|
109 |
+
vq_parts=vq_parts,
|
110 |
+
vq_mask_tokens=vq_masks,
|
111 |
+
vq_mask_labels=vq_masks,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
@dataclass
|
116 |
+
class Conversation:
|
117 |
+
messages: list[Message]
|
118 |
+
|
119 |
+
def __init__(self: "Conversation", messages: list[Message] | None = None):
|
120 |
+
self.messages = messages or []
|
121 |
+
|
122 |
+
def encode(
|
123 |
+
self: "Conversation",
|
124 |
+
tokenizer: FishTokenizer,
|
125 |
+
add_shift: bool = True,
|
126 |
+
ignore_loss_tokens: list[str] = [],
|
127 |
+
) -> EncodedMessage:
|
128 |
+
# Build the input_ids and labels
|
129 |
+
tokens = []
|
130 |
+
labels = []
|
131 |
+
vq_parts = []
|
132 |
+
vq_mask_tokens = []
|
133 |
+
vq_mask_labels = []
|
134 |
+
vq_require_losses = []
|
135 |
+
ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
|
136 |
+
|
137 |
+
for message in self.messages:
|
138 |
+
encoded = message.encode(
|
139 |
+
tokenizer,
|
140 |
+
)
|
141 |
+
tokens.append(encoded.tokens)
|
142 |
+
labels.append(encoded.labels)
|
143 |
+
vq_parts.extend(encoded.vq_parts)
|
144 |
+
vq_mask_tokens.append(encoded.vq_mask_tokens)
|
145 |
+
vq_mask_labels.append(encoded.vq_mask_labels)
|
146 |
+
vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
|
147 |
+
|
148 |
+
tokens = torch.cat(tokens, dim=0)
|
149 |
+
labels = torch.cat(labels, dim=0)
|
150 |
+
vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
|
151 |
+
vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
|
152 |
+
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
|
153 |
+
|
154 |
+
if add_shift:
|
155 |
+
tokens = tokens[:-1]
|
156 |
+
labels = labels[1:]
|
157 |
+
vq_mask_tokens = vq_mask_tokens[:-1]
|
158 |
+
vq_mask_labels = vq_mask_labels[1:]
|
159 |
+
|
160 |
+
for i in ignore_loss_token_ids:
|
161 |
+
assert i != -100 and i is not None
|
162 |
+
labels[labels == i] = -100
|
163 |
+
|
164 |
+
assert tokens.dtype in [
|
165 |
+
torch.int,
|
166 |
+
torch.long,
|
167 |
+
], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
|
168 |
+
|
169 |
+
return EncodedMessage(
|
170 |
+
tokens=tokens,
|
171 |
+
labels=labels,
|
172 |
+
vq_parts=vq_parts,
|
173 |
+
vq_mask_tokens=vq_mask_tokens,
|
174 |
+
vq_mask_labels=vq_mask_labels,
|
175 |
+
vq_require_losses=vq_require_losses,
|
176 |
+
)
|
177 |
+
|
178 |
+
def encode_for_inference(
|
179 |
+
self: "Conversation",
|
180 |
+
tokenizer: FishTokenizer,
|
181 |
+
num_codebooks: int,
|
182 |
+
) -> EncodedMessage:
|
183 |
+
# self.visualize(tokenizer)
|
184 |
+
|
185 |
+
encoded = self.encode(tokenizer, add_shift=False)
|
186 |
+
tokens = encoded.tokens
|
187 |
+
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
|
188 |
+
values[0] = tokens
|
189 |
+
|
190 |
+
if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
|
191 |
+
return values
|
192 |
+
|
193 |
+
vq_parts = encoded.vq_parts
|
194 |
+
vq_parts = [part.to(values.device) for part in vq_parts]
|
195 |
+
vq_parts = torch.cat(vq_parts, dim=1)
|
196 |
+
values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
|
197 |
+
values[1:, encoded.vq_mask_tokens] = vq_parts
|
198 |
+
|
199 |
+
return values
|
200 |
+
|
201 |
+
def visualize(
|
202 |
+
self: "Conversation",
|
203 |
+
tokenizer: FishTokenizer,
|
204 |
+
ignore_loss_tokens: list[str] = [],
|
205 |
+
):
|
206 |
+
encoded = self.encode(
|
207 |
+
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
|
208 |
+
)
|
209 |
+
|
210 |
+
# Colors for alternating tokens
|
211 |
+
colors = {
|
212 |
+
"blue": "\033[94m", # Light blue
|
213 |
+
"cyan": "\033[96m", # Cyan
|
214 |
+
"green": "\033[92m", # Light green
|
215 |
+
"dark_green": "\033[32m", # Dark green
|
216 |
+
}
|
217 |
+
blue_idx = 0
|
218 |
+
green_idx = 0
|
219 |
+
|
220 |
+
def print_in_blue(x):
|
221 |
+
nonlocal blue_idx
|
222 |
+
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
|
223 |
+
print(f"{color}{x}\033[0m", end="")
|
224 |
+
blue_idx += 1
|
225 |
+
|
226 |
+
def print_in_green(x):
|
227 |
+
nonlocal green_idx
|
228 |
+
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
|
229 |
+
print(f"{color}{x}\033[0m", end="")
|
230 |
+
green_idx += 1
|
231 |
+
|
232 |
+
for tok, lab in zip(encoded.tokens, encoded.labels):
|
233 |
+
val = tokenizer.decode([tok])
|
234 |
+
|
235 |
+
if lab == -100:
|
236 |
+
print_in_green(val)
|
237 |
+
else:
|
238 |
+
print_in_blue(val)
|
239 |
+
|
240 |
+
print()
|
241 |
+
|
242 |
+
def append(self: "Conversation", message: Message):
|
243 |
+
self.messages.append(message)
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
message0 = Message(
|
248 |
+
role="user",
|
249 |
+
parts=[
|
250 |
+
TextPart(text="Hello, how are you?"),
|
251 |
+
VQPart(codes=torch.zeros((4, 10))),
|
252 |
+
],
|
253 |
+
cal_loss=False,
|
254 |
+
)
|
255 |
+
|
256 |
+
message1 = Message(
|
257 |
+
role="assistant",
|
258 |
+
parts=[TextPart(text="I'm fine, thank you.")],
|
259 |
+
cal_loss=True,
|
260 |
+
)
|
261 |
+
conversation = Conversation([message0, message1])
|
262 |
+
tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
|
263 |
+
conversation.visualize(tokenizer)
|
264 |
+
|
265 |
+
encoded = conversation.encode(tokenizer)
|
266 |
+
print(encoded)
|
267 |
+
print(tokenizer.batch_decode(encoded.tokens))
|
fish_speech/datasets/concat_repeat.py
CHANGED
@@ -1,53 +1,53 @@
|
|
1 |
-
import bisect
|
2 |
-
import random
|
3 |
-
from typing import Iterable
|
4 |
-
|
5 |
-
from torch.utils.data import Dataset, IterableDataset
|
6 |
-
|
7 |
-
|
8 |
-
class ConcatRepeatDataset(Dataset):
|
9 |
-
datasets: list[Dataset]
|
10 |
-
cumulative_sizes: list[int]
|
11 |
-
repeats: list[int]
|
12 |
-
|
13 |
-
@staticmethod
|
14 |
-
def cumsum(sequence, repeats):
|
15 |
-
r, s = [], 0
|
16 |
-
for dataset, repeat in zip(sequence, repeats):
|
17 |
-
l = len(dataset) * repeat
|
18 |
-
r.append(l + s)
|
19 |
-
s += l
|
20 |
-
return r
|
21 |
-
|
22 |
-
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
|
23 |
-
super().__init__()
|
24 |
-
|
25 |
-
self.datasets = list(datasets)
|
26 |
-
self.repeats = repeats
|
27 |
-
|
28 |
-
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
29 |
-
assert len(self.datasets) == len(
|
30 |
-
repeats
|
31 |
-
), "datasets and repeats should have the same length"
|
32 |
-
|
33 |
-
for d in self.datasets:
|
34 |
-
assert not isinstance(
|
35 |
-
d, IterableDataset
|
36 |
-
), "ConcatRepeatDataset does not support IterableDataset"
|
37 |
-
|
38 |
-
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
39 |
-
|
40 |
-
def __len__(self):
|
41 |
-
return self.cumulative_sizes[-1]
|
42 |
-
|
43 |
-
def __getitem__(self, idx):
|
44 |
-
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
45 |
-
|
46 |
-
if dataset_idx == 0:
|
47 |
-
sample_idx = idx
|
48 |
-
else:
|
49 |
-
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
50 |
-
|
51 |
-
dataset = self.datasets[dataset_idx]
|
52 |
-
|
53 |
-
return dataset[sample_idx % len(dataset)]
|
|
|
1 |
+
import bisect
|
2 |
+
import random
|
3 |
+
from typing import Iterable
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset, IterableDataset
|
6 |
+
|
7 |
+
|
8 |
+
class ConcatRepeatDataset(Dataset):
|
9 |
+
datasets: list[Dataset]
|
10 |
+
cumulative_sizes: list[int]
|
11 |
+
repeats: list[int]
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def cumsum(sequence, repeats):
|
15 |
+
r, s = [], 0
|
16 |
+
for dataset, repeat in zip(sequence, repeats):
|
17 |
+
l = len(dataset) * repeat
|
18 |
+
r.append(l + s)
|
19 |
+
s += l
|
20 |
+
return r
|
21 |
+
|
22 |
+
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.datasets = list(datasets)
|
26 |
+
self.repeats = repeats
|
27 |
+
|
28 |
+
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
29 |
+
assert len(self.datasets) == len(
|
30 |
+
repeats
|
31 |
+
), "datasets and repeats should have the same length"
|
32 |
+
|
33 |
+
for d in self.datasets:
|
34 |
+
assert not isinstance(
|
35 |
+
d, IterableDataset
|
36 |
+
), "ConcatRepeatDataset does not support IterableDataset"
|
37 |
+
|
38 |
+
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return self.cumulative_sizes[-1]
|
42 |
+
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
45 |
+
|
46 |
+
if dataset_idx == 0:
|
47 |
+
sample_idx = idx
|
48 |
+
else:
|
49 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
50 |
+
|
51 |
+
dataset = self.datasets[dataset_idx]
|
52 |
+
|
53 |
+
return dataset[sample_idx % len(dataset)]
|
fish_speech/datasets/protos/text-data.proto
CHANGED
@@ -1,24 +1,24 @@
|
|
1 |
-
syntax = "proto3";
|
2 |
-
|
3 |
-
package text_data;
|
4 |
-
|
5 |
-
message Semantics {
|
6 |
-
repeated uint32 values = 1;
|
7 |
-
}
|
8 |
-
|
9 |
-
message Sentence {
|
10 |
-
repeated string texts = 1;
|
11 |
-
repeated Semantics semantics = 3;
|
12 |
-
}
|
13 |
-
|
14 |
-
message TextData {
|
15 |
-
string source = 1;
|
16 |
-
string name = 2;
|
17 |
-
repeated Sentence sentences = 4;
|
18 |
-
}
|
19 |
-
|
20 |
-
message SampledData {
|
21 |
-
string source = 1;
|
22 |
-
string name = 2;
|
23 |
-
repeated Sentence samples = 3;
|
24 |
-
}
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package text_data;
|
4 |
+
|
5 |
+
message Semantics {
|
6 |
+
repeated uint32 values = 1;
|
7 |
+
}
|
8 |
+
|
9 |
+
message Sentence {
|
10 |
+
repeated string texts = 1;
|
11 |
+
repeated Semantics semantics = 3;
|
12 |
+
}
|
13 |
+
|
14 |
+
message TextData {
|
15 |
+
string source = 1;
|
16 |
+
string name = 2;
|
17 |
+
repeated Sentence sentences = 4;
|
18 |
+
}
|
19 |
+
|
20 |
+
message SampledData {
|
21 |
+
string source = 1;
|
22 |
+
string name = 2;
|
23 |
+
repeated Sentence samples = 3;
|
24 |
+
}
|
fish_speech/datasets/protos/text_data_pb2.py
CHANGED
@@ -1,33 +1,33 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
-
# source: text-data.proto
|
4 |
-
# Protobuf Python Version: 4.25.1
|
5 |
-
"""Generated protocol buffer code."""
|
6 |
-
from google.protobuf import descriptor as _descriptor
|
7 |
-
from google.protobuf import descriptor_pool as _descriptor_pool
|
8 |
-
from google.protobuf import symbol_database as _symbol_database
|
9 |
-
from google.protobuf.internal import builder as _builder
|
10 |
-
|
11 |
-
# @@protoc_insertion_point(imports)
|
12 |
-
|
13 |
-
_sym_db = _symbol_database.Default()
|
14 |
-
|
15 |
-
|
16 |
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
17 |
-
b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
|
18 |
-
)
|
19 |
-
|
20 |
-
_globals = globals()
|
21 |
-
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
22 |
-
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
|
23 |
-
if _descriptor._USE_C_DESCRIPTORS == False:
|
24 |
-
DESCRIPTOR._options = None
|
25 |
-
_globals["_SEMANTICS"]._serialized_start = 30
|
26 |
-
_globals["_SEMANTICS"]._serialized_end = 57
|
27 |
-
_globals["_SENTENCE"]._serialized_start = 59
|
28 |
-
_globals["_SENTENCE"]._serialized_end = 125
|
29 |
-
_globals["_TEXTDATA"]._serialized_start = 127
|
30 |
-
_globals["_TEXTDATA"]._serialized_end = 207
|
31 |
-
_globals["_SAMPLEDDATA"]._serialized_start = 209
|
32 |
-
_globals["_SAMPLEDDATA"]._serialized_end = 290
|
33 |
-
# @@protoc_insertion_point(module_scope)
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: text-data.proto
|
4 |
+
# Protobuf Python Version: 4.25.1
|
5 |
+
"""Generated protocol buffer code."""
|
6 |
+
from google.protobuf import descriptor as _descriptor
|
7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
from google.protobuf.internal import builder as _builder
|
10 |
+
|
11 |
+
# @@protoc_insertion_point(imports)
|
12 |
+
|
13 |
+
_sym_db = _symbol_database.Default()
|
14 |
+
|
15 |
+
|
16 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
17 |
+
b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
|
18 |
+
)
|
19 |
+
|
20 |
+
_globals = globals()
|
21 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
22 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
|
23 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
24 |
+
DESCRIPTOR._options = None
|
25 |
+
_globals["_SEMANTICS"]._serialized_start = 30
|
26 |
+
_globals["_SEMANTICS"]._serialized_end = 57
|
27 |
+
_globals["_SENTENCE"]._serialized_start = 59
|
28 |
+
_globals["_SENTENCE"]._serialized_end = 125
|
29 |
+
_globals["_TEXTDATA"]._serialized_start = 127
|
30 |
+
_globals["_TEXTDATA"]._serialized_end = 207
|
31 |
+
_globals["_SAMPLEDDATA"]._serialized_start = 209
|
32 |
+
_globals["_SAMPLEDDATA"]._serialized_end = 290
|
33 |
+
# @@protoc_insertion_point(module_scope)
|
fish_speech/datasets/protos/text_data_stream.py
CHANGED
@@ -1,36 +1,36 @@
|
|
1 |
-
import struct
|
2 |
-
|
3 |
-
from .text_data_pb2 import TextData
|
4 |
-
|
5 |
-
|
6 |
-
def read_pb_stream(f):
|
7 |
-
while True:
|
8 |
-
buf = f.read(4)
|
9 |
-
if len(buf) == 0:
|
10 |
-
break
|
11 |
-
size = struct.unpack("I", buf)[0]
|
12 |
-
buf = f.read(size)
|
13 |
-
text_data = TextData()
|
14 |
-
text_data.ParseFromString(buf)
|
15 |
-
yield text_data
|
16 |
-
|
17 |
-
|
18 |
-
def write_pb_stream(f, text_data):
|
19 |
-
buf = text_data.SerializeToString()
|
20 |
-
f.write(struct.pack("I", len(buf)))
|
21 |
-
f.write(buf)
|
22 |
-
|
23 |
-
|
24 |
-
def pack_pb_stream(text_data):
|
25 |
-
buf = text_data.SerializeToString()
|
26 |
-
return struct.pack("I", len(buf)) + buf
|
27 |
-
|
28 |
-
|
29 |
-
def split_pb_stream(f):
|
30 |
-
while True:
|
31 |
-
head = f.read(4)
|
32 |
-
if len(head) == 0:
|
33 |
-
break
|
34 |
-
size = struct.unpack("I", head)[0]
|
35 |
-
buf = f.read(size)
|
36 |
-
yield head + buf
|
|
|
1 |
+
import struct
|
2 |
+
|
3 |
+
from .text_data_pb2 import TextData
|
4 |
+
|
5 |
+
|
6 |
+
def read_pb_stream(f):
|
7 |
+
while True:
|
8 |
+
buf = f.read(4)
|
9 |
+
if len(buf) == 0:
|
10 |
+
break
|
11 |
+
size = struct.unpack("I", buf)[0]
|
12 |
+
buf = f.read(size)
|
13 |
+
text_data = TextData()
|
14 |
+
text_data.ParseFromString(buf)
|
15 |
+
yield text_data
|
16 |
+
|
17 |
+
|
18 |
+
def write_pb_stream(f, text_data):
|
19 |
+
buf = text_data.SerializeToString()
|
20 |
+
f.write(struct.pack("I", len(buf)))
|
21 |
+
f.write(buf)
|
22 |
+
|
23 |
+
|
24 |
+
def pack_pb_stream(text_data):
|
25 |
+
buf = text_data.SerializeToString()
|
26 |
+
return struct.pack("I", len(buf)) + buf
|
27 |
+
|
28 |
+
|
29 |
+
def split_pb_stream(f):
|
30 |
+
while True:
|
31 |
+
head = f.read(4)
|
32 |
+
if len(head) == 0:
|
33 |
+
break
|
34 |
+
size = struct.unpack("I", head)[0]
|
35 |
+
buf = f.read(size)
|
36 |
+
yield head + buf
|
fish_speech/datasets/semantic.py
CHANGED
@@ -1,496 +1,496 @@
|
|
1 |
-
import random
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from itertools import chain
|
4 |
-
from pathlib import Path
|
5 |
-
from random import Random
|
6 |
-
from typing import Optional, Union
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import pyarrow.parquet as pq
|
10 |
-
import torch
|
11 |
-
import torch.nn.functional as F
|
12 |
-
from datasets.download.streaming_download_manager import xopen
|
13 |
-
from huggingface_hub import HfApi
|
14 |
-
from lightning import LightningDataModule
|
15 |
-
from torch.distributed import get_rank, get_world_size, is_initialized
|
16 |
-
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
17 |
-
from transformers import AutoTokenizer
|
18 |
-
|
19 |
-
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
20 |
-
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
21 |
-
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
22 |
-
from fish_speech.text.clean import clean_text
|
23 |
-
from fish_speech.utils import RankedLogger
|
24 |
-
from fish_speech.utils.braceexpand import braceexpand
|
25 |
-
|
26 |
-
log = RankedLogger(__name__, rank_zero_only=True)
|
27 |
-
|
28 |
-
|
29 |
-
def split_by_rank_worker(files):
|
30 |
-
# We need to know the total number of devices
|
31 |
-
# to split the data properly
|
32 |
-
|
33 |
-
total_devices = 1
|
34 |
-
if is_initialized():
|
35 |
-
total_devices = get_world_size()
|
36 |
-
|
37 |
-
worker_info = get_worker_info()
|
38 |
-
if worker_info is not None:
|
39 |
-
total_devices *= worker_info.num_workers
|
40 |
-
|
41 |
-
if len(files) < total_devices:
|
42 |
-
# Repeat the files N times to match the number of devices
|
43 |
-
files = files * (total_devices // len(files) + 1)
|
44 |
-
|
45 |
-
# DDP
|
46 |
-
if is_initialized():
|
47 |
-
files = files[get_rank() :: get_world_size()]
|
48 |
-
|
49 |
-
# Split by worker
|
50 |
-
if worker_info is not None:
|
51 |
-
files = files[worker_info.id :: worker_info.num_workers]
|
52 |
-
|
53 |
-
return files
|
54 |
-
|
55 |
-
|
56 |
-
class AutoTextSemanticInstructionDataset(IterableDataset):
|
57 |
-
"""
|
58 |
-
Auto Augment Dataset by Speaker
|
59 |
-
|
60 |
-
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
61 |
-
2. Automatically normalize the text
|
62 |
-
|
63 |
-
For interactive mode, we use the following format (multiple sequences):
|
64 |
-
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
65 |
-
|
66 |
-
For non-interactive mode, we use the following format (one long sequence):
|
67 |
-
<s> [INST] text [/INST] ... </s>
|
68 |
-
"""
|
69 |
-
|
70 |
-
def __init__(
|
71 |
-
self,
|
72 |
-
proto_files: list[str],
|
73 |
-
seed: int = 42,
|
74 |
-
interactive_prob: float = 0.5,
|
75 |
-
max_length: int = 1024,
|
76 |
-
tokenizer: AutoTokenizer = None,
|
77 |
-
use_speaker: bool | float = True,
|
78 |
-
causal: bool = True,
|
79 |
-
num_codebooks: Optional[int] = None,
|
80 |
-
skip_text_prob: float = 0.0,
|
81 |
-
):
|
82 |
-
"""
|
83 |
-
Args:
|
84 |
-
proto_files: proto buf files if using local data
|
85 |
-
seed: random seed
|
86 |
-
interactive_prob: probability to use interactive mode
|
87 |
-
max_length: max length of the text
|
88 |
-
tokenizer: tokenizer
|
89 |
-
use_speaker: include speaker information in the prompt
|
90 |
-
causal: use causal sampling when using local data, disable will lead to random sampling
|
91 |
-
num_codebooks: number of codebooks, if None, it will be automatically detected
|
92 |
-
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
93 |
-
"""
|
94 |
-
|
95 |
-
super().__init__()
|
96 |
-
|
97 |
-
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
98 |
-
|
99 |
-
self.seed = seed
|
100 |
-
self.max_length = max_length
|
101 |
-
self.tokenizer = tokenizer
|
102 |
-
self.interactive_prob = interactive_prob
|
103 |
-
self.use_speaker = use_speaker
|
104 |
-
self.proto_files = proto_files
|
105 |
-
self.causal = causal
|
106 |
-
self.num_codebooks = num_codebooks
|
107 |
-
self.skip_text_prob = skip_text_prob
|
108 |
-
|
109 |
-
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
110 |
-
self.groups = None
|
111 |
-
|
112 |
-
def init_mock_data_server(self):
|
113 |
-
if self.groups is not None:
|
114 |
-
return
|
115 |
-
|
116 |
-
# Expand the proto files
|
117 |
-
expanded_proto_files = []
|
118 |
-
for filename in self.proto_files:
|
119 |
-
for i in braceexpand(filename):
|
120 |
-
i = Path(i)
|
121 |
-
if i.is_file():
|
122 |
-
expanded_proto_files.append(i)
|
123 |
-
elif i.is_dir():
|
124 |
-
expanded_proto_files.extend(i.rglob("*.proto"))
|
125 |
-
expanded_proto_files.extend(i.rglob("*.protos"))
|
126 |
-
else:
|
127 |
-
raise ValueError(f"{i} is not a file or directory")
|
128 |
-
|
129 |
-
expanded_proto_files = sorted(expanded_proto_files)
|
130 |
-
Random(self.seed).shuffle(expanded_proto_files)
|
131 |
-
|
132 |
-
self.groups = []
|
133 |
-
shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
134 |
-
log.info(
|
135 |
-
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
136 |
-
)
|
137 |
-
|
138 |
-
count = 0
|
139 |
-
for filename in shard_proto_files:
|
140 |
-
with open(filename, "rb") as f:
|
141 |
-
for text_data in read_pb_stream(f):
|
142 |
-
self.groups.append(text_data)
|
143 |
-
count += 1
|
144 |
-
|
145 |
-
log.info(f"Read total {count} groups of data")
|
146 |
-
|
147 |
-
# Shuffle the lines
|
148 |
-
Random(self.seed).shuffle(self.groups)
|
149 |
-
self.group_weights = [len(i.sentences) for i in self.groups]
|
150 |
-
|
151 |
-
def __iter__(self):
|
152 |
-
while True:
|
153 |
-
yield self.augment()
|
154 |
-
|
155 |
-
def tokenize_sentence(self, sentence: str):
|
156 |
-
sentence = clean_text(sentence)
|
157 |
-
tokens = self.tokenizer.encode(
|
158 |
-
f"{sentence}",
|
159 |
-
max_length=10**6,
|
160 |
-
add_special_tokens=False,
|
161 |
-
truncation=False,
|
162 |
-
)
|
163 |
-
return sentence, len(tokens)
|
164 |
-
|
165 |
-
def sample_data(self):
|
166 |
-
if self.groups is None:
|
167 |
-
self.init_mock_data_server()
|
168 |
-
|
169 |
-
# Shuffle unique lines, estimate that each sample is at least 20 tokens
|
170 |
-
num_samples = self.max_length // 20
|
171 |
-
|
172 |
-
# choice group based on their number of samples
|
173 |
-
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
174 |
-
|
175 |
-
if self.causal:
|
176 |
-
# Sample in order
|
177 |
-
if num_samples >= len(group.sentences):
|
178 |
-
samples = group.sentences
|
179 |
-
else:
|
180 |
-
begin = random.randint(0, len(group.sentences) - num_samples)
|
181 |
-
samples = group.sentences[begin : begin + num_samples]
|
182 |
-
else:
|
183 |
-
samples = random.choices(
|
184 |
-
group.sentences, k=min(num_samples, len(group.sentences))
|
185 |
-
)
|
186 |
-
|
187 |
-
return SampledData(
|
188 |
-
source=group.source,
|
189 |
-
name=group.name,
|
190 |
-
samples=samples,
|
191 |
-
)
|
192 |
-
|
193 |
-
def augment(self):
|
194 |
-
final_text, final_semantic = [], []
|
195 |
-
response = self.sample_data()
|
196 |
-
if len(response.samples) == 0:
|
197 |
-
# Invalid group
|
198 |
-
return None
|
199 |
-
|
200 |
-
samples = list(response.samples)
|
201 |
-
idx = 0
|
202 |
-
use_interactive = random.random() < self.interactive_prob
|
203 |
-
|
204 |
-
if use_interactive is False:
|
205 |
-
# Random sample based on speaker using a truncated normal distribution
|
206 |
-
a = torch.tensor([0], dtype=torch.float32)
|
207 |
-
torch.nn.init.trunc_normal_(
|
208 |
-
a,
|
209 |
-
mean=self.max_length // 2,
|
210 |
-
std=self.max_length // 4,
|
211 |
-
a=10,
|
212 |
-
b=self.max_length,
|
213 |
-
)
|
214 |
-
remaining_tokens = a.long().item() - 4
|
215 |
-
else:
|
216 |
-
remaining_tokens = self.max_length
|
217 |
-
|
218 |
-
# Use speaker
|
219 |
-
if isinstance(self.use_speaker, float):
|
220 |
-
use_speaker = random.random() < self.use_speaker
|
221 |
-
else:
|
222 |
-
use_speaker = self.use_speaker
|
223 |
-
|
224 |
-
all_tokens, all_labels = [], []
|
225 |
-
while remaining_tokens > 0 and len(samples) > 0:
|
226 |
-
sentence = samples.pop(0)
|
227 |
-
|
228 |
-
text = random.choice(sentence.texts)
|
229 |
-
text, length = self.tokenize_sentence(text)
|
230 |
-
remaining_tokens -= length + len(sentence.semantics[0].values)
|
231 |
-
|
232 |
-
if use_interactive is False:
|
233 |
-
final_text.append(text)
|
234 |
-
final_semantic.append(sentence.semantics)
|
235 |
-
else:
|
236 |
-
# For interactive mode, we only apply speaker for the first sentence
|
237 |
-
# [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
238 |
-
tokens, labels = self.pack_sentences(
|
239 |
-
sentences=[text],
|
240 |
-
semantics=[sentence.semantics],
|
241 |
-
speaker=response.name if use_speaker else None,
|
242 |
-
skip_text=random.random() < self.skip_text_prob,
|
243 |
-
)
|
244 |
-
|
245 |
-
all_tokens.append(tokens)
|
246 |
-
all_labels.append(labels)
|
247 |
-
|
248 |
-
idx += 1
|
249 |
-
|
250 |
-
if use_interactive is False:
|
251 |
-
tokens, labels = self.pack_sentences(
|
252 |
-
final_text,
|
253 |
-
semantics=final_semantic,
|
254 |
-
speaker=response.name if use_speaker else None,
|
255 |
-
)
|
256 |
-
all_tokens.append(tokens)
|
257 |
-
all_labels.append(labels)
|
258 |
-
|
259 |
-
tokens = torch.cat(all_tokens, dim=1)
|
260 |
-
labels = torch.cat(all_labels, dim=1)
|
261 |
-
|
262 |
-
# Verify that the length is correct
|
263 |
-
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
264 |
-
|
265 |
-
data = {"tokens": tokens, "labels": labels}
|
266 |
-
|
267 |
-
return data
|
268 |
-
|
269 |
-
def pack_sentences(
|
270 |
-
self,
|
271 |
-
sentences: list[str],
|
272 |
-
semantics: list,
|
273 |
-
speaker: Optional[str] = None,
|
274 |
-
skip_text: bool = False,
|
275 |
-
):
|
276 |
-
if speaker is None:
|
277 |
-
speaker = "assistant"
|
278 |
-
|
279 |
-
cated_sentences = " ".join(sentences)
|
280 |
-
if skip_text:
|
281 |
-
cated_sentences = "<|skip_text|>"
|
282 |
-
|
283 |
-
final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
|
284 |
-
final_text = final_text + f"<|im_start|>{speaker}\n"
|
285 |
-
|
286 |
-
encoded = self.tokenizer.encode(
|
287 |
-
final_text,
|
288 |
-
add_special_tokens=False,
|
289 |
-
truncation=False,
|
290 |
-
max_length=10**6,
|
291 |
-
)
|
292 |
-
semantic_length = sum([len(i[0].values) for i in semantics])
|
293 |
-
prompt_length = len(encoded)
|
294 |
-
num_codebooks = (
|
295 |
-
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
296 |
-
)
|
297 |
-
|
298 |
-
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
299 |
-
tokens = (
|
300 |
-
encoded
|
301 |
-
+ [self.semantic_token_id] * semantic_length
|
302 |
-
+ self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
303 |
-
)
|
304 |
-
|
305 |
-
# Codebook bos/padding: 0, eos: 1
|
306 |
-
codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
|
307 |
-
for segment in semantics:
|
308 |
-
for book_idx, book in zip(range(num_codebooks), segment):
|
309 |
-
for j in book.values:
|
310 |
-
codes[book_idx].append(int(j) + 1)
|
311 |
-
|
312 |
-
for book in codes:
|
313 |
-
book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
|
314 |
-
|
315 |
-
tokens = [tokens] + codes
|
316 |
-
|
317 |
-
tokens = torch.tensor(tokens, dtype=torch.long)
|
318 |
-
labels = tokens.clone()
|
319 |
-
|
320 |
-
if skip_text:
|
321 |
-
# If text is not provided, the sentence is used for condition only, all labels are -100
|
322 |
-
torch.fill_(labels, -100)
|
323 |
-
return tokens, labels
|
324 |
-
|
325 |
-
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
326 |
-
# Since we don't mask out the input tokens, the language modeling still works
|
327 |
-
labels[1:, :prompt_length] = -100
|
328 |
-
|
329 |
-
tokens = tokens[:, :-1]
|
330 |
-
labels = labels[:, 1:]
|
331 |
-
|
332 |
-
# Verify the padding is correct, and the last token is eos
|
333 |
-
assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
|
334 |
-
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
335 |
-
|
336 |
-
return tokens, labels
|
337 |
-
|
338 |
-
|
339 |
-
@dataclass
|
340 |
-
class TextDataCollator:
|
341 |
-
tokenizer: AutoTokenizer
|
342 |
-
max_length: int = 1024
|
343 |
-
|
344 |
-
def __call__(self, examples):
|
345 |
-
if "negative_tokens" in examples:
|
346 |
-
positive_examples = []
|
347 |
-
negative_examples = []
|
348 |
-
|
349 |
-
for i in examples:
|
350 |
-
positive_examples.append(
|
351 |
-
{
|
352 |
-
"tokens": i["tokens"],
|
353 |
-
"labels": i["labels"],
|
354 |
-
}
|
355 |
-
)
|
356 |
-
negative_examples.append(
|
357 |
-
{
|
358 |
-
"tokens": i["negative_tokens"],
|
359 |
-
"labels": i["negative_labels"],
|
360 |
-
}
|
361 |
-
)
|
362 |
-
|
363 |
-
examples = positive_examples + negative_examples
|
364 |
-
|
365 |
-
return self.batchify(examples)
|
366 |
-
|
367 |
-
def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
368 |
-
tokens, attention_masks, labels = [], [], []
|
369 |
-
|
370 |
-
# Calculate the max length
|
371 |
-
max_tokens_length = 0
|
372 |
-
for example in examples:
|
373 |
-
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
|
374 |
-
max_tokens_length = min(max_tokens_length, self.max_length)
|
375 |
-
|
376 |
-
for example in examples:
|
377 |
-
_tokens = example[tokens_key][:, :max_tokens_length]
|
378 |
-
_labels = example[labels_key][:, :max_tokens_length]
|
379 |
-
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
|
380 |
-
tokens_length = _tokens.size(1)
|
381 |
-
_attention_mask[:tokens_length] = False
|
382 |
-
|
383 |
-
assert tokens_length == _labels.size(
|
384 |
-
1
|
385 |
-
), f"{tokens_length} != {_labels.size(1)}"
|
386 |
-
|
387 |
-
if tokens_length < max_tokens_length:
|
388 |
-
_tokens = F.pad(
|
389 |
-
_tokens,
|
390 |
-
(0, max_tokens_length - tokens_length),
|
391 |
-
value=self.tokenizer.eos_token_id,
|
392 |
-
)
|
393 |
-
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
394 |
-
_labels = F.pad(
|
395 |
-
_labels, (0, max_tokens_length - _labels.size(1)), value=-100
|
396 |
-
)
|
397 |
-
|
398 |
-
tokens.append(_tokens)
|
399 |
-
attention_masks.append(_attention_mask)
|
400 |
-
labels.append(_labels)
|
401 |
-
|
402 |
-
tokens = torch.stack(tokens, dim=0)
|
403 |
-
attention_masks = torch.stack(attention_masks, dim=0)
|
404 |
-
labels = torch.stack(labels, dim=0)
|
405 |
-
|
406 |
-
return {
|
407 |
-
"inputs": tokens,
|
408 |
-
"attention_masks": attention_masks,
|
409 |
-
"labels": labels,
|
410 |
-
}
|
411 |
-
|
412 |
-
|
413 |
-
class InterleaveDataset(IterableDataset):
|
414 |
-
def __init__(
|
415 |
-
self,
|
416 |
-
datasets: list[IterableDataset],
|
417 |
-
probabilities: list[float],
|
418 |
-
seed: int = 42,
|
419 |
-
):
|
420 |
-
super().__init__()
|
421 |
-
|
422 |
-
self.datasets = datasets
|
423 |
-
self.probabilities = probabilities
|
424 |
-
self.seed = seed
|
425 |
-
|
426 |
-
def __iter__(self):
|
427 |
-
rng = np.random.default_rng(self.seed)
|
428 |
-
dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
429 |
-
|
430 |
-
while True:
|
431 |
-
# Random choice one
|
432 |
-
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
433 |
-
dataset_iterator = dataset_iterators[dataset_idx]
|
434 |
-
|
435 |
-
try:
|
436 |
-
yield next(dataset_iterator)
|
437 |
-
except StopIteration:
|
438 |
-
# Exhausted, create a new iterator
|
439 |
-
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
440 |
-
yield next(dataset_iterators[dataset_idx])
|
441 |
-
|
442 |
-
|
443 |
-
class SemanticDataModule(LightningDataModule):
|
444 |
-
def __init__(
|
445 |
-
self,
|
446 |
-
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
447 |
-
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
448 |
-
batch_size: int = 32,
|
449 |
-
tokenizer: AutoTokenizer = None,
|
450 |
-
max_length: int = 1024,
|
451 |
-
num_workers: int = 4,
|
452 |
-
):
|
453 |
-
super().__init__()
|
454 |
-
|
455 |
-
self.train_dataset = train_dataset
|
456 |
-
self.val_dataset = val_dataset
|
457 |
-
self.batch_size = batch_size
|
458 |
-
self.tokenizer = tokenizer
|
459 |
-
self.max_length = max_length
|
460 |
-
self.num_workers = num_workers
|
461 |
-
|
462 |
-
def train_dataloader(self):
|
463 |
-
return DataLoader(
|
464 |
-
self.train_dataset,
|
465 |
-
batch_size=self.batch_size,
|
466 |
-
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
467 |
-
num_workers=self.num_workers,
|
468 |
-
persistent_workers=True,
|
469 |
-
)
|
470 |
-
|
471 |
-
def val_dataloader(self):
|
472 |
-
return DataLoader(
|
473 |
-
self.val_dataset,
|
474 |
-
batch_size=self.batch_size,
|
475 |
-
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
476 |
-
num_workers=self.num_workers,
|
477 |
-
persistent_workers=True,
|
478 |
-
)
|
479 |
-
|
480 |
-
|
481 |
-
if __name__ == "__main__":
|
482 |
-
from tqdm import tqdm
|
483 |
-
|
484 |
-
ds = AutoTextSemanticInstructionDataset(
|
485 |
-
["data/protos"],
|
486 |
-
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
487 |
-
use_speaker=False,
|
488 |
-
interactive_prob=1.0,
|
489 |
-
skip_text_prob=0.5,
|
490 |
-
)
|
491 |
-
|
492 |
-
for i in ds:
|
493 |
-
print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
494 |
-
# i["labels"][0][i["labels"][0] == -100] = 0
|
495 |
-
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
496 |
-
break
|
|
|
1 |
+
import random
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from itertools import chain
|
4 |
+
from pathlib import Path
|
5 |
+
from random import Random
|
6 |
+
from typing import Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pyarrow.parquet as pq
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from datasets.download.streaming_download_manager import xopen
|
13 |
+
from huggingface_hub import HfApi
|
14 |
+
from lightning import LightningDataModule
|
15 |
+
from torch.distributed import get_rank, get_world_size, is_initialized
|
16 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
20 |
+
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
21 |
+
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
22 |
+
from fish_speech.text.clean import clean_text
|
23 |
+
from fish_speech.utils import RankedLogger
|
24 |
+
from fish_speech.utils.braceexpand import braceexpand
|
25 |
+
|
26 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
27 |
+
|
28 |
+
|
29 |
+
def split_by_rank_worker(files):
|
30 |
+
# We need to know the total number of devices
|
31 |
+
# to split the data properly
|
32 |
+
|
33 |
+
total_devices = 1
|
34 |
+
if is_initialized():
|
35 |
+
total_devices = get_world_size()
|
36 |
+
|
37 |
+
worker_info = get_worker_info()
|
38 |
+
if worker_info is not None:
|
39 |
+
total_devices *= worker_info.num_workers
|
40 |
+
|
41 |
+
if len(files) < total_devices:
|
42 |
+
# Repeat the files N times to match the number of devices
|
43 |
+
files = files * (total_devices // len(files) + 1)
|
44 |
+
|
45 |
+
# DDP
|
46 |
+
if is_initialized():
|
47 |
+
files = files[get_rank() :: get_world_size()]
|
48 |
+
|
49 |
+
# Split by worker
|
50 |
+
if worker_info is not None:
|
51 |
+
files = files[worker_info.id :: worker_info.num_workers]
|
52 |
+
|
53 |
+
return files
|
54 |
+
|
55 |
+
|
56 |
+
class AutoTextSemanticInstructionDataset(IterableDataset):
|
57 |
+
"""
|
58 |
+
Auto Augment Dataset by Speaker
|
59 |
+
|
60 |
+
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
61 |
+
2. Automatically normalize the text
|
62 |
+
|
63 |
+
For interactive mode, we use the following format (multiple sequences):
|
64 |
+
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
65 |
+
|
66 |
+
For non-interactive mode, we use the following format (one long sequence):
|
67 |
+
<s> [INST] text [/INST] ... </s>
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
proto_files: list[str],
|
73 |
+
seed: int = 42,
|
74 |
+
interactive_prob: float = 0.5,
|
75 |
+
max_length: int = 1024,
|
76 |
+
tokenizer: AutoTokenizer = None,
|
77 |
+
use_speaker: bool | float = True,
|
78 |
+
causal: bool = True,
|
79 |
+
num_codebooks: Optional[int] = None,
|
80 |
+
skip_text_prob: float = 0.0,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
proto_files: proto buf files if using local data
|
85 |
+
seed: random seed
|
86 |
+
interactive_prob: probability to use interactive mode
|
87 |
+
max_length: max length of the text
|
88 |
+
tokenizer: tokenizer
|
89 |
+
use_speaker: include speaker information in the prompt
|
90 |
+
causal: use causal sampling when using local data, disable will lead to random sampling
|
91 |
+
num_codebooks: number of codebooks, if None, it will be automatically detected
|
92 |
+
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
93 |
+
"""
|
94 |
+
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
98 |
+
|
99 |
+
self.seed = seed
|
100 |
+
self.max_length = max_length
|
101 |
+
self.tokenizer = tokenizer
|
102 |
+
self.interactive_prob = interactive_prob
|
103 |
+
self.use_speaker = use_speaker
|
104 |
+
self.proto_files = proto_files
|
105 |
+
self.causal = causal
|
106 |
+
self.num_codebooks = num_codebooks
|
107 |
+
self.skip_text_prob = skip_text_prob
|
108 |
+
|
109 |
+
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
110 |
+
self.groups = None
|
111 |
+
|
112 |
+
def init_mock_data_server(self):
|
113 |
+
if self.groups is not None:
|
114 |
+
return
|
115 |
+
|
116 |
+
# Expand the proto files
|
117 |
+
expanded_proto_files = []
|
118 |
+
for filename in self.proto_files:
|
119 |
+
for i in braceexpand(filename):
|
120 |
+
i = Path(i)
|
121 |
+
if i.is_file():
|
122 |
+
expanded_proto_files.append(i)
|
123 |
+
elif i.is_dir():
|
124 |
+
expanded_proto_files.extend(i.rglob("*.proto"))
|
125 |
+
expanded_proto_files.extend(i.rglob("*.protos"))
|
126 |
+
else:
|
127 |
+
raise ValueError(f"{i} is not a file or directory")
|
128 |
+
|
129 |
+
expanded_proto_files = sorted(expanded_proto_files)
|
130 |
+
Random(self.seed).shuffle(expanded_proto_files)
|
131 |
+
|
132 |
+
self.groups = []
|
133 |
+
shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
134 |
+
log.info(
|
135 |
+
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
136 |
+
)
|
137 |
+
|
138 |
+
count = 0
|
139 |
+
for filename in shard_proto_files:
|
140 |
+
with open(filename, "rb") as f:
|
141 |
+
for text_data in read_pb_stream(f):
|
142 |
+
self.groups.append(text_data)
|
143 |
+
count += 1
|
144 |
+
|
145 |
+
log.info(f"Read total {count} groups of data")
|
146 |
+
|
147 |
+
# Shuffle the lines
|
148 |
+
Random(self.seed).shuffle(self.groups)
|
149 |
+
self.group_weights = [len(i.sentences) for i in self.groups]
|
150 |
+
|
151 |
+
def __iter__(self):
|
152 |
+
while True:
|
153 |
+
yield self.augment()
|
154 |
+
|
155 |
+
def tokenize_sentence(self, sentence: str):
|
156 |
+
sentence = clean_text(sentence)
|
157 |
+
tokens = self.tokenizer.encode(
|
158 |
+
f"{sentence}",
|
159 |
+
max_length=10**6,
|
160 |
+
add_special_tokens=False,
|
161 |
+
truncation=False,
|
162 |
+
)
|
163 |
+
return sentence, len(tokens)
|
164 |
+
|
165 |
+
def sample_data(self):
|
166 |
+
if self.groups is None:
|
167 |
+
self.init_mock_data_server()
|
168 |
+
|
169 |
+
# Shuffle unique lines, estimate that each sample is at least 20 tokens
|
170 |
+
num_samples = self.max_length // 20
|
171 |
+
|
172 |
+
# choice group based on their number of samples
|
173 |
+
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
174 |
+
|
175 |
+
if self.causal:
|
176 |
+
# Sample in order
|
177 |
+
if num_samples >= len(group.sentences):
|
178 |
+
samples = group.sentences
|
179 |
+
else:
|
180 |
+
begin = random.randint(0, len(group.sentences) - num_samples)
|
181 |
+
samples = group.sentences[begin : begin + num_samples]
|
182 |
+
else:
|
183 |
+
samples = random.choices(
|
184 |
+
group.sentences, k=min(num_samples, len(group.sentences))
|
185 |
+
)
|
186 |
+
|
187 |
+
return SampledData(
|
188 |
+
source=group.source,
|
189 |
+
name=group.name,
|
190 |
+
samples=samples,
|
191 |
+
)
|
192 |
+
|
193 |
+
def augment(self):
|
194 |
+
final_text, final_semantic = [], []
|
195 |
+
response = self.sample_data()
|
196 |
+
if len(response.samples) == 0:
|
197 |
+
# Invalid group
|
198 |
+
return None
|
199 |
+
|
200 |
+
samples = list(response.samples)
|
201 |
+
idx = 0
|
202 |
+
use_interactive = random.random() < self.interactive_prob
|
203 |
+
|
204 |
+
if use_interactive is False:
|
205 |
+
# Random sample based on speaker using a truncated normal distribution
|
206 |
+
a = torch.tensor([0], dtype=torch.float32)
|
207 |
+
torch.nn.init.trunc_normal_(
|
208 |
+
a,
|
209 |
+
mean=self.max_length // 2,
|
210 |
+
std=self.max_length // 4,
|
211 |
+
a=10,
|
212 |
+
b=self.max_length,
|
213 |
+
)
|
214 |
+
remaining_tokens = a.long().item() - 4
|
215 |
+
else:
|
216 |
+
remaining_tokens = self.max_length
|
217 |
+
|
218 |
+
# Use speaker
|
219 |
+
if isinstance(self.use_speaker, float):
|
220 |
+
use_speaker = random.random() < self.use_speaker
|
221 |
+
else:
|
222 |
+
use_speaker = self.use_speaker
|
223 |
+
|
224 |
+
all_tokens, all_labels = [], []
|
225 |
+
while remaining_tokens > 0 and len(samples) > 0:
|
226 |
+
sentence = samples.pop(0)
|
227 |
+
|
228 |
+
text = random.choice(sentence.texts)
|
229 |
+
text, length = self.tokenize_sentence(text)
|
230 |
+
remaining_tokens -= length + len(sentence.semantics[0].values)
|
231 |
+
|
232 |
+
if use_interactive is False:
|
233 |
+
final_text.append(text)
|
234 |
+
final_semantic.append(sentence.semantics)
|
235 |
+
else:
|
236 |
+
# For interactive mode, we only apply speaker for the first sentence
|
237 |
+
# [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
238 |
+
tokens, labels = self.pack_sentences(
|
239 |
+
sentences=[text],
|
240 |
+
semantics=[sentence.semantics],
|
241 |
+
speaker=response.name if use_speaker else None,
|
242 |
+
skip_text=random.random() < self.skip_text_prob,
|
243 |
+
)
|
244 |
+
|
245 |
+
all_tokens.append(tokens)
|
246 |
+
all_labels.append(labels)
|
247 |
+
|
248 |
+
idx += 1
|
249 |
+
|
250 |
+
if use_interactive is False:
|
251 |
+
tokens, labels = self.pack_sentences(
|
252 |
+
final_text,
|
253 |
+
semantics=final_semantic,
|
254 |
+
speaker=response.name if use_speaker else None,
|
255 |
+
)
|
256 |
+
all_tokens.append(tokens)
|
257 |
+
all_labels.append(labels)
|
258 |
+
|
259 |
+
tokens = torch.cat(all_tokens, dim=1)
|
260 |
+
labels = torch.cat(all_labels, dim=1)
|
261 |
+
|
262 |
+
# Verify that the length is correct
|
263 |
+
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
264 |
+
|
265 |
+
data = {"tokens": tokens, "labels": labels}
|
266 |
+
|
267 |
+
return data
|
268 |
+
|
269 |
+
def pack_sentences(
|
270 |
+
self,
|
271 |
+
sentences: list[str],
|
272 |
+
semantics: list,
|
273 |
+
speaker: Optional[str] = None,
|
274 |
+
skip_text: bool = False,
|
275 |
+
):
|
276 |
+
if speaker is None:
|
277 |
+
speaker = "assistant"
|
278 |
+
|
279 |
+
cated_sentences = " ".join(sentences)
|
280 |
+
if skip_text:
|
281 |
+
cated_sentences = "<|skip_text|>"
|
282 |
+
|
283 |
+
final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
|
284 |
+
final_text = final_text + f"<|im_start|>{speaker}\n"
|
285 |
+
|
286 |
+
encoded = self.tokenizer.encode(
|
287 |
+
final_text,
|
288 |
+
add_special_tokens=False,
|
289 |
+
truncation=False,
|
290 |
+
max_length=10**6,
|
291 |
+
)
|
292 |
+
semantic_length = sum([len(i[0].values) for i in semantics])
|
293 |
+
prompt_length = len(encoded)
|
294 |
+
num_codebooks = (
|
295 |
+
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
296 |
+
)
|
297 |
+
|
298 |
+
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
299 |
+
tokens = (
|
300 |
+
encoded
|
301 |
+
+ [self.semantic_token_id] * semantic_length
|
302 |
+
+ self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
303 |
+
)
|
304 |
+
|
305 |
+
# Codebook bos/padding: 0, eos: 1
|
306 |
+
codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
|
307 |
+
for segment in semantics:
|
308 |
+
for book_idx, book in zip(range(num_codebooks), segment):
|
309 |
+
for j in book.values:
|
310 |
+
codes[book_idx].append(int(j) + 1)
|
311 |
+
|
312 |
+
for book in codes:
|
313 |
+
book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
|
314 |
+
|
315 |
+
tokens = [tokens] + codes
|
316 |
+
|
317 |
+
tokens = torch.tensor(tokens, dtype=torch.long)
|
318 |
+
labels = tokens.clone()
|
319 |
+
|
320 |
+
if skip_text:
|
321 |
+
# If text is not provided, the sentence is used for condition only, all labels are -100
|
322 |
+
torch.fill_(labels, -100)
|
323 |
+
return tokens, labels
|
324 |
+
|
325 |
+
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
326 |
+
# Since we don't mask out the input tokens, the language modeling still works
|
327 |
+
labels[1:, :prompt_length] = -100
|
328 |
+
|
329 |
+
tokens = tokens[:, :-1]
|
330 |
+
labels = labels[:, 1:]
|
331 |
+
|
332 |
+
# Verify the padding is correct, and the last token is eos
|
333 |
+
assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
|
334 |
+
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
335 |
+
|
336 |
+
return tokens, labels
|
337 |
+
|
338 |
+
|
339 |
+
@dataclass
|
340 |
+
class TextDataCollator:
|
341 |
+
tokenizer: AutoTokenizer
|
342 |
+
max_length: int = 1024
|
343 |
+
|
344 |
+
def __call__(self, examples):
|
345 |
+
if "negative_tokens" in examples:
|
346 |
+
positive_examples = []
|
347 |
+
negative_examples = []
|
348 |
+
|
349 |
+
for i in examples:
|
350 |
+
positive_examples.append(
|
351 |
+
{
|
352 |
+
"tokens": i["tokens"],
|
353 |
+
"labels": i["labels"],
|
354 |
+
}
|
355 |
+
)
|
356 |
+
negative_examples.append(
|
357 |
+
{
|
358 |
+
"tokens": i["negative_tokens"],
|
359 |
+
"labels": i["negative_labels"],
|
360 |
+
}
|
361 |
+
)
|
362 |
+
|
363 |
+
examples = positive_examples + negative_examples
|
364 |
+
|
365 |
+
return self.batchify(examples)
|
366 |
+
|
367 |
+
def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
368 |
+
tokens, attention_masks, labels = [], [], []
|
369 |
+
|
370 |
+
# Calculate the max length
|
371 |
+
max_tokens_length = 0
|
372 |
+
for example in examples:
|
373 |
+
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
|
374 |
+
max_tokens_length = min(max_tokens_length, self.max_length)
|
375 |
+
|
376 |
+
for example in examples:
|
377 |
+
_tokens = example[tokens_key][:, :max_tokens_length]
|
378 |
+
_labels = example[labels_key][:, :max_tokens_length]
|
379 |
+
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
|
380 |
+
tokens_length = _tokens.size(1)
|
381 |
+
_attention_mask[:tokens_length] = False
|
382 |
+
|
383 |
+
assert tokens_length == _labels.size(
|
384 |
+
1
|
385 |
+
), f"{tokens_length} != {_labels.size(1)}"
|
386 |
+
|
387 |
+
if tokens_length < max_tokens_length:
|
388 |
+
_tokens = F.pad(
|
389 |
+
_tokens,
|
390 |
+
(0, max_tokens_length - tokens_length),
|
391 |
+
value=self.tokenizer.eos_token_id,
|
392 |
+
)
|
393 |
+
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
394 |
+
_labels = F.pad(
|
395 |
+
_labels, (0, max_tokens_length - _labels.size(1)), value=-100
|
396 |
+
)
|
397 |
+
|
398 |
+
tokens.append(_tokens)
|
399 |
+
attention_masks.append(_attention_mask)
|
400 |
+
labels.append(_labels)
|
401 |
+
|
402 |
+
tokens = torch.stack(tokens, dim=0)
|
403 |
+
attention_masks = torch.stack(attention_masks, dim=0)
|
404 |
+
labels = torch.stack(labels, dim=0)
|
405 |
+
|
406 |
+
return {
|
407 |
+
"inputs": tokens,
|
408 |
+
"attention_masks": attention_masks,
|
409 |
+
"labels": labels,
|
410 |
+
}
|
411 |
+
|
412 |
+
|
413 |
+
class InterleaveDataset(IterableDataset):
|
414 |
+
def __init__(
|
415 |
+
self,
|
416 |
+
datasets: list[IterableDataset],
|
417 |
+
probabilities: list[float],
|
418 |
+
seed: int = 42,
|
419 |
+
):
|
420 |
+
super().__init__()
|
421 |
+
|
422 |
+
self.datasets = datasets
|
423 |
+
self.probabilities = probabilities
|
424 |
+
self.seed = seed
|
425 |
+
|
426 |
+
def __iter__(self):
|
427 |
+
rng = np.random.default_rng(self.seed)
|
428 |
+
dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
429 |
+
|
430 |
+
while True:
|
431 |
+
# Random choice one
|
432 |
+
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
433 |
+
dataset_iterator = dataset_iterators[dataset_idx]
|
434 |
+
|
435 |
+
try:
|
436 |
+
yield next(dataset_iterator)
|
437 |
+
except StopIteration:
|
438 |
+
# Exhausted, create a new iterator
|
439 |
+
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
440 |
+
yield next(dataset_iterators[dataset_idx])
|
441 |
+
|
442 |
+
|
443 |
+
class SemanticDataModule(LightningDataModule):
|
444 |
+
def __init__(
|
445 |
+
self,
|
446 |
+
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
447 |
+
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
448 |
+
batch_size: int = 32,
|
449 |
+
tokenizer: AutoTokenizer = None,
|
450 |
+
max_length: int = 1024,
|
451 |
+
num_workers: int = 4,
|
452 |
+
):
|
453 |
+
super().__init__()
|
454 |
+
|
455 |
+
self.train_dataset = train_dataset
|
456 |
+
self.val_dataset = val_dataset
|
457 |
+
self.batch_size = batch_size
|
458 |
+
self.tokenizer = tokenizer
|
459 |
+
self.max_length = max_length
|
460 |
+
self.num_workers = num_workers
|
461 |
+
|
462 |
+
def train_dataloader(self):
|
463 |
+
return DataLoader(
|
464 |
+
self.train_dataset,
|
465 |
+
batch_size=self.batch_size,
|
466 |
+
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
467 |
+
num_workers=self.num_workers,
|
468 |
+
persistent_workers=True,
|
469 |
+
)
|
470 |
+
|
471 |
+
def val_dataloader(self):
|
472 |
+
return DataLoader(
|
473 |
+
self.val_dataset,
|
474 |
+
batch_size=self.batch_size,
|
475 |
+
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
476 |
+
num_workers=self.num_workers,
|
477 |
+
persistent_workers=True,
|
478 |
+
)
|
479 |
+
|
480 |
+
|
481 |
+
if __name__ == "__main__":
|
482 |
+
from tqdm import tqdm
|
483 |
+
|
484 |
+
ds = AutoTextSemanticInstructionDataset(
|
485 |
+
["data/protos"],
|
486 |
+
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
487 |
+
use_speaker=False,
|
488 |
+
interactive_prob=1.0,
|
489 |
+
skip_text_prob=0.5,
|
490 |
+
)
|
491 |
+
|
492 |
+
for i in ds:
|
493 |
+
print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
494 |
+
# i["labels"][0][i["labels"][0] == -100] = 0
|
495 |
+
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
496 |
+
break
|
fish_speech/datasets/text.py
DELETED
@@ -1,661 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from itertools import chain
|
4 |
-
from pathlib import Path
|
5 |
-
from random import Random
|
6 |
-
from typing import Optional, Union
|
7 |
-
|
8 |
-
import grpc
|
9 |
-
import numpy as np
|
10 |
-
import pyarrow.parquet as pq
|
11 |
-
import torch
|
12 |
-
import torch.nn.functional as F
|
13 |
-
from datasets.download.streaming_download_manager import xopen
|
14 |
-
from huggingface_hub import HfApi
|
15 |
-
from lightning import LightningDataModule
|
16 |
-
from torch.distributed import get_rank, get_world_size, is_initialized
|
17 |
-
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
18 |
-
from transformers import AutoTokenizer
|
19 |
-
|
20 |
-
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
21 |
-
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
22 |
-
from fish_speech.text.clean import clean_text
|
23 |
-
from fish_speech.utils import RankedLogger
|
24 |
-
from fish_speech.utils.braceexpand import braceexpand
|
25 |
-
|
26 |
-
log = RankedLogger(__name__, rank_zero_only=True)
|
27 |
-
|
28 |
-
CODEBOOK_PAD_TOKEN_ID = 0
|
29 |
-
CODEBOOK_EOS_TOKEN_ID = 1
|
30 |
-
|
31 |
-
|
32 |
-
def split_by_rank_worker(files):
|
33 |
-
# We need to know the total number of devices
|
34 |
-
# to split the data properly
|
35 |
-
|
36 |
-
total_devices = 1
|
37 |
-
if is_initialized():
|
38 |
-
total_devices = get_world_size()
|
39 |
-
|
40 |
-
worker_info = get_worker_info()
|
41 |
-
if worker_info is not None:
|
42 |
-
total_devices *= worker_info.num_workers
|
43 |
-
|
44 |
-
if len(files) < total_devices:
|
45 |
-
# Repeat the files N times to match the number of devices
|
46 |
-
files = files * (total_devices // len(files) + 1)
|
47 |
-
|
48 |
-
# DDP
|
49 |
-
if is_initialized():
|
50 |
-
files = files[get_rank() :: get_world_size()]
|
51 |
-
|
52 |
-
# Split by worker
|
53 |
-
if worker_info is not None:
|
54 |
-
files = files[worker_info.id :: worker_info.num_workers]
|
55 |
-
|
56 |
-
return files
|
57 |
-
|
58 |
-
|
59 |
-
class StreamTextDataset(IterableDataset):
|
60 |
-
def __init__(
|
61 |
-
self,
|
62 |
-
files: Optional[Union[list[str], str]] = None,
|
63 |
-
prefix: Optional[str] = None,
|
64 |
-
seed: int = 42,
|
65 |
-
parquet_batch_size: int = 10000,
|
66 |
-
repo: str = "uonlp/CulturaX",
|
67 |
-
max_length: int = 1024,
|
68 |
-
tokenizer: AutoTokenizer = None,
|
69 |
-
):
|
70 |
-
super().__init__()
|
71 |
-
|
72 |
-
self.seed = seed
|
73 |
-
self.parquet_batch_size = parquet_batch_size
|
74 |
-
self.repo = repo
|
75 |
-
self.max_length = max_length
|
76 |
-
self.tokenizer = tokenizer
|
77 |
-
|
78 |
-
if files is None and prefix is None:
|
79 |
-
raise ValueError("Either files or prefix must be specified")
|
80 |
-
|
81 |
-
if prefix is not None:
|
82 |
-
files = HfApi().list_repo_files(repo, repo_type="dataset")
|
83 |
-
files = [
|
84 |
-
f for f in files if f.startswith(prefix) and f.endswith(".parquet")
|
85 |
-
]
|
86 |
-
log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
|
87 |
-
else:
|
88 |
-
if isinstance(files, str):
|
89 |
-
files = [files]
|
90 |
-
|
91 |
-
files = list(chain.from_iterable(map(braceexpand, files)))
|
92 |
-
log.info(f"Expanded {len(files)} files in {repo}")
|
93 |
-
|
94 |
-
# Get sharded files
|
95 |
-
self.files = sorted(files)
|
96 |
-
Random(seed).shuffle(self.files)
|
97 |
-
|
98 |
-
def __iter__(self):
|
99 |
-
files = split_by_rank_worker(self.files)
|
100 |
-
random.shuffle(files)
|
101 |
-
|
102 |
-
for filename in files:
|
103 |
-
try:
|
104 |
-
yield from self.parse_data(filename)
|
105 |
-
except Exception as e:
|
106 |
-
log.exception(f"Failed to parse {filename}: {e}")
|
107 |
-
|
108 |
-
def parse_data(self, filename: str):
|
109 |
-
for data in self.parse_data_internal(filename):
|
110 |
-
text = data["text"]
|
111 |
-
|
112 |
-
# encode
|
113 |
-
tokens = self.tokenizer.encode(
|
114 |
-
text,
|
115 |
-
add_special_tokens=False,
|
116 |
-
truncation=False,
|
117 |
-
max_length=10**6,
|
118 |
-
)
|
119 |
-
|
120 |
-
# Random choice self.max_length
|
121 |
-
if len(tokens) > self.max_length:
|
122 |
-
start = random.randint(0, len(tokens) - self.max_length)
|
123 |
-
tokens = tokens[start : start + self.max_length - 1]
|
124 |
-
|
125 |
-
tokens = (
|
126 |
-
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
127 |
-
)
|
128 |
-
# Pad dims
|
129 |
-
placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
|
130 |
-
|
131 |
-
tokens = torch.concat(
|
132 |
-
[
|
133 |
-
torch.tensor([tokens], dtype=torch.long),
|
134 |
-
placeholder_multi_codebook,
|
135 |
-
],
|
136 |
-
dim=0,
|
137 |
-
)
|
138 |
-
labels = tokens.clone()
|
139 |
-
tokens = tokens[:, :-1]
|
140 |
-
labels = labels[:, 1:]
|
141 |
-
labels[1:] = -100 # remove all placeholders
|
142 |
-
|
143 |
-
yield {"tokens": tokens, "labels": labels}
|
144 |
-
|
145 |
-
def parse_data_internal(self, filename: str):
|
146 |
-
url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
|
147 |
-
|
148 |
-
with xopen(url, mode="rb") as stream:
|
149 |
-
parquet_file = pq.ParquetFile(stream)
|
150 |
-
|
151 |
-
for batch in parquet_file.iter_batches(
|
152 |
-
batch_size=self.parquet_batch_size, columns=["text"]
|
153 |
-
):
|
154 |
-
# In-batch shuffling
|
155 |
-
texts = [{"text": text.as_py()} for text in batch["text"]]
|
156 |
-
random.shuffle(texts)
|
157 |
-
yield from texts
|
158 |
-
|
159 |
-
|
160 |
-
class AutoAugTextDataset(IterableDataset):
|
161 |
-
"""
|
162 |
-
Auto Augment Dataset by Speaker
|
163 |
-
|
164 |
-
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
165 |
-
2. Automatically normalize the text
|
166 |
-
|
167 |
-
For interactive mode, we use the following format (multiple sequences):
|
168 |
-
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
169 |
-
|
170 |
-
For non-interactive mode, we use the following format (one long sequence):
|
171 |
-
<s> [INST] text [/INST] ... </s>
|
172 |
-
"""
|
173 |
-
|
174 |
-
def __init__(
|
175 |
-
self,
|
176 |
-
proto_files: list[str],
|
177 |
-
seed: int = 42,
|
178 |
-
interactive_prob: float = 0.5,
|
179 |
-
max_length: int = 1024,
|
180 |
-
tokenizer: AutoTokenizer = None,
|
181 |
-
use_speaker: bool = True,
|
182 |
-
causual: bool = True,
|
183 |
-
use_negative_samples: bool = False,
|
184 |
-
num_codebooks: Optional[int] = None,
|
185 |
-
):
|
186 |
-
"""
|
187 |
-
Args:
|
188 |
-
proto_files: proto buf files if using local data
|
189 |
-
seed: random seed
|
190 |
-
interactive_prob: probability to use interactive mode
|
191 |
-
max_length: max length of the text
|
192 |
-
tokenizer: tokenizer
|
193 |
-
use_speaker: include speaker information in the prompt
|
194 |
-
causual: use causual sampling when using local data, disable will lead to random sampling
|
195 |
-
use_negative_samples: generate negative samples
|
196 |
-
num_codebooks: number of codebooks, if None, it will be automatically detected
|
197 |
-
"""
|
198 |
-
|
199 |
-
super().__init__()
|
200 |
-
|
201 |
-
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
202 |
-
|
203 |
-
self.seed = seed
|
204 |
-
self.max_length = max_length
|
205 |
-
self.tokenizer = tokenizer
|
206 |
-
self.interactive_prob = interactive_prob
|
207 |
-
self.use_speaker = use_speaker
|
208 |
-
self.proto_files = proto_files
|
209 |
-
self.causual = causual
|
210 |
-
self.use_negative_samples = use_negative_samples
|
211 |
-
self.num_codebooks = num_codebooks
|
212 |
-
|
213 |
-
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
214 |
-
self.groups = None
|
215 |
-
|
216 |
-
def init_mock_data_server(self):
|
217 |
-
if self.groups is not None:
|
218 |
-
return
|
219 |
-
|
220 |
-
# Expand the proto files
|
221 |
-
expanded_proto_files = []
|
222 |
-
for filename in self.proto_files:
|
223 |
-
for i in braceexpand(filename):
|
224 |
-
i = Path(i)
|
225 |
-
if i.is_file():
|
226 |
-
expanded_proto_files.append(i)
|
227 |
-
elif i.is_dir():
|
228 |
-
expanded_proto_files.extend(i.rglob("*.proto"))
|
229 |
-
expanded_proto_files.extend(i.rglob("*.protos"))
|
230 |
-
else:
|
231 |
-
raise ValueError(f"{i} is not a file or directory")
|
232 |
-
|
233 |
-
expanded_proto_files = sorted(expanded_proto_files)
|
234 |
-
Random(self.seed).shuffle(expanded_proto_files)
|
235 |
-
|
236 |
-
self.groups = []
|
237 |
-
shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
238 |
-
log.info(
|
239 |
-
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
240 |
-
)
|
241 |
-
|
242 |
-
count = 0
|
243 |
-
for filename in shard_proto_files:
|
244 |
-
with open(filename, "rb") as f:
|
245 |
-
for text_data in read_pb_stream(f):
|
246 |
-
self.groups.append(text_data)
|
247 |
-
count += 1
|
248 |
-
|
249 |
-
log.info(f"Read total {count} groups of data")
|
250 |
-
|
251 |
-
# Shuffle the lines
|
252 |
-
Random(self.seed).shuffle(self.groups)
|
253 |
-
self.group_weights = [len(i.sentences) for i in self.groups]
|
254 |
-
|
255 |
-
def __iter__(self):
|
256 |
-
while True:
|
257 |
-
yield self.augment()
|
258 |
-
|
259 |
-
def tokenize_sentence(self, sentence: str):
|
260 |
-
sentence = clean_text(sentence)
|
261 |
-
tokens = self.tokenizer.encode(
|
262 |
-
f"{sentence}",
|
263 |
-
max_length=10**6,
|
264 |
-
add_special_tokens=False,
|
265 |
-
truncation=False,
|
266 |
-
)
|
267 |
-
return sentence, len(tokens)
|
268 |
-
|
269 |
-
def sample_data(self):
|
270 |
-
if self.groups is None:
|
271 |
-
self.init_mock_data_server()
|
272 |
-
|
273 |
-
# Shuffle unique lines, estimate that each sample is at least 20 tokens
|
274 |
-
num_samples = self.max_length // 20
|
275 |
-
|
276 |
-
# choice group based on their number of samples
|
277 |
-
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
278 |
-
|
279 |
-
if self.causual:
|
280 |
-
# Sample in order
|
281 |
-
if num_samples >= len(group.sentences):
|
282 |
-
samples = group.sentences
|
283 |
-
else:
|
284 |
-
begin = random.randint(0, len(group.sentences) - num_samples)
|
285 |
-
samples = group.sentences[begin : begin + num_samples]
|
286 |
-
else:
|
287 |
-
samples = random.choices(
|
288 |
-
group.sentences, k=min(num_samples, len(group.sentences))
|
289 |
-
)
|
290 |
-
|
291 |
-
return SampledData(
|
292 |
-
source=group.source,
|
293 |
-
name=group.name,
|
294 |
-
samples=samples,
|
295 |
-
)
|
296 |
-
|
297 |
-
def augment(self):
|
298 |
-
# Random sample based on speaker using a truncated normal distribution
|
299 |
-
a = torch.tensor([0], dtype=torch.float32)
|
300 |
-
torch.nn.init.trunc_normal_(
|
301 |
-
a,
|
302 |
-
mean=self.max_length // 2,
|
303 |
-
std=self.max_length // 4,
|
304 |
-
a=10,
|
305 |
-
b=self.max_length,
|
306 |
-
)
|
307 |
-
remaining_tokens = a.long().item() - 4
|
308 |
-
|
309 |
-
final_text, final_semantic = [], []
|
310 |
-
response = self.sample_data()
|
311 |
-
if len(response.samples) == 0:
|
312 |
-
# Invalid group
|
313 |
-
return None
|
314 |
-
|
315 |
-
samples = list(response.samples)
|
316 |
-
idx = 0
|
317 |
-
use_interactive = random.random() < self.interactive_prob
|
318 |
-
|
319 |
-
all_tokens, all_labels = [], []
|
320 |
-
while remaining_tokens > 0 and len(samples) > 0:
|
321 |
-
sentence = samples.pop(0)
|
322 |
-
|
323 |
-
text = random.choice(sentence.texts)
|
324 |
-
text, length = self.tokenize_sentence(text)
|
325 |
-
remaining_tokens -= length + len(sentence.semantics[0].values)
|
326 |
-
|
327 |
-
if use_interactive is False:
|
328 |
-
final_text.append(text)
|
329 |
-
final_semantic.append(sentence.semantics)
|
330 |
-
else:
|
331 |
-
# For interactive mode, we only apply speaker for the first sentence
|
332 |
-
# [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
333 |
-
tokens, labels = self.pack_sentences(
|
334 |
-
sentences=[text],
|
335 |
-
semantics=[sentence.semantics],
|
336 |
-
speaker=response.name if (self.use_speaker and idx == 0) else None,
|
337 |
-
add_bos=idx == 0,
|
338 |
-
)
|
339 |
-
|
340 |
-
all_tokens.append(tokens)
|
341 |
-
all_labels.append(labels)
|
342 |
-
|
343 |
-
idx += 1
|
344 |
-
|
345 |
-
if use_interactive is False:
|
346 |
-
tokens, labels = self.pack_sentences(
|
347 |
-
final_text,
|
348 |
-
semantics=final_semantic,
|
349 |
-
speaker=response.name if self.use_speaker else None,
|
350 |
-
add_bos=True,
|
351 |
-
)
|
352 |
-
all_tokens.append(tokens)
|
353 |
-
all_labels.append(labels)
|
354 |
-
|
355 |
-
tokens = torch.cat(all_tokens, dim=1)
|
356 |
-
labels = torch.cat(all_labels, dim=1)
|
357 |
-
|
358 |
-
# Verify that the length is correct
|
359 |
-
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
360 |
-
|
361 |
-
# Verify bos token
|
362 |
-
assert tokens[0, 0] == self.tokenizer.bos_token_id
|
363 |
-
|
364 |
-
data = {"tokens": tokens, "labels": labels}
|
365 |
-
|
366 |
-
if self.use_negative_samples:
|
367 |
-
negative_samples = self.generate_negative_samples(all_tokens, all_labels)
|
368 |
-
data.update(negative_samples)
|
369 |
-
|
370 |
-
return data
|
371 |
-
|
372 |
-
def generate_negative_samples(self, all_tokens, all_labels):
|
373 |
-
new_tokens, new_labels = [], []
|
374 |
-
|
375 |
-
for tokens, labels in zip(all_tokens, all_labels):
|
376 |
-
# If all codebooks are not -100, we find where it starts
|
377 |
-
start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
|
378 |
-
assert (labels[1:, start:] != -100).all() # This shouldn't happen
|
379 |
-
|
380 |
-
mode = random.choice(["repeat", "lost", "noise"])
|
381 |
-
begin = random.randint(start, labels.size(1) - 1)
|
382 |
-
end = random.randint(begin, labels.size(1) - 1)
|
383 |
-
|
384 |
-
if mode == "repeat":
|
385 |
-
tokens = torch.cat(
|
386 |
-
[
|
387 |
-
tokens[:, :begin],
|
388 |
-
tokens[:, begin:end],
|
389 |
-
tokens[:, begin:end],
|
390 |
-
tokens[:, end:],
|
391 |
-
],
|
392 |
-
dim=1,
|
393 |
-
)
|
394 |
-
labels = torch.cat(
|
395 |
-
[
|
396 |
-
labels[:, :begin],
|
397 |
-
labels[:, begin:end],
|
398 |
-
labels[:, begin:end],
|
399 |
-
labels[:, end:],
|
400 |
-
],
|
401 |
-
dim=1,
|
402 |
-
)
|
403 |
-
elif mode == "lost":
|
404 |
-
tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
|
405 |
-
labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
|
406 |
-
elif mode == "noise":
|
407 |
-
middle_tokens, middle_labels = (
|
408 |
-
tokens[:, begin:end],
|
409 |
-
labels[:, begin:end],
|
410 |
-
)
|
411 |
-
random_order0 = torch.randperm(middle_tokens.size(1))
|
412 |
-
random_order1 = torch.randperm(middle_tokens.size(1))
|
413 |
-
middle_tokens = middle_tokens[:, random_order0]
|
414 |
-
middle_labels = middle_labels[:, random_order1]
|
415 |
-
tokens = torch.cat(
|
416 |
-
[tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
|
417 |
-
)
|
418 |
-
labels = torch.cat(
|
419 |
-
[labels[:, :begin], middle_labels, labels[:, end:]], dim=1
|
420 |
-
)
|
421 |
-
|
422 |
-
new_tokens.append(tokens)
|
423 |
-
new_labels.append(labels)
|
424 |
-
|
425 |
-
tokens = torch.cat(new_tokens, dim=1)
|
426 |
-
labels = torch.cat(new_labels, dim=1)
|
427 |
-
|
428 |
-
# Verify that the length is correct
|
429 |
-
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
430 |
-
|
431 |
-
return {"negative_tokens": tokens, "negative_labels": labels}
|
432 |
-
|
433 |
-
def pack_sentences(
|
434 |
-
self,
|
435 |
-
sentences: list[str],
|
436 |
-
semantics=list,
|
437 |
-
speaker: Optional[str] = None,
|
438 |
-
add_bos: bool = True,
|
439 |
-
):
|
440 |
-
if speaker is not None:
|
441 |
-
sentences = [f"[SPK: {speaker}]"] + sentences
|
442 |
-
|
443 |
-
final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
|
444 |
-
final_text = final_text + "<|im_start|>assistant<|im_sep|>"
|
445 |
-
|
446 |
-
encoded = self.tokenizer.encode(
|
447 |
-
final_text,
|
448 |
-
add_special_tokens=False,
|
449 |
-
truncation=False,
|
450 |
-
max_length=10**6,
|
451 |
-
)
|
452 |
-
semantic_length = sum([len(i[0].values) for i in semantics])
|
453 |
-
prompt_length = len(encoded)
|
454 |
-
num_codebooks = (
|
455 |
-
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
456 |
-
)
|
457 |
-
|
458 |
-
bos_bias = 1 if add_bos else 0
|
459 |
-
|
460 |
-
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
461 |
-
tokens = (
|
462 |
-
encoded
|
463 |
-
+ [self.semantic_token_id] * semantic_length
|
464 |
-
+ self.tokenizer.convert_tokens_to_ids(
|
465 |
-
["<|im_end|>", "<|end_of_sequence|>"]
|
466 |
-
)
|
467 |
-
)
|
468 |
-
|
469 |
-
if add_bos:
|
470 |
-
tokens = [self.tokenizer.bos_token_id] + tokens
|
471 |
-
|
472 |
-
# Codebook bos/padding: 0, eos: 1
|
473 |
-
codes = [
|
474 |
-
[CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
|
475 |
-
for _ in range(num_codebooks)
|
476 |
-
]
|
477 |
-
for segment in semantics:
|
478 |
-
for book_idx, book in zip(range(num_codebooks), segment):
|
479 |
-
for j in book.values:
|
480 |
-
codes[book_idx].append(int(j) + 2)
|
481 |
-
|
482 |
-
for book in codes:
|
483 |
-
book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
|
484 |
-
|
485 |
-
tokens = [tokens] + codes
|
486 |
-
|
487 |
-
tokens = torch.tensor(tokens, dtype=torch.long)
|
488 |
-
labels = tokens.clone()
|
489 |
-
|
490 |
-
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
491 |
-
# Since we don't mask out the input tokens, the language modeling still works
|
492 |
-
labels[1:, : (prompt_length + bos_bias)] = -100
|
493 |
-
|
494 |
-
tokens = tokens[:, :-1]
|
495 |
-
labels = labels[:, 1:]
|
496 |
-
|
497 |
-
# Verify the padding is correct, and the last token is eos
|
498 |
-
assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
499 |
-
assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
|
500 |
-
assert labels[0, -1] == self.tokenizer.eos_token_id
|
501 |
-
assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
|
502 |
-
|
503 |
-
return tokens, labels
|
504 |
-
|
505 |
-
|
506 |
-
@dataclass
|
507 |
-
class TextDataCollator:
|
508 |
-
tokenizer: AutoTokenizer
|
509 |
-
max_length: int = 1024
|
510 |
-
|
511 |
-
def __call__(self, examples):
|
512 |
-
if "negative_tokens" in examples:
|
513 |
-
positive_examples = []
|
514 |
-
negative_examples = []
|
515 |
-
|
516 |
-
for i in examples:
|
517 |
-
positive_examples.append(
|
518 |
-
{
|
519 |
-
"tokens": i["tokens"],
|
520 |
-
"labels": i["labels"],
|
521 |
-
}
|
522 |
-
)
|
523 |
-
negative_examples.append(
|
524 |
-
{
|
525 |
-
"tokens": i["negative_tokens"],
|
526 |
-
"labels": i["negative_labels"],
|
527 |
-
}
|
528 |
-
)
|
529 |
-
|
530 |
-
examples = positive_examples + negative_examples
|
531 |
-
|
532 |
-
return self.batchify(examples)
|
533 |
-
|
534 |
-
def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
535 |
-
tokens, attention_masks, labels = [], [], []
|
536 |
-
|
537 |
-
# Calculate the max length
|
538 |
-
max_tokens_length = 0
|
539 |
-
for example in examples:
|
540 |
-
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
|
541 |
-
max_tokens_length = min(max_tokens_length, self.max_length)
|
542 |
-
|
543 |
-
for example in examples:
|
544 |
-
_tokens = example[tokens_key][:, :max_tokens_length]
|
545 |
-
_labels = example[labels_key][:, :max_tokens_length]
|
546 |
-
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
|
547 |
-
tokens_length = _tokens.size(1)
|
548 |
-
_attention_mask[:tokens_length] = False
|
549 |
-
|
550 |
-
assert tokens_length == _labels.size(
|
551 |
-
1
|
552 |
-
), f"{tokens_length} != {_labels.size(1)}"
|
553 |
-
|
554 |
-
if tokens_length < max_tokens_length:
|
555 |
-
_tokens = F.pad(
|
556 |
-
_tokens,
|
557 |
-
(0, max_tokens_length - tokens_length),
|
558 |
-
value=self.tokenizer.eos_token_id,
|
559 |
-
)
|
560 |
-
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
561 |
-
_labels = F.pad(
|
562 |
-
_labels, (0, max_tokens_length - _labels.size(1)), value=-100
|
563 |
-
)
|
564 |
-
|
565 |
-
tokens.append(_tokens)
|
566 |
-
attention_masks.append(_attention_mask)
|
567 |
-
labels.append(_labels)
|
568 |
-
|
569 |
-
tokens = torch.stack(tokens, dim=0)
|
570 |
-
attention_masks = torch.stack(attention_masks, dim=0)
|
571 |
-
labels = torch.stack(labels, dim=0)
|
572 |
-
|
573 |
-
return {
|
574 |
-
"inputs": tokens,
|
575 |
-
"attention_masks": attention_masks,
|
576 |
-
"labels": labels,
|
577 |
-
}
|
578 |
-
|
579 |
-
|
580 |
-
class InterleaveDataset(IterableDataset):
|
581 |
-
def __init__(
|
582 |
-
self,
|
583 |
-
datasets: list[IterableDataset],
|
584 |
-
probabilities: list[float],
|
585 |
-
seed: int = 42,
|
586 |
-
):
|
587 |
-
super().__init__()
|
588 |
-
|
589 |
-
self.datasets = datasets
|
590 |
-
self.probabilities = probabilities
|
591 |
-
self.seed = seed
|
592 |
-
|
593 |
-
def __iter__(self):
|
594 |
-
rng = np.random.default_rng(self.seed)
|
595 |
-
dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
596 |
-
|
597 |
-
while True:
|
598 |
-
# Random choice one
|
599 |
-
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
600 |
-
dataset_iterator = dataset_iterators[dataset_idx]
|
601 |
-
|
602 |
-
try:
|
603 |
-
yield next(dataset_iterator)
|
604 |
-
except StopIteration:
|
605 |
-
# Exhausted, create a new iterator
|
606 |
-
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
607 |
-
yield next(dataset_iterators[dataset_idx])
|
608 |
-
|
609 |
-
|
610 |
-
class TextDataModule(LightningDataModule):
|
611 |
-
def __init__(
|
612 |
-
self,
|
613 |
-
train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
614 |
-
val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
|
615 |
-
batch_size: int = 32,
|
616 |
-
tokenizer: AutoTokenizer = None,
|
617 |
-
max_length: int = 1024,
|
618 |
-
num_workers: int = 4,
|
619 |
-
):
|
620 |
-
super().__init__()
|
621 |
-
|
622 |
-
self.train_dataset = train_dataset
|
623 |
-
self.val_dataset = val_dataset
|
624 |
-
self.batch_size = batch_size
|
625 |
-
self.tokenizer = tokenizer
|
626 |
-
self.max_length = max_length
|
627 |
-
self.num_workers = num_workers
|
628 |
-
|
629 |
-
def train_dataloader(self):
|
630 |
-
return DataLoader(
|
631 |
-
self.train_dataset,
|
632 |
-
batch_size=self.batch_size,
|
633 |
-
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
634 |
-
num_workers=self.num_workers,
|
635 |
-
)
|
636 |
-
|
637 |
-
def val_dataloader(self):
|
638 |
-
return DataLoader(
|
639 |
-
self.val_dataset,
|
640 |
-
batch_size=self.batch_size,
|
641 |
-
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
642 |
-
num_workers=self.num_workers,
|
643 |
-
)
|
644 |
-
|
645 |
-
|
646 |
-
if __name__ == "__main__":
|
647 |
-
from tqdm import tqdm
|
648 |
-
|
649 |
-
ds = AutoAugTextDataset(
|
650 |
-
["data/protos"],
|
651 |
-
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
652 |
-
use_speaker=False,
|
653 |
-
interactive_prob=1.0,
|
654 |
-
use_negative_samples=False,
|
655 |
-
)
|
656 |
-
|
657 |
-
for i in ds:
|
658 |
-
print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
659 |
-
# i["labels"][0][i["labels"][0] == -100] = 0
|
660 |
-
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
661 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/datasets/vqgan.py
CHANGED
@@ -1,147 +1,147 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Optional
|
4 |
-
|
5 |
-
import librosa
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from lightning import LightningDataModule
|
9 |
-
from torch.utils.data import DataLoader, Dataset
|
10 |
-
|
11 |
-
from fish_speech.utils import RankedLogger
|
12 |
-
|
13 |
-
logger = RankedLogger(__name__, rank_zero_only=False)
|
14 |
-
|
15 |
-
|
16 |
-
class VQGANDataset(Dataset):
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
filelist: str,
|
20 |
-
sample_rate: int = 32000,
|
21 |
-
hop_length: int = 640,
|
22 |
-
slice_frames: Optional[int] = None,
|
23 |
-
):
|
24 |
-
super().__init__()
|
25 |
-
|
26 |
-
filelist = Path(filelist)
|
27 |
-
root = filelist.parent
|
28 |
-
|
29 |
-
self.files = [
|
30 |
-
root / line.strip()
|
31 |
-
for line in filelist.read_text(encoding="utf-8").splitlines()
|
32 |
-
if line.strip()
|
33 |
-
]
|
34 |
-
self.sample_rate = sample_rate
|
35 |
-
self.hop_length = hop_length
|
36 |
-
self.slice_frames = slice_frames
|
37 |
-
|
38 |
-
def __len__(self):
|
39 |
-
return len(self.files)
|
40 |
-
|
41 |
-
def get_item(self, idx):
|
42 |
-
file = self.files[idx]
|
43 |
-
|
44 |
-
audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
|
45 |
-
|
46 |
-
# Slice audio and features
|
47 |
-
if (
|
48 |
-
self.slice_frames is not None
|
49 |
-
and audio.shape[0] > self.slice_frames * self.hop_length
|
50 |
-
):
|
51 |
-
start = np.random.randint(
|
52 |
-
0, audio.shape[0] - self.slice_frames * self.hop_length
|
53 |
-
)
|
54 |
-
audio = audio[start : start + self.slice_frames * self.hop_length]
|
55 |
-
|
56 |
-
if len(audio) == 0:
|
57 |
-
return None
|
58 |
-
|
59 |
-
max_value = np.abs(audio).max()
|
60 |
-
if max_value > 1.0:
|
61 |
-
audio = audio / max_value
|
62 |
-
|
63 |
-
return {
|
64 |
-
"audio": torch.from_numpy(audio),
|
65 |
-
}
|
66 |
-
|
67 |
-
def __getitem__(self, idx):
|
68 |
-
try:
|
69 |
-
return self.get_item(idx)
|
70 |
-
except Exception as e:
|
71 |
-
import traceback
|
72 |
-
|
73 |
-
traceback.print_exc()
|
74 |
-
logger.error(f"Error loading {self.files[idx]}: {e}")
|
75 |
-
return None
|
76 |
-
|
77 |
-
|
78 |
-
@dataclass
|
79 |
-
class VQGANCollator:
|
80 |
-
def __call__(self, batch):
|
81 |
-
batch = [x for x in batch if x is not None]
|
82 |
-
|
83 |
-
audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
|
84 |
-
audio_maxlen = audio_lengths.max()
|
85 |
-
|
86 |
-
# Rounds up to nearest multiple of 2 (audio_lengths)
|
87 |
-
audios = []
|
88 |
-
for x in batch:
|
89 |
-
audios.append(
|
90 |
-
torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
|
91 |
-
)
|
92 |
-
|
93 |
-
return {
|
94 |
-
"audios": torch.stack(audios),
|
95 |
-
"audio_lengths": audio_lengths,
|
96 |
-
}
|
97 |
-
|
98 |
-
|
99 |
-
class VQGANDataModule(LightningDataModule):
|
100 |
-
def __init__(
|
101 |
-
self,
|
102 |
-
train_dataset: VQGANDataset,
|
103 |
-
val_dataset: VQGANDataset,
|
104 |
-
batch_size: int = 32,
|
105 |
-
num_workers: int = 4,
|
106 |
-
val_batch_size: Optional[int] = None,
|
107 |
-
):
|
108 |
-
super().__init__()
|
109 |
-
|
110 |
-
self.train_dataset = train_dataset
|
111 |
-
self.val_dataset = val_dataset
|
112 |
-
self.batch_size = batch_size
|
113 |
-
self.val_batch_size = val_batch_size or batch_size
|
114 |
-
self.num_workers = num_workers
|
115 |
-
|
116 |
-
def train_dataloader(self):
|
117 |
-
return DataLoader(
|
118 |
-
self.train_dataset,
|
119 |
-
batch_size=self.batch_size,
|
120 |
-
collate_fn=VQGANCollator(),
|
121 |
-
num_workers=self.num_workers,
|
122 |
-
shuffle=True,
|
123 |
-
persistent_workers=True,
|
124 |
-
)
|
125 |
-
|
126 |
-
def val_dataloader(self):
|
127 |
-
return DataLoader(
|
128 |
-
self.val_dataset,
|
129 |
-
batch_size=self.val_batch_size,
|
130 |
-
collate_fn=VQGANCollator(),
|
131 |
-
num_workers=self.num_workers,
|
132 |
-
persistent_workers=True,
|
133 |
-
)
|
134 |
-
|
135 |
-
|
136 |
-
if __name__ == "__main__":
|
137 |
-
dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
|
138 |
-
dataloader = DataLoader(
|
139 |
-
dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
|
140 |
-
)
|
141 |
-
|
142 |
-
for batch in dataloader:
|
143 |
-
print(batch["audios"].shape)
|
144 |
-
print(batch["features"].shape)
|
145 |
-
print(batch["audio_lengths"])
|
146 |
-
print(batch["feature_lengths"])
|
147 |
-
break
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from lightning import LightningDataModule
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
|
11 |
+
from fish_speech.utils import RankedLogger
|
12 |
+
|
13 |
+
logger = RankedLogger(__name__, rank_zero_only=False)
|
14 |
+
|
15 |
+
|
16 |
+
class VQGANDataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
filelist: str,
|
20 |
+
sample_rate: int = 32000,
|
21 |
+
hop_length: int = 640,
|
22 |
+
slice_frames: Optional[int] = None,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
filelist = Path(filelist)
|
27 |
+
root = filelist.parent
|
28 |
+
|
29 |
+
self.files = [
|
30 |
+
root / line.strip()
|
31 |
+
for line in filelist.read_text(encoding="utf-8").splitlines()
|
32 |
+
if line.strip()
|
33 |
+
]
|
34 |
+
self.sample_rate = sample_rate
|
35 |
+
self.hop_length = hop_length
|
36 |
+
self.slice_frames = slice_frames
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.files)
|
40 |
+
|
41 |
+
def get_item(self, idx):
|
42 |
+
file = self.files[idx]
|
43 |
+
|
44 |
+
audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
|
45 |
+
|
46 |
+
# Slice audio and features
|
47 |
+
if (
|
48 |
+
self.slice_frames is not None
|
49 |
+
and audio.shape[0] > self.slice_frames * self.hop_length
|
50 |
+
):
|
51 |
+
start = np.random.randint(
|
52 |
+
0, audio.shape[0] - self.slice_frames * self.hop_length
|
53 |
+
)
|
54 |
+
audio = audio[start : start + self.slice_frames * self.hop_length]
|
55 |
+
|
56 |
+
if len(audio) == 0:
|
57 |
+
return None
|
58 |
+
|
59 |
+
max_value = np.abs(audio).max()
|
60 |
+
if max_value > 1.0:
|
61 |
+
audio = audio / max_value
|
62 |
+
|
63 |
+
return {
|
64 |
+
"audio": torch.from_numpy(audio),
|
65 |
+
}
|
66 |
+
|
67 |
+
def __getitem__(self, idx):
|
68 |
+
try:
|
69 |
+
return self.get_item(idx)
|
70 |
+
except Exception as e:
|
71 |
+
import traceback
|
72 |
+
|
73 |
+
traceback.print_exc()
|
74 |
+
logger.error(f"Error loading {self.files[idx]}: {e}")
|
75 |
+
return None
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class VQGANCollator:
|
80 |
+
def __call__(self, batch):
|
81 |
+
batch = [x for x in batch if x is not None]
|
82 |
+
|
83 |
+
audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
|
84 |
+
audio_maxlen = audio_lengths.max()
|
85 |
+
|
86 |
+
# Rounds up to nearest multiple of 2 (audio_lengths)
|
87 |
+
audios = []
|
88 |
+
for x in batch:
|
89 |
+
audios.append(
|
90 |
+
torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
|
91 |
+
)
|
92 |
+
|
93 |
+
return {
|
94 |
+
"audios": torch.stack(audios),
|
95 |
+
"audio_lengths": audio_lengths,
|
96 |
+
}
|
97 |
+
|
98 |
+
|
99 |
+
class VQGANDataModule(LightningDataModule):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
train_dataset: VQGANDataset,
|
103 |
+
val_dataset: VQGANDataset,
|
104 |
+
batch_size: int = 32,
|
105 |
+
num_workers: int = 4,
|
106 |
+
val_batch_size: Optional[int] = None,
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.train_dataset = train_dataset
|
111 |
+
self.val_dataset = val_dataset
|
112 |
+
self.batch_size = batch_size
|
113 |
+
self.val_batch_size = val_batch_size or batch_size
|
114 |
+
self.num_workers = num_workers
|
115 |
+
|
116 |
+
def train_dataloader(self):
|
117 |
+
return DataLoader(
|
118 |
+
self.train_dataset,
|
119 |
+
batch_size=self.batch_size,
|
120 |
+
collate_fn=VQGANCollator(),
|
121 |
+
num_workers=self.num_workers,
|
122 |
+
shuffle=True,
|
123 |
+
persistent_workers=True,
|
124 |
+
)
|
125 |
+
|
126 |
+
def val_dataloader(self):
|
127 |
+
return DataLoader(
|
128 |
+
self.val_dataset,
|
129 |
+
batch_size=self.val_batch_size,
|
130 |
+
collate_fn=VQGANCollator(),
|
131 |
+
num_workers=self.num_workers,
|
132 |
+
persistent_workers=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
|
138 |
+
dataloader = DataLoader(
|
139 |
+
dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
|
140 |
+
)
|
141 |
+
|
142 |
+
for batch in dataloader:
|
143 |
+
print(batch["audios"].shape)
|
144 |
+
print(batch["features"].shape)
|
145 |
+
print(batch["audio_lengths"])
|
146 |
+
print(batch["feature_lengths"])
|
147 |
+
break
|
fish_speech/i18n/README.md
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
-
## i18n Folder Attribution
|
2 |
-
|
3 |
-
The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
|
4 |
-
|
5 |
-
### fish_speech/i18n/core.py
|
6 |
-
|
7 |
-
**Related code from RVC:**
|
8 |
-
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
|
9 |
-
|
10 |
-
**Initial commit:**
|
11 |
-
add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
|
12 |
-
|
13 |
-
**Initial author:**
|
14 |
-
[@L4Ph](https://github.com/L4Ph)
|
15 |
-
|
16 |
-
### fish_speech/i18n/scan.py
|
17 |
-
|
18 |
-
**Related code from RVC:**
|
19 |
-
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
|
20 |
-
|
21 |
-
**Initial commit:**
|
22 |
-
File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
|
23 |
-
|
24 |
-
**Initial author:**
|
25 |
-
[@towzeur](https://github.com/towzeur)
|
26 |
-
|
27 |
-
We appreciate the contributions of the RVC project and its authors.
|
|
|
1 |
+
## i18n Folder Attribution
|
2 |
+
|
3 |
+
The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
|
4 |
+
|
5 |
+
### fish_speech/i18n/core.py
|
6 |
+
|
7 |
+
**Related code from RVC:**
|
8 |
+
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
|
9 |
+
|
10 |
+
**Initial commit:**
|
11 |
+
add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
|
12 |
+
|
13 |
+
**Initial author:**
|
14 |
+
[@L4Ph](https://github.com/L4Ph)
|
15 |
+
|
16 |
+
### fish_speech/i18n/scan.py
|
17 |
+
|
18 |
+
**Related code from RVC:**
|
19 |
+
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
|
20 |
+
|
21 |
+
**Initial commit:**
|
22 |
+
File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
|
23 |
+
|
24 |
+
**Initial author:**
|
25 |
+
[@towzeur](https://github.com/towzeur)
|
26 |
+
|
27 |
+
We appreciate the contributions of the RVC project and its authors.
|
fish_speech/i18n/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from .core import i18n
|
2 |
-
|
3 |
-
__all__ = ["i18n"]
|
|
|
1 |
+
from .core import i18n
|
2 |
+
|
3 |
+
__all__ = ["i18n"]
|
fish_speech/i18n/core.py
CHANGED
@@ -1,40 +1,40 @@
|
|
1 |
-
import json
|
2 |
-
import locale
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
I18N_FILE_PATH = Path(__file__).parent / "locale"
|
6 |
-
DEFAULT_LANGUAGE = "en_US"
|
7 |
-
|
8 |
-
|
9 |
-
def load_language_list(language):
|
10 |
-
with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
|
11 |
-
language_list = json.load(f)
|
12 |
-
|
13 |
-
return language_list
|
14 |
-
|
15 |
-
|
16 |
-
class I18nAuto:
|
17 |
-
def __init__(self):
|
18 |
-
i18n_file = Path(".locale")
|
19 |
-
|
20 |
-
if i18n_file.exists():
|
21 |
-
with open(i18n_file, "r", encoding="utf-8") as f:
|
22 |
-
language = f.read().strip()
|
23 |
-
else:
|
24 |
-
# getlocale can't identify the system's language ((None, None))
|
25 |
-
language = locale.getdefaultlocale()[0]
|
26 |
-
|
27 |
-
if (I18N_FILE_PATH / f"{language}.json").exists() is False:
|
28 |
-
language = DEFAULT_LANGUAGE
|
29 |
-
|
30 |
-
self.language = language
|
31 |
-
self.language_map = load_language_list(language)
|
32 |
-
|
33 |
-
def __call__(self, key):
|
34 |
-
return self.language_map.get(key, key)
|
35 |
-
|
36 |
-
def __repr__(self):
|
37 |
-
return "Use Language: " + self.language
|
38 |
-
|
39 |
-
|
40 |
-
i18n = I18nAuto()
|
|
|
1 |
+
import json
|
2 |
+
import locale
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
I18N_FILE_PATH = Path(__file__).parent / "locale"
|
6 |
+
DEFAULT_LANGUAGE = "en_US"
|
7 |
+
|
8 |
+
|
9 |
+
def load_language_list(language):
|
10 |
+
with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
|
11 |
+
language_list = json.load(f)
|
12 |
+
|
13 |
+
return language_list
|
14 |
+
|
15 |
+
|
16 |
+
class I18nAuto:
|
17 |
+
def __init__(self):
|
18 |
+
i18n_file = Path(".locale")
|
19 |
+
|
20 |
+
if i18n_file.exists():
|
21 |
+
with open(i18n_file, "r", encoding="utf-8") as f:
|
22 |
+
language = f.read().strip()
|
23 |
+
else:
|
24 |
+
# getlocale can't identify the system's language ((None, None))
|
25 |
+
language = locale.getdefaultlocale()[0]
|
26 |
+
|
27 |
+
if (I18N_FILE_PATH / f"{language}.json").exists() is False:
|
28 |
+
language = DEFAULT_LANGUAGE
|
29 |
+
|
30 |
+
self.language = language
|
31 |
+
self.language_map = load_language_list(language)
|
32 |
+
|
33 |
+
def __call__(self, key):
|
34 |
+
return self.language_map.get(key, key)
|
35 |
+
|
36 |
+
def __repr__(self):
|
37 |
+
return "Use Language: " + self.language
|
38 |
+
|
39 |
+
|
40 |
+
i18n = I18nAuto()
|
fish_speech/i18n/locale/en_US.json
CHANGED
@@ -1,122 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
|
5 |
-
"Accumulate Gradient Batches": "Accumulate Gradient Batches",
|
6 |
-
"Add to Processing Area": "Add to Processing Area",
|
7 |
-
"Added path successfully!": "Added path successfully!",
|
8 |
-
"Advanced Config": "Advanced Config",
|
9 |
-
"Base LLAMA Model": "Base LLAMA Model",
|
10 |
-
"Batch Inference": "Batch Inference",
|
11 |
-
"Batch Size": "Batch Size",
|
12 |
-
"Changing with the Model Path": "Changing with the Model Path",
|
13 |
-
"Chinese": "Chinese",
|
14 |
-
"Compile Model": "Compile Model",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
|
16 |
-
"Copy": "Copy",
|
17 |
-
"Data Preprocessing": "Data Preprocessing",
|
18 |
-
"Data Preprocessing Path": "Data Preprocessing Path",
|
19 |
-
"Data Source": "Data Source",
|
20 |
-
"Decoder Model Config": "Decoder Model Config",
|
21 |
-
"Decoder Model Path": "Decoder Model Path",
|
22 |
-
"Disabled": "Disabled",
|
23 |
-
"Enable Reference Audio": "Enable Reference Audio",
|
24 |
-
"English": "English",
|
25 |
-
"Error Message": "Error Message",
|
26 |
-
"File Preprocessing": "File Preprocessing",
|
27 |
-
"Generate": "Generate",
|
28 |
-
"Generated Audio": "Generated Audio",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
|
30 |
-
"Infer interface is closed": "Infer interface is closed",
|
31 |
-
"Inference Configuration": "Inference Configuration",
|
32 |
-
"Inference Server Configuration": "Inference Server Configuration",
|
33 |
-
"Inference Server Error": "Inference Server Error",
|
34 |
-
"Inferring interface is launched at {}": "Inferring interface is launched at {}",
|
35 |
-
"Initial Learning Rate": "Initial Learning Rate",
|
36 |
-
"Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
|
37 |
-
"Input Text": "Input Text",
|
38 |
-
"Invalid path: {}": "Invalid path: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
|
40 |
-
"Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
|
41 |
-
"Japanese": "Japanese",
|
42 |
-
"LLAMA Configuration": "LLAMA Configuration",
|
43 |
-
"LLAMA Model Config": "LLAMA Model Config",
|
44 |
-
"LLAMA Model Path": "LLAMA Model Path",
|
45 |
-
"Labeling Device": "Labeling Device",
|
46 |
-
"LoRA Model to be merged": "LoRA Model to be merged",
|
47 |
-
"Maximum Audio Duration": "Maximum Audio Duration",
|
48 |
-
"Maximum Length per Sample": "Maximum Length per Sample",
|
49 |
-
"Maximum Training Steps": "Maximum Training Steps",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
|
51 |
-
"Merge": "Merge",
|
52 |
-
"Merge LoRA": "Merge LoRA",
|
53 |
-
"Merge successfully": "Merge successfully",
|
54 |
-
"Minimum Audio Duration": "Minimum Audio Duration",
|
55 |
-
"Model Output Path": "Model Output Path",
|
56 |
-
"Model Size": "Model Size",
|
57 |
-
"Move": "Move",
|
58 |
-
"Move files successfully": "Move files successfully",
|
59 |
-
"No audio generated, please check the input text.": "No audio generated, please check the input text.",
|
60 |
-
"No selected options": "No selected options",
|
61 |
-
"Number of Workers": "Number of Workers",
|
62 |
-
"Open Inference Server": "Open Inference Server",
|
63 |
-
"Open Labeler WebUI": "Open Labeler WebUI",
|
64 |
-
"Open Tensorboard": "Open Tensorboard",
|
65 |
-
"Opened labeler in browser": "Opened labeler in browser",
|
66 |
-
"Optional Label Language": "Optional Label Language",
|
67 |
-
"Optional online ver": "Optional online ver",
|
68 |
-
"Output Path": "Output Path",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
|
70 |
-
"Precision": "Precision",
|
71 |
-
"Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
|
72 |
-
"Put your text here.": "Put your text here.",
|
73 |
-
"Reference Audio": "Reference Audio",
|
74 |
-
"Reference Text": "Reference Text",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
|
76 |
-
"Remove Selected Data": "Remove Selected Data",
|
77 |
-
"Removed path successfully!": "Removed path successfully!",
|
78 |
-
"Repetition Penalty": "Repetition Penalty",
|
79 |
-
"Save model every n steps": "Save model every n steps",
|
80 |
-
"Select LLAMA ckpt": "Select LLAMA ckpt",
|
81 |
-
"Select VITS ckpt": "Select VITS ckpt",
|
82 |
-
"Select VQGAN ckpt": "Select VQGAN ckpt",
|
83 |
-
"Select source file processing method": "Select source file processing method",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
|
85 |
-
"Selected: {}": "Selected: {}",
|
86 |
-
"Speaker": "Speaker",
|
87 |
-
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
|
88 |
-
"Start Training": "Start Training",
|
89 |
-
"Streaming Audio": "Streaming Audio",
|
90 |
-
"Streaming Generate": "Streaming Generate",
|
91 |
-
"Tensorboard Host": "Tensorboard Host",
|
92 |
-
"Tensorboard Log Path": "Tensorboard Log Path",
|
93 |
-
"Tensorboard Port": "Tensorboard Port",
|
94 |
-
"Tensorboard interface is closed": "Tensorboard interface is closed",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
|
96 |
-
"Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
|
98 |
-
"Training Configuration": "Training Configuration",
|
99 |
-
"Training Error": "Training Error",
|
100 |
-
"Training stopped": "Training stopped",
|
101 |
-
"Type name of the speaker": "Type name of the speaker",
|
102 |
-
"Type the path or select from the dropdown": "Type the path or select from the dropdown",
|
103 |
-
"Use LoRA": "Use LoRA",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
|
105 |
-
"Use filelist": "Use filelist",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
|
107 |
-
"VITS Configuration": "VITS Configuration",
|
108 |
-
"VQGAN Configuration": "VQGAN Configuration",
|
109 |
-
"Validation Batch Size": "Validation Batch Size",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
|
112 |
-
"WebUI Host": "WebUI Host",
|
113 |
-
"WebUI Port": "WebUI Port",
|
114 |
-
"Whisper Model": "Whisper Model",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
|
117 |
-
"latest": "latest",
|
118 |
-
"new": "new",
|
119 |
-
"Realtime Transform Text": "Realtime Transform Text",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
|
121 |
-
"Text Normalization": "Text Normalization"
|
122 |
-
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
|
5 |
+
"Accumulate Gradient Batches": "Accumulate Gradient Batches",
|
6 |
+
"Add to Processing Area": "Add to Processing Area",
|
7 |
+
"Added path successfully!": "Added path successfully!",
|
8 |
+
"Advanced Config": "Advanced Config",
|
9 |
+
"Base LLAMA Model": "Base LLAMA Model",
|
10 |
+
"Batch Inference": "Batch Inference",
|
11 |
+
"Batch Size": "Batch Size",
|
12 |
+
"Changing with the Model Path": "Changing with the Model Path",
|
13 |
+
"Chinese": "Chinese",
|
14 |
+
"Compile Model": "Compile Model",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
|
16 |
+
"Copy": "Copy",
|
17 |
+
"Data Preprocessing": "Data Preprocessing",
|
18 |
+
"Data Preprocessing Path": "Data Preprocessing Path",
|
19 |
+
"Data Source": "Data Source",
|
20 |
+
"Decoder Model Config": "Decoder Model Config",
|
21 |
+
"Decoder Model Path": "Decoder Model Path",
|
22 |
+
"Disabled": "Disabled",
|
23 |
+
"Enable Reference Audio": "Enable Reference Audio",
|
24 |
+
"English": "English",
|
25 |
+
"Error Message": "Error Message",
|
26 |
+
"File Preprocessing": "File Preprocessing",
|
27 |
+
"Generate": "Generate",
|
28 |
+
"Generated Audio": "Generated Audio",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
|
30 |
+
"Infer interface is closed": "Infer interface is closed",
|
31 |
+
"Inference Configuration": "Inference Configuration",
|
32 |
+
"Inference Server Configuration": "Inference Server Configuration",
|
33 |
+
"Inference Server Error": "Inference Server Error",
|
34 |
+
"Inferring interface is launched at {}": "Inferring interface is launched at {}",
|
35 |
+
"Initial Learning Rate": "Initial Learning Rate",
|
36 |
+
"Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
|
37 |
+
"Input Text": "Input Text",
|
38 |
+
"Invalid path: {}": "Invalid path: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
|
41 |
+
"Japanese": "Japanese",
|
42 |
+
"LLAMA Configuration": "LLAMA Configuration",
|
43 |
+
"LLAMA Model Config": "LLAMA Model Config",
|
44 |
+
"LLAMA Model Path": "LLAMA Model Path",
|
45 |
+
"Labeling Device": "Labeling Device",
|
46 |
+
"LoRA Model to be merged": "LoRA Model to be merged",
|
47 |
+
"Maximum Audio Duration": "Maximum Audio Duration",
|
48 |
+
"Maximum Length per Sample": "Maximum Length per Sample",
|
49 |
+
"Maximum Training Steps": "Maximum Training Steps",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
|
51 |
+
"Merge": "Merge",
|
52 |
+
"Merge LoRA": "Merge LoRA",
|
53 |
+
"Merge successfully": "Merge successfully",
|
54 |
+
"Minimum Audio Duration": "Minimum Audio Duration",
|
55 |
+
"Model Output Path": "Model Output Path",
|
56 |
+
"Model Size": "Model Size",
|
57 |
+
"Move": "Move",
|
58 |
+
"Move files successfully": "Move files successfully",
|
59 |
+
"No audio generated, please check the input text.": "No audio generated, please check the input text.",
|
60 |
+
"No selected options": "No selected options",
|
61 |
+
"Number of Workers": "Number of Workers",
|
62 |
+
"Open Inference Server": "Open Inference Server",
|
63 |
+
"Open Labeler WebUI": "Open Labeler WebUI",
|
64 |
+
"Open Tensorboard": "Open Tensorboard",
|
65 |
+
"Opened labeler in browser": "Opened labeler in browser",
|
66 |
+
"Optional Label Language": "Optional Label Language",
|
67 |
+
"Optional online ver": "Optional online ver",
|
68 |
+
"Output Path": "Output Path",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
|
70 |
+
"Precision": "Precision",
|
71 |
+
"Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
|
72 |
+
"Put your text here.": "Put your text here.",
|
73 |
+
"Reference Audio": "Reference Audio",
|
74 |
+
"Reference Text": "Reference Text",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
|
76 |
+
"Remove Selected Data": "Remove Selected Data",
|
77 |
+
"Removed path successfully!": "Removed path successfully!",
|
78 |
+
"Repetition Penalty": "Repetition Penalty",
|
79 |
+
"Save model every n steps": "Save model every n steps",
|
80 |
+
"Select LLAMA ckpt": "Select LLAMA ckpt",
|
81 |
+
"Select VITS ckpt": "Select VITS ckpt",
|
82 |
+
"Select VQGAN ckpt": "Select VQGAN ckpt",
|
83 |
+
"Select source file processing method": "Select source file processing method",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
|
85 |
+
"Selected: {}": "Selected: {}",
|
86 |
+
"Speaker": "Speaker",
|
87 |
+
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
|
88 |
+
"Start Training": "Start Training",
|
89 |
+
"Streaming Audio": "Streaming Audio",
|
90 |
+
"Streaming Generate": "Streaming Generate",
|
91 |
+
"Tensorboard Host": "Tensorboard Host",
|
92 |
+
"Tensorboard Log Path": "Tensorboard Log Path",
|
93 |
+
"Tensorboard Port": "Tensorboard Port",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard interface is closed",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
|
96 |
+
"Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
|
98 |
+
"Training Configuration": "Training Configuration",
|
99 |
+
"Training Error": "Training Error",
|
100 |
+
"Training stopped": "Training stopped",
|
101 |
+
"Type name of the speaker": "Type name of the speaker",
|
102 |
+
"Type the path or select from the dropdown": "Type the path or select from the dropdown",
|
103 |
+
"Use LoRA": "Use LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
|
105 |
+
"Use filelist": "Use filelist",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
|
107 |
+
"VITS Configuration": "VITS Configuration",
|
108 |
+
"VQGAN Configuration": "VQGAN Configuration",
|
109 |
+
"Validation Batch Size": "Validation Batch Size",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
|
112 |
+
"WebUI Host": "WebUI Host",
|
113 |
+
"WebUI Port": "WebUI Port",
|
114 |
+
"Whisper Model": "Whisper Model",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
|
117 |
+
"latest": "latest",
|
118 |
+
"new": "new",
|
119 |
+
"Realtime Transform Text": "Realtime Transform Text",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
|
121 |
+
"Text Normalization": "Text Normalization",
|
122 |
+
"Select Example Audio": "Select Example Audio"
|
123 |
+
}
|
fish_speech/i18n/locale/es_ES.json
CHANGED
@@ -1,122 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
|
5 |
-
"Accumulate Gradient Batches": "Acumular lotes de gradientes",
|
6 |
-
"Add to Processing Area": "Agregar al Área de Procesamiento",
|
7 |
-
"Added path successfully!": "¡Ruta agregada exitosamente!",
|
8 |
-
"Advanced Config": "Configuración Avanzada",
|
9 |
-
"Base LLAMA Model": "Modelo Base LLAMA",
|
10 |
-
"Batch Inference": "Inferencia por Lote",
|
11 |
-
"Batch Size": "Tamaño del Lote",
|
12 |
-
"Changing with the Model Path": "Cambiando con la Ruta del Modelo",
|
13 |
-
"Chinese": "Chino",
|
14 |
-
"Compile Model": "Compilar Modelo",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
|
16 |
-
"Copy": "Copiar",
|
17 |
-
"Data Preprocessing": "Preprocesamiento de Datos",
|
18 |
-
"Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
|
19 |
-
"Data Source": "Fuente de Datos",
|
20 |
-
"Decoder Model Config": "Configuración del modelo decodificador",
|
21 |
-
"Decoder Model Path": "Ruta del modelo decodificador",
|
22 |
-
"Disabled": "Desactivado",
|
23 |
-
"Enable Reference Audio": "Habilitar Audio de Referencia",
|
24 |
-
"English": "Inglés",
|
25 |
-
"Error Message": "Mensaje de Error",
|
26 |
-
"File Preprocessing": "Preprocesamiento de Archivos",
|
27 |
-
"Generate": "Generar",
|
28 |
-
"Generated Audio": "Audio Generado",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
|
30 |
-
"Infer interface is closed": "La interfaz de inferencia está cerrada",
|
31 |
-
"Inference Configuration": "Configuración de Inferencia",
|
32 |
-
"Inference Server Configuration": "Configuración del Servidor de Inferencia",
|
33 |
-
"Inference Server Error": "Error del Servidor de Inferencia",
|
34 |
-
"Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
|
35 |
-
"Initial Learning Rate": "Tasa de Aprendizaje Inicial",
|
36 |
-
"Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
|
37 |
-
"Input Text": "Texto de Entrada",
|
38 |
-
"Invalid path: {}": "Ruta inválida: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
|
40 |
-
"Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
|
41 |
-
"Japanese": "Japonés",
|
42 |
-
"LLAMA Configuration": "Configuración de LLAMA",
|
43 |
-
"LLAMA Model Config": "Configuración del Modelo LLAMA",
|
44 |
-
"LLAMA Model Path": "Ruta del Modelo LLAMA",
|
45 |
-
"Labeling Device": "Dispositivo de Etiquetado",
|
46 |
-
"LoRA Model to be merged": "Modelo LoRA a fusionar",
|
47 |
-
"Maximum Audio Duration": "Duración máxima de audio",
|
48 |
-
"Maximum Length per Sample": "Longitud Máxima por Muestra",
|
49 |
-
"Maximum Training Steps": "Pasos Máximos de Entrenamiento",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
|
51 |
-
"Merge": "Fusionar",
|
52 |
-
"Merge LoRA": "Fusionar LoRA",
|
53 |
-
"Merge successfully": "Fusionado exitosamente",
|
54 |
-
"Minimum Audio Duration": "Duración mínima de audio",
|
55 |
-
"Model Output Path": "Ruta de Salida del Modelo",
|
56 |
-
"Model Size": "Tamaño del Modelo",
|
57 |
-
"Move": "Mover",
|
58 |
-
"Move files successfully": "Archivos movidos exitosamente",
|
59 |
-
"No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
|
60 |
-
"No selected options": "No hay opciones seleccionadas",
|
61 |
-
"Number of Workers": "Número de Trabajadores",
|
62 |
-
"Open Inference Server": "Abrir Servidor de Inferencia",
|
63 |
-
"Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
|
64 |
-
"Open Tensorboard": "Abrir Tensorboard",
|
65 |
-
"Opened labeler in browser": "Se abrió el etiquetador en el navegador",
|
66 |
-
"Optional Label Language": "Idioma de Etiquetado Opcional",
|
67 |
-
"Optional online ver": "Ver en línea opcional",
|
68 |
-
"Output Path": "Ruta de Salida",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
|
70 |
-
"Precision": "Precisión",
|
71 |
-
"Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
|
72 |
-
"Put your text here.": "Ponga su texto aquí.",
|
73 |
-
"Reference Audio": "Audio de Referencia",
|
74 |
-
"Reference Text": "Texto de Referencia",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
|
76 |
-
"Remove Selected Data": "Eliminar Datos Seleccionados",
|
77 |
-
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
|
78 |
-
"Repetition Penalty": "Penalización por Repetición",
|
79 |
-
"Save model every n steps": "Guardar modelo cada n pasos",
|
80 |
-
"Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
|
81 |
-
"Select VITS ckpt": "Seleccionar punto de control VITS",
|
82 |
-
"Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
|
83 |
-
"Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
|
85 |
-
"Selected: {}": "Seleccionado: {}",
|
86 |
-
"Speaker": "Hablante",
|
87 |
-
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
|
88 |
-
"Start Training": "Iniciar Entrenamiento",
|
89 |
-
"Streaming Audio": "transmisión de audio",
|
90 |
-
"Streaming Generate": "síntesis en flujo",
|
91 |
-
"Tensorboard Host": "Host de Tensorboard",
|
92 |
-
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
|
93 |
-
"Tensorboard Port": "Puerto de Tensorboard",
|
94 |
-
"Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
|
95 |
-
"Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
|
96 |
-
"Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
|
98 |
-
"Training Configuration": "Configuración de Entrenamiento",
|
99 |
-
"Training Error": "Error de Entrenamiento",
|
100 |
-
"Training stopped": "Entrenamiento detenido",
|
101 |
-
"Type name of the speaker": "Escriba el nombre del hablante",
|
102 |
-
"Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
|
103 |
-
"Use LoRA": "Usar LoRA",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
|
105 |
-
"Use filelist": "Usar lista de archivos",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
|
107 |
-
"VITS Configuration": "Configuración de VITS",
|
108 |
-
"VQGAN Configuration": "Configuración de VQGAN",
|
109 |
-
"Validation Batch Size": "Tamaño del Lote de Validación",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
|
112 |
-
"WebUI Host": "Host de WebUI",
|
113 |
-
"WebUI Port": "Puerto de WebUI",
|
114 |
-
"Whisper Model": "Modelo Whisper",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
|
117 |
-
"latest": "más reciente",
|
118 |
-
"new": "nuevo",
|
119 |
-
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
|
121 |
-
"Text Normalization": "Normalización de Texto"
|
122 |
-
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
|
5 |
+
"Accumulate Gradient Batches": "Acumular lotes de gradientes",
|
6 |
+
"Add to Processing Area": "Agregar al Área de Procesamiento",
|
7 |
+
"Added path successfully!": "¡Ruta agregada exitosamente!",
|
8 |
+
"Advanced Config": "Configuración Avanzada",
|
9 |
+
"Base LLAMA Model": "Modelo Base LLAMA",
|
10 |
+
"Batch Inference": "Inferencia por Lote",
|
11 |
+
"Batch Size": "Tamaño del Lote",
|
12 |
+
"Changing with the Model Path": "Cambiando con la Ruta del Modelo",
|
13 |
+
"Chinese": "Chino",
|
14 |
+
"Compile Model": "Compilar Modelo",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
|
16 |
+
"Copy": "Copiar",
|
17 |
+
"Data Preprocessing": "Preprocesamiento de Datos",
|
18 |
+
"Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
|
19 |
+
"Data Source": "Fuente de Datos",
|
20 |
+
"Decoder Model Config": "Configuración del modelo decodificador",
|
21 |
+
"Decoder Model Path": "Ruta del modelo decodificador",
|
22 |
+
"Disabled": "Desactivado",
|
23 |
+
"Enable Reference Audio": "Habilitar Audio de Referencia",
|
24 |
+
"English": "Inglés",
|
25 |
+
"Error Message": "Mensaje de Error",
|
26 |
+
"File Preprocessing": "Preprocesamiento de Archivos",
|
27 |
+
"Generate": "Generar",
|
28 |
+
"Generated Audio": "Audio Generado",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
|
30 |
+
"Infer interface is closed": "La interfaz de inferencia está cerrada",
|
31 |
+
"Inference Configuration": "Configuración de Inferencia",
|
32 |
+
"Inference Server Configuration": "Configuración del Servidor de Inferencia",
|
33 |
+
"Inference Server Error": "Error del Servidor de Inferencia",
|
34 |
+
"Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
|
35 |
+
"Initial Learning Rate": "Tasa de Aprendizaje Inicial",
|
36 |
+
"Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
|
37 |
+
"Input Text": "Texto de Entrada",
|
38 |
+
"Invalid path: {}": "Ruta inválida: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
|
41 |
+
"Japanese": "Japonés",
|
42 |
+
"LLAMA Configuration": "Configuración de LLAMA",
|
43 |
+
"LLAMA Model Config": "Configuración del Modelo LLAMA",
|
44 |
+
"LLAMA Model Path": "Ruta del Modelo LLAMA",
|
45 |
+
"Labeling Device": "Dispositivo de Etiquetado",
|
46 |
+
"LoRA Model to be merged": "Modelo LoRA a fusionar",
|
47 |
+
"Maximum Audio Duration": "Duración máxima de audio",
|
48 |
+
"Maximum Length per Sample": "Longitud Máxima por Muestra",
|
49 |
+
"Maximum Training Steps": "Pasos Máximos de Entrenamiento",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
|
51 |
+
"Merge": "Fusionar",
|
52 |
+
"Merge LoRA": "Fusionar LoRA",
|
53 |
+
"Merge successfully": "Fusionado exitosamente",
|
54 |
+
"Minimum Audio Duration": "Duración mínima de audio",
|
55 |
+
"Model Output Path": "Ruta de Salida del Modelo",
|
56 |
+
"Model Size": "Tamaño del Modelo",
|
57 |
+
"Move": "Mover",
|
58 |
+
"Move files successfully": "Archivos movidos exitosamente",
|
59 |
+
"No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
|
60 |
+
"No selected options": "No hay opciones seleccionadas",
|
61 |
+
"Number of Workers": "Número de Trabajadores",
|
62 |
+
"Open Inference Server": "Abrir Servidor de Inferencia",
|
63 |
+
"Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
|
64 |
+
"Open Tensorboard": "Abrir Tensorboard",
|
65 |
+
"Opened labeler in browser": "Se abrió el etiquetador en el navegador",
|
66 |
+
"Optional Label Language": "Idioma de Etiquetado Opcional",
|
67 |
+
"Optional online ver": "Ver en línea opcional",
|
68 |
+
"Output Path": "Ruta de Salida",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
|
70 |
+
"Precision": "Precisión",
|
71 |
+
"Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
|
72 |
+
"Put your text here.": "Ponga su texto aquí.",
|
73 |
+
"Reference Audio": "Audio de Referencia",
|
74 |
+
"Reference Text": "Texto de Referencia",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
|
76 |
+
"Remove Selected Data": "Eliminar Datos Seleccionados",
|
77 |
+
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
|
78 |
+
"Repetition Penalty": "Penalización por Repetición",
|
79 |
+
"Save model every n steps": "Guardar modelo cada n pasos",
|
80 |
+
"Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
|
81 |
+
"Select VITS ckpt": "Seleccionar punto de control VITS",
|
82 |
+
"Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
|
83 |
+
"Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
|
85 |
+
"Selected: {}": "Seleccionado: {}",
|
86 |
+
"Speaker": "Hablante",
|
87 |
+
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
|
88 |
+
"Start Training": "Iniciar Entrenamiento",
|
89 |
+
"Streaming Audio": "transmisión de audio",
|
90 |
+
"Streaming Generate": "síntesis en flujo",
|
91 |
+
"Tensorboard Host": "Host de Tensorboard",
|
92 |
+
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
|
93 |
+
"Tensorboard Port": "Puerto de Tensorboard",
|
94 |
+
"Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
|
95 |
+
"Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
|
96 |
+
"Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
|
98 |
+
"Training Configuration": "Configuración de Entrenamiento",
|
99 |
+
"Training Error": "Error de Entrenamiento",
|
100 |
+
"Training stopped": "Entrenamiento detenido",
|
101 |
+
"Type name of the speaker": "Escriba el nombre del hablante",
|
102 |
+
"Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
|
103 |
+
"Use LoRA": "Usar LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
|
105 |
+
"Use filelist": "Usar lista de archivos",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
|
107 |
+
"VITS Configuration": "Configuración de VITS",
|
108 |
+
"VQGAN Configuration": "Configuración de VQGAN",
|
109 |
+
"Validation Batch Size": "Tamaño del Lote de Validación",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
|
112 |
+
"WebUI Host": "Host de WebUI",
|
113 |
+
"WebUI Port": "Puerto de WebUI",
|
114 |
+
"Whisper Model": "Modelo Whisper",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
|
117 |
+
"latest": "más reciente",
|
118 |
+
"new": "nuevo",
|
119 |
+
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
|
121 |
+
"Text Normalization": "Normalización de Texto",
|
122 |
+
"Select Example Audio": "Selecionar áudio de exemplo"
|
123 |
+
}
|
fish_speech/i18n/locale/ja_JP.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
|
5 |
-
"Accumulate Gradient Batches": "勾配バッチの累積",
|
6 |
-
"Add to Processing Area": "処理エリアに追加",
|
7 |
-
"Added path successfully!": "パスの追加に成功しました!",
|
8 |
-
"Advanced Config": "詳細設定",
|
9 |
-
"Base LLAMA Model": "基本LLAMAモデル",
|
10 |
-
"Batch Inference": "バッチ推論",
|
11 |
-
"Batch Size": "バッチサイズ",
|
12 |
-
"Changing with the Model Path": "モデルのパスに伴って変化する",
|
13 |
-
"Chinese": "中国語",
|
14 |
-
"Compile Model": "モデルのコンパイル",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
|
16 |
-
"Copy": "コピー",
|
17 |
-
"Data Preprocessing": "データ前処理",
|
18 |
-
"Data Preprocessing Path": "データ前処理パス",
|
19 |
-
"Data Source": "データソース",
|
20 |
-
"Decoder Model Config": "デコーダーモデルの構成",
|
21 |
-
"Decoder Model Path": "デコーダーモデルのパス",
|
22 |
-
"Disabled": "無効",
|
23 |
-
"Enable Reference Audio": "リファレンスオーディオを有効にする",
|
24 |
-
"English": "英語",
|
25 |
-
"Error Message": "エラーメッセージ",
|
26 |
-
"File Preprocessing": "文書前处理",
|
27 |
-
"Generate": "生成",
|
28 |
-
"Generated Audio": "生成されたオーディオ",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
|
30 |
-
"Infer interface is closed": "推論インターフェースが閉じられています",
|
31 |
-
"Inference Configuration": "推論設定",
|
32 |
-
"Inference Server Configuration": "推論サーバー設定",
|
33 |
-
"Inference Server Error": "推論サーバーエラー",
|
34 |
-
"Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
|
35 |
-
"Initial Learning Rate": "初期学習率",
|
36 |
-
"Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
|
37 |
-
"Input Text": "入力テキスト",
|
38 |
-
"Invalid path: {}": "無効なパス: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
|
40 |
-
"Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
|
41 |
-
"Japanese": "日本語",
|
42 |
-
"LLAMA Configuration": "LLAMA設定",
|
43 |
-
"LLAMA Model Config": "LLAMAモデル設定",
|
44 |
-
"LLAMA Model Path": "LLAMAモデルパス",
|
45 |
-
"Labeling Device": "ラベリングデバイス",
|
46 |
-
"LoRA Model to be merged": "マージするLoRAモデル",
|
47 |
-
"Maximum Audio Duration": "最大オーディオの長さ",
|
48 |
-
"Maximum Length per Sample": "サンプルあたりの最大長",
|
49 |
-
"Maximum Training Steps": "最大トレーニングステップ数",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
|
51 |
-
"Merge": "マージ",
|
52 |
-
"Merge LoRA": "LoRAのマージ",
|
53 |
-
"Merge successfully": "マージに成功しました",
|
54 |
-
"Minimum Audio Duration": "最小オーディオの長さ",
|
55 |
-
"Model Output Path": "モデル出力パス",
|
56 |
-
"Model Size": "モデルサイズ",
|
57 |
-
"Move": "移動",
|
58 |
-
"Move files successfully": "ファイルの移動に成功しました",
|
59 |
-
"No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
|
60 |
-
"No selected options": "選択されたオプションはありません",
|
61 |
-
"Number of Workers": "ワーカー数",
|
62 |
-
"Open Inference Server": "推論サーバーを開く",
|
63 |
-
"Open Labeler WebUI": "ラベラーWebUIを開く",
|
64 |
-
"Open Tensorboard": "Tensorboardを開く",
|
65 |
-
"Opened labeler in browser": "ブラウザでラベラーを開きました",
|
66 |
-
"Optional Label Language": "オプションのラベル言語",
|
67 |
-
"Optional online ver": "オプションのオンラインバージョン",
|
68 |
-
"Output Path": "出力パス",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
|
70 |
-
"Precision": "精度",
|
71 |
-
"Probability of applying Speaker Condition": "話者条件を適用する確率",
|
72 |
-
"Put your text here.": "ここにテキストを入力してください。",
|
73 |
-
"Reference Audio": "
|
74 |
-
"Reference Text": "リファレンステキスト",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
|
76 |
-
"Remove Selected Data": "選択したデータを削除",
|
77 |
-
"Removed path successfully!": "パスの削除に成功しました!",
|
78 |
-
"Repetition Penalty": "反復ペナルティ",
|
79 |
-
"Save model every n steps": "nステップごとにモデルを保存",
|
80 |
-
"Select LLAMA ckpt": " LLAMA チェックポイントを選択",
|
81 |
-
"Select VITS ckpt": "VITS チェックポイントを選択",
|
82 |
-
"Select VQGAN ckpt": "VQGAN チェックポイントを選択",
|
83 |
-
"Select source file processing method": "ソースファイルの処理方法を選択",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
|
85 |
-
"Selected: {}": "選択済み: {}",
|
86 |
-
"Speaker": "話者",
|
87 |
-
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
|
88 |
-
"Start Training": "トレーニング開始",
|
89 |
-
"Streaming Audio": "ストリーミングオーディオ",
|
90 |
-
"Streaming Generate": "ストリーミング合成",
|
91 |
-
"Tensorboard Host": "Tensorboardホスト",
|
92 |
-
"Tensorboard Log Path": "Tensorboardログパス",
|
93 |
-
"Tensorboard Port": "Tensorboardポート",
|
94 |
-
"Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
|
96 |
-
"Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
|
98 |
-
"Training Configuration": "トレーニング設定",
|
99 |
-
"Training Error": "トレーニングエラー",
|
100 |
-
"Training stopped": "トレーニングが停止しました",
|
101 |
-
"Type name of the speaker": "話者の名前を入力",
|
102 |
-
"Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
|
103 |
-
"Use LoRA": "LoRAを使用",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
|
105 |
-
"Use filelist": "ファイルリストを使用",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
|
107 |
-
"VITS Configuration": "VITS の構成",
|
108 |
-
"VQGAN Configuration": "VQGAN の構成",
|
109 |
-
"Validation Batch Size": "検証バッチサイズ",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
|
112 |
-
"WebUI Host": "WebUIホスト",
|
113 |
-
"WebUI Port": "WebUIポート",
|
114 |
-
"Whisper Model": "Whisperモデル",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
|
117 |
-
"latest": "最新",
|
118 |
-
"new": "新規",
|
119 |
-
"Realtime Transform Text": "リアルタイム変換テキスト",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
|
121 |
-
"Text Normalization": "テキスト正規化"
|
122 |
-
|
123 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
|
5 |
+
"Accumulate Gradient Batches": "勾配バッチの累積",
|
6 |
+
"Add to Processing Area": "処理エリアに追加",
|
7 |
+
"Added path successfully!": "パスの追加に成功しました!",
|
8 |
+
"Advanced Config": "詳細設定",
|
9 |
+
"Base LLAMA Model": "基本LLAMAモデル",
|
10 |
+
"Batch Inference": "バッチ推論",
|
11 |
+
"Batch Size": "バッチサイズ",
|
12 |
+
"Changing with the Model Path": "モデルのパスに伴って変化する",
|
13 |
+
"Chinese": "中国語",
|
14 |
+
"Compile Model": "モデルのコンパイル",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
|
16 |
+
"Copy": "コピー",
|
17 |
+
"Data Preprocessing": "データ前処理",
|
18 |
+
"Data Preprocessing Path": "データ前処理パス",
|
19 |
+
"Data Source": "データソース",
|
20 |
+
"Decoder Model Config": "デコーダーモデルの構成",
|
21 |
+
"Decoder Model Path": "デコーダーモデルのパス",
|
22 |
+
"Disabled": "無効",
|
23 |
+
"Enable Reference Audio": "リファレンスオーディオを有効にする",
|
24 |
+
"English": "英語",
|
25 |
+
"Error Message": "エラーメッセージ",
|
26 |
+
"File Preprocessing": "文書前处理",
|
27 |
+
"Generate": "生成",
|
28 |
+
"Generated Audio": "生成されたオーディオ",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
|
30 |
+
"Infer interface is closed": "推論インターフェースが閉じられています",
|
31 |
+
"Inference Configuration": "推論設定",
|
32 |
+
"Inference Server Configuration": "推論サーバー設定",
|
33 |
+
"Inference Server Error": "推論サーバーエラー",
|
34 |
+
"Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
|
35 |
+
"Initial Learning Rate": "初期学習率",
|
36 |
+
"Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
|
37 |
+
"Input Text": "入力テキスト",
|
38 |
+
"Invalid path: {}": "無効なパス: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
|
40 |
+
"Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
|
41 |
+
"Japanese": "日本語",
|
42 |
+
"LLAMA Configuration": "LLAMA設定",
|
43 |
+
"LLAMA Model Config": "LLAMAモデル設定",
|
44 |
+
"LLAMA Model Path": "LLAMAモデルパス",
|
45 |
+
"Labeling Device": "ラベリングデバイス",
|
46 |
+
"LoRA Model to be merged": "マージするLoRAモデル",
|
47 |
+
"Maximum Audio Duration": "最大オーディオの長さ",
|
48 |
+
"Maximum Length per Sample": "サンプルあたりの最大長",
|
49 |
+
"Maximum Training Steps": "最大トレーニングステップ数",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
|
51 |
+
"Merge": "マージ",
|
52 |
+
"Merge LoRA": "LoRAのマージ",
|
53 |
+
"Merge successfully": "マージに成功しました",
|
54 |
+
"Minimum Audio Duration": "最小オーディオの長さ",
|
55 |
+
"Model Output Path": "モデル出力パス",
|
56 |
+
"Model Size": "モデルサイズ",
|
57 |
+
"Move": "移動",
|
58 |
+
"Move files successfully": "ファイルの移動に成功しました",
|
59 |
+
"No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
|
60 |
+
"No selected options": "選択されたオプションはありません",
|
61 |
+
"Number of Workers": "ワーカー数",
|
62 |
+
"Open Inference Server": "推論サーバーを開く",
|
63 |
+
"Open Labeler WebUI": "ラベラーWebUIを開く",
|
64 |
+
"Open Tensorboard": "Tensorboardを開く",
|
65 |
+
"Opened labeler in browser": "ブラウザでラベラーを開きました",
|
66 |
+
"Optional Label Language": "オプションのラベル言語",
|
67 |
+
"Optional online ver": "オプションのオンラインバージョン",
|
68 |
+
"Output Path": "出力パス",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
|
70 |
+
"Precision": "精度",
|
71 |
+
"Probability of applying Speaker Condition": "話者条件を適用する確率",
|
72 |
+
"Put your text here.": "ここにテキストを入力してください。",
|
73 |
+
"Reference Audio": "リファレンスオーディオ",
|
74 |
+
"Reference Text": "リファレンステキスト",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
|
76 |
+
"Remove Selected Data": "選択したデータを削除",
|
77 |
+
"Removed path successfully!": "パスの削除に成功しました!",
|
78 |
+
"Repetition Penalty": "反復ペナルティ",
|
79 |
+
"Save model every n steps": "nステップごとにモデルを保存",
|
80 |
+
"Select LLAMA ckpt": " LLAMA チェックポイントを選択",
|
81 |
+
"Select VITS ckpt": "VITS チェックポイントを選択",
|
82 |
+
"Select VQGAN ckpt": "VQGAN チェックポイントを選択",
|
83 |
+
"Select source file processing method": "ソースファイルの処理方法を選択",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
|
85 |
+
"Selected: {}": "選択済み: {}",
|
86 |
+
"Speaker": "話者",
|
87 |
+
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
|
88 |
+
"Start Training": "トレーニング開始",
|
89 |
+
"Streaming Audio": "ストリーミングオーディオ",
|
90 |
+
"Streaming Generate": "ストリーミング合成",
|
91 |
+
"Tensorboard Host": "Tensorboardホスト",
|
92 |
+
"Tensorboard Log Path": "Tensorboardログパス",
|
93 |
+
"Tensorboard Port": "Tensorboardポート",
|
94 |
+
"Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
|
96 |
+
"Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
|
98 |
+
"Training Configuration": "トレーニング設定",
|
99 |
+
"Training Error": "トレーニングエラー",
|
100 |
+
"Training stopped": "トレーニングが停止しました",
|
101 |
+
"Type name of the speaker": "話者の名前を入力",
|
102 |
+
"Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
|
103 |
+
"Use LoRA": "LoRAを使用",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
|
105 |
+
"Use filelist": "ファイルリストを使用",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
|
107 |
+
"VITS Configuration": "VITS の構成",
|
108 |
+
"VQGAN Configuration": "VQGAN の構成",
|
109 |
+
"Validation Batch Size": "検証バッチサイズ",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
|
112 |
+
"WebUI Host": "WebUIホスト",
|
113 |
+
"WebUI Port": "WebUIポート",
|
114 |
+
"Whisper Model": "Whisperモデル",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
|
117 |
+
"latest": "最新",
|
118 |
+
"new": "新規",
|
119 |
+
"Realtime Transform Text": "リアルタイム変換テキスト",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
|
121 |
+
"Text Normalization": "テキスト正規化",
|
122 |
+
"Select Example Audio": "サンプル音声を選択"
|
123 |
+
}
|
fish_speech/i18n/locale/ko_KR.json
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
|
5 |
+
"Accumulate Gradient Batches": "그라디언트 배치 누적",
|
6 |
+
"Add to Processing Area": "처리 영역에 추가",
|
7 |
+
"Added path successfully!": "경로가 성공적으로 추가되었습니다!",
|
8 |
+
"Advanced Config": "고급 설정",
|
9 |
+
"Base LLAMA Model": "기본 LLAMA 모델",
|
10 |
+
"Batch Inference": "배치 추론",
|
11 |
+
"Batch Size": "배치 크기",
|
12 |
+
"Changing with the Model Path": "모델 경로에 따라 변경 중",
|
13 |
+
"Chinese": "중국어",
|
14 |
+
"Compile Model": "모델 컴파일",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
|
16 |
+
"Copy": "복사",
|
17 |
+
"Data Preprocessing": "데이터 전처리",
|
18 |
+
"Data Preprocessing Path": "데이터 전처리 경로",
|
19 |
+
"Data Source": "데이터 소스",
|
20 |
+
"Decoder Model Config": "디코더 모델 설정",
|
21 |
+
"Decoder Model Path": "디코더 모델 경로",
|
22 |
+
"Disabled": "비활성화 됨",
|
23 |
+
"Enable Reference Audio": "참고 음성 활성화",
|
24 |
+
"English": "영어",
|
25 |
+
"Error Message": "오류 메시지",
|
26 |
+
"File Preprocessing": "파일 전처리",
|
27 |
+
"Generate": "생성",
|
28 |
+
"Generated Audio": "생성된 오디오",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
|
30 |
+
"Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
|
31 |
+
"Inference Configuration": "추론 설정",
|
32 |
+
"Inference Server Configuration": "추론 서버 설정",
|
33 |
+
"Inference Server Error": "추론 서버 오류",
|
34 |
+
"Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
|
35 |
+
"Initial Learning Rate": "초기 학습률",
|
36 |
+
"Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
|
37 |
+
"Input Text": "입력 텍스트",
|
38 |
+
"Invalid path: {}": "유효하지 않은 경로: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
|
40 |
+
"Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
|
41 |
+
"Japanese": "일본어",
|
42 |
+
"LLAMA Configuration": "LLAMA 설정",
|
43 |
+
"LLAMA Model Config": "LLAMA 모델 설정",
|
44 |
+
"LLAMA Model Path": "LLAMA 모델 경로",
|
45 |
+
"Labeling Device": "라벨링 장치",
|
46 |
+
"LoRA Model to be merged": "병합할 LoRA 모델",
|
47 |
+
"Maximum Audio Duration": "최대 오디오 길이",
|
48 |
+
"Maximum Length per Sample": "샘플당 최대 길이",
|
49 |
+
"Maximum Training Steps": "최대 학습 단계",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
|
51 |
+
"Merge": "병합",
|
52 |
+
"Merge LoRA": "LoRA 병합",
|
53 |
+
"Merge successfully": "성공적으로 병합 되었습니다.",
|
54 |
+
"Minimum Audio Duration": "최소 오디오 길이",
|
55 |
+
"Model Output Path": "모델 출력 경로",
|
56 |
+
"Model Size": "모델 크기",
|
57 |
+
"Move": "이동",
|
58 |
+
"Move files successfully": "파일이 성공적으로 이동되었습니다.",
|
59 |
+
"No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
|
60 |
+
"No selected options": "옵션이 선택되지 않았습니다.",
|
61 |
+
"Number of Workers": "작업자 수",
|
62 |
+
"Open Inference Server": "추론 서버 열기",
|
63 |
+
"Open Labeler WebUI": "라벨러 WebUI 열기",
|
64 |
+
"Open Tensorboard": "Tensorboard 열기",
|
65 |
+
"Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
|
66 |
+
"Optional Label Language": "선택적 라벨 언어",
|
67 |
+
"Optional online ver": "온라인 버전 선택",
|
68 |
+
"Output Path": "출력 경로",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
|
70 |
+
"Precision": "정밀도",
|
71 |
+
"Probability of applying Speaker Condition": "화자 조건 적용 확률",
|
72 |
+
"Put your text here.": "여기에 텍스트를 입력하세요.",
|
73 |
+
"Reference Audio": "참고 오디오",
|
74 |
+
"Reference Text": "참고 텍스트",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
|
76 |
+
"Remove Selected Data": "선택한 데이터 제거",
|
77 |
+
"Removed path successfully!": "��로가 성공적으로 제거되었습니다!",
|
78 |
+
"Repetition Penalty": "반복 패널티",
|
79 |
+
"Save model every n steps": "n 단계마다 모델 저장",
|
80 |
+
"Select LLAMA ckpt": "LLAMA ckpt 선택",
|
81 |
+
"Select VITS ckpt": "VITS ckpt 선택",
|
82 |
+
"Select VQGAN ckpt": "VQGAN ckpt 선택",
|
83 |
+
"Select source file processing method": "소스 파일 처리 방법 선택",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
|
85 |
+
"Selected: {}": "선택됨: {}",
|
86 |
+
"Speaker": "화자",
|
87 |
+
"Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
|
88 |
+
"Start Training": "학습 시작",
|
89 |
+
"Streaming Audio": "스트리밍 오디오",
|
90 |
+
"Streaming Generate": "스트리밍 생성",
|
91 |
+
"Tensorboard Host": "Tensorboard 호스트",
|
92 |
+
"Tensorboard Log Path": "Tensorboard 로그 경로",
|
93 |
+
"Tensorboard Port": "Tensorboard 포트",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
|
96 |
+
"Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
|
98 |
+
"Training Configuration": "학습 설정",
|
99 |
+
"Training Error": "학습 오류",
|
100 |
+
"Training stopped": "학습이 중지되었습니다.",
|
101 |
+
"Type name of the speaker": "화자의 이름을 입력하세요.",
|
102 |
+
"Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
|
103 |
+
"Use LoRA": "LoRA 사용",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
|
105 |
+
"Use filelist": "파일 목록 사용",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
|
107 |
+
"VITS Configuration": "VITS 설정",
|
108 |
+
"VQGAN Configuration": "VQGAN 설정",
|
109 |
+
"Validation Batch Size": "검증 배치 크기",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
|
112 |
+
"WebUI Host": "WebUI 호스트",
|
113 |
+
"WebUI Port": "WebUI 포트",
|
114 |
+
"Whisper Model": "Whisper 모델",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
|
117 |
+
"latest": "최신",
|
118 |
+
"new": "새로운",
|
119 |
+
"Realtime Transform Text": "실시간 텍스트 변환",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
|
121 |
+
"Text Normalization": "텍스트 정규화",
|
122 |
+
"Select Example Audio": "예시 오디오 선택"
|
123 |
+
}
|
fish_speech/i18n/locale/pt_BR.json
CHANGED
@@ -1,133 +1,133 @@
|
|
1 |
-
{
|
2 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
|
3 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
|
4 |
-
"Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
|
5 |
-
"Add to Processing Area": "Adicionar à Área de Processamento",
|
6 |
-
"Added path successfully!": "Caminho adicionado com sucesso!",
|
7 |
-
"Advanced Config": "Configuração Avançada",
|
8 |
-
"Base LLAMA Model": "Modelo LLAMA Base",
|
9 |
-
"Batch Inference": "Inferência em Lote",
|
10 |
-
"Batch Size": "Tamanho do Lote",
|
11 |
-
"Changing with the Model Path": "Alterando com o Caminho do Modelo",
|
12 |
-
|
13 |
-
"Compile Model": "Compilar Modelo",
|
14 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
|
15 |
-
"Copy": "Copiar",
|
16 |
-
"Data Preprocessing": "Pré-processamento de Dados",
|
17 |
-
"Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
|
18 |
-
"Data Source": "Fonte de Dados",
|
19 |
-
"Decoder Model Config": "Configuração do Modelo Decodificador",
|
20 |
-
"Decoder Model Path": "Caminho do Modelo Decodificador",
|
21 |
-
"Disabled": "Desativado",
|
22 |
-
"Enable Initial Prompt": "Habilitar Prompt Inicial",
|
23 |
-
"Enable Reference Audio": "Habilitar Áudio de Referência",
|
24 |
-
"English": "Inglês",
|
25 |
-
"Japanese": "Japonês",
|
26 |
-
"Chinese": "Chinês",
|
27 |
-
"Portuguese": "Português",
|
28 |
-
"Spanish": "Espanhol",
|
29 |
-
"Error Message": "Mensagem de Erro",
|
30 |
-
"Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
|
31 |
-
"File Preprocessing": "Pré-processamento de Arquivos",
|
32 |
-
"Generate": "Gerar",
|
33 |
-
"Generated Audio": "Áudio Gerado",
|
34 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
|
35 |
-
"Infer interface is closed": "A interface de inferência foi fechada",
|
36 |
-
"Inference Configuration": "Configuração de Inferência",
|
37 |
-
"Inference Server Configuration": "Configuração do Servidor de Inferência",
|
38 |
-
"Inference Server Error": "Erro do Servidor de Inferência",
|
39 |
-
"Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
|
40 |
-
"Initial Learning Rate": "Taxa de Aprendizagem Inicial",
|
41 |
-
"Initial Prompt": "Prompt Inicial",
|
42 |
-
"Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
|
43 |
-
"Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
|
44 |
-
"Input Text": "Texto de Entrada",
|
45 |
-
"Invalid path: {}": "Caminho inválido: {}",
|
46 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
|
47 |
-
"Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
|
48 |
-
"LLAMA Configuration": "Configuração do LLAMA",
|
49 |
-
"LLAMA Model Config": "Configuração do Modelo LLAMA",
|
50 |
-
"LLAMA Model Path": "Caminho do Modelo LLAMA",
|
51 |
-
"Labeling Device": "Dispositivo de Rotulagem",
|
52 |
-
"LoRA Model to be merged": "Modelo LoRA para mesclagem",
|
53 |
-
"Maximum Length per Sample": "Comprimento Máximo por Amostra",
|
54 |
-
"Maximum Training Steps": "Etapas Máximas de Treinamento",
|
55 |
-
"Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
|
56 |
-
"Merge": "Mesclar",
|
57 |
-
"Merge LoRA": "Mesclar LoRA",
|
58 |
-
"Merge successfully": "Mesclado com sucesso",
|
59 |
-
"Model Output Path": "Caminho de Saída do Modelo",
|
60 |
-
"Model Quantization": "Quantização do Modelo",
|
61 |
-
"Model Size": "Tamanho do Modelo",
|
62 |
-
"Move": "Mover",
|
63 |
-
"Move files successfully": "Arquivos movidos com sucesso",
|
64 |
-
"No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
|
65 |
-
"No selected options": "Nenhuma opção selecionada",
|
66 |
-
"Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
|
67 |
-
"Number of Workers": "Número de Processos",
|
68 |
-
"Open Inference Server": "Abrir Servidor de Inferência",
|
69 |
-
"Open Labeler WebUI": "Abrir WebUI de Rotulagem",
|
70 |
-
"Open Tensorboard": "Abrir Tensorboard",
|
71 |
-
"Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
|
72 |
-
"Optional Label Language": "Idioma do Rótulo (Opcional)",
|
73 |
-
"Optional online ver": "Versão online (opcional)",
|
74 |
-
"Output Path": "Caminho de Saída",
|
75 |
-
"Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
|
76 |
-
"Post-quantification Precision": "Precisão Pós-quantização",
|
77 |
-
"Precision": "Precisão",
|
78 |
-
"Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
|
79 |
-
"Put your text here.": "Insira seu texto aqui.",
|
80 |
-
"Quantify": "Quantizar",
|
81 |
-
"Quantify successfully": "Quantizado com sucesso",
|
82 |
-
"Realtime Transform Text": "Transformar Texto em Tempo Real",
|
83 |
-
"Reference Audio": "Áudio de Referência",
|
84 |
-
"Reference Text": "Texto de Referência",
|
85 |
-
"warning": "Aviso",
|
86 |
-
"Pre-processing begins...": "O pré-processamento começou!",
|
87 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
|
88 |
-
"Remove Selected Data": "Remover Dados Selecionados",
|
89 |
-
"Removed path successfully!": "Caminho removido com sucesso!",
|
90 |
-
"Repetition Penalty": "Penalidade de Repetição",
|
91 |
-
"Save model every n steps": "Salvar modelo a cada n etapas",
|
92 |
-
"Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
|
93 |
-
"Select source file processing method": "Escolha como processar o arquivo de origem",
|
94 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
|
95 |
-
"Selected: {}": "Selecionado: {}",
|
96 |
-
"Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
|
97 |
-
"Start Training": "Iniciar Treinamento",
|
98 |
-
"Streaming Audio": "Áudio em Streaming",
|
99 |
-
"Streaming Generate": "Geração em Streaming",
|
100 |
-
"Tensorboard Host": "Host do Tensorboard",
|
101 |
-
"Tensorboard Log Path": "Caminho de Log do Tensorboard",
|
102 |
-
"Tensorboard Port": "Porta do Tensorboard",
|
103 |
-
"Tensorboard interface is closed": "A interface do Tensorboard está fechada",
|
104 |
-
"Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
|
105 |
-
"Text Normalization": "Normalização de Texto",
|
106 |
-
"Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
|
107 |
-
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
|
108 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
|
109 |
-
"Training Configuration": "Configuração de Treinamento",
|
110 |
-
"Training Error": "Erro de Treinamento",
|
111 |
-
"Training stopped": "Treinamento interrompido!",
|
112 |
-
"Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
|
113 |
-
"Use LoRA": "Usar LoRA",
|
114 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
|
115 |
-
"Use filelist": "Usar lista de arquivos",
|
116 |
-
"VQGAN Configuration": "Configuração do VQGAN",
|
117 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
|
118 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
|
119 |
-
"WebUI Host": "Host da WebUI",
|
120 |
-
"WebUI Port": "Porta da WebUI",
|
121 |
-
"Whisper Model": "Modelo Whisper",
|
122 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
|
123 |
-
"auto": "automático",
|
124 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
|
125 |
-
"latest": "mais recente",
|
126 |
-
"new": "novo",
|
127 |
-
"This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
|
128 |
-
"You don't need to train this model!": "Não é necessário treinar este modelo!",
|
129 |
-
"Yes": "Sim",
|
130 |
-
"No": "Não",
|
131 |
-
"version:": "versão:",
|
132 |
-
"author:": "autor:"
|
133 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
|
3 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
|
4 |
+
"Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
|
5 |
+
"Add to Processing Area": "Adicionar à Área de Processamento",
|
6 |
+
"Added path successfully!": "Caminho adicionado com sucesso!",
|
7 |
+
"Advanced Config": "Configuração Avançada",
|
8 |
+
"Base LLAMA Model": "Modelo LLAMA Base",
|
9 |
+
"Batch Inference": "Inferência em Lote",
|
10 |
+
"Batch Size": "Tamanho do Lote",
|
11 |
+
"Changing with the Model Path": "Alterando com o Caminho do Modelo",
|
12 |
+
|
13 |
+
"Compile Model": "Compilar Modelo",
|
14 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
|
15 |
+
"Copy": "Copiar",
|
16 |
+
"Data Preprocessing": "Pré-processamento de Dados",
|
17 |
+
"Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
|
18 |
+
"Data Source": "Fonte de Dados",
|
19 |
+
"Decoder Model Config": "Configuração do Modelo Decodificador",
|
20 |
+
"Decoder Model Path": "Caminho do Modelo Decodificador",
|
21 |
+
"Disabled": "Desativado",
|
22 |
+
"Enable Initial Prompt": "Habilitar Prompt Inicial",
|
23 |
+
"Enable Reference Audio": "Habilitar Áudio de Referência",
|
24 |
+
"English": "Inglês",
|
25 |
+
"Japanese": "Japonês",
|
26 |
+
"Chinese": "Chinês",
|
27 |
+
"Portuguese": "Português",
|
28 |
+
"Spanish": "Espanhol",
|
29 |
+
"Error Message": "Mensagem de Erro",
|
30 |
+
"Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
|
31 |
+
"File Preprocessing": "Pré-processamento de Arquivos",
|
32 |
+
"Generate": "Gerar",
|
33 |
+
"Generated Audio": "Áudio Gerado",
|
34 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
|
35 |
+
"Infer interface is closed": "A interface de inferência foi fechada",
|
36 |
+
"Inference Configuration": "Configuração de Inferência",
|
37 |
+
"Inference Server Configuration": "Configuração do Servidor de Inferência",
|
38 |
+
"Inference Server Error": "Erro do Servidor de Inferência",
|
39 |
+
"Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
|
40 |
+
"Initial Learning Rate": "Taxa de Aprendizagem Inicial",
|
41 |
+
"Initial Prompt": "Prompt Inicial",
|
42 |
+
"Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
|
43 |
+
"Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
|
44 |
+
"Input Text": "Texto de Entrada",
|
45 |
+
"Invalid path: {}": "Caminho inválido: {}",
|
46 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
|
47 |
+
"Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
|
48 |
+
"LLAMA Configuration": "Configuração do LLAMA",
|
49 |
+
"LLAMA Model Config": "Configuração do Modelo LLAMA",
|
50 |
+
"LLAMA Model Path": "Caminho do Modelo LLAMA",
|
51 |
+
"Labeling Device": "Dispositivo de Rotulagem",
|
52 |
+
"LoRA Model to be merged": "Modelo LoRA para mesclagem",
|
53 |
+
"Maximum Length per Sample": "Comprimento Máximo por Amostra",
|
54 |
+
"Maximum Training Steps": "Etapas Máximas de Treinamento",
|
55 |
+
"Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
|
56 |
+
"Merge": "Mesclar",
|
57 |
+
"Merge LoRA": "Mesclar LoRA",
|
58 |
+
"Merge successfully": "Mesclado com sucesso",
|
59 |
+
"Model Output Path": "Caminho de Saída do Modelo",
|
60 |
+
"Model Quantization": "Quantização do Modelo",
|
61 |
+
"Model Size": "Tamanho do Modelo",
|
62 |
+
"Move": "Mover",
|
63 |
+
"Move files successfully": "Arquivos movidos com sucesso",
|
64 |
+
"No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
|
65 |
+
"No selected options": "Nenhuma opção selecionada",
|
66 |
+
"Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
|
67 |
+
"Number of Workers": "Número de Processos",
|
68 |
+
"Open Inference Server": "Abrir Servidor de Inferência",
|
69 |
+
"Open Labeler WebUI": "Abrir WebUI de Rotulagem",
|
70 |
+
"Open Tensorboard": "Abrir Tensorboard",
|
71 |
+
"Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
|
72 |
+
"Optional Label Language": "Idioma do Rótulo (Opcional)",
|
73 |
+
"Optional online ver": "Versão online (opcional)",
|
74 |
+
"Output Path": "Caminho de Saída",
|
75 |
+
"Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
|
76 |
+
"Post-quantification Precision": "Precisão Pós-quantização",
|
77 |
+
"Precision": "Precisão",
|
78 |
+
"Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
|
79 |
+
"Put your text here.": "Insira seu texto aqui.",
|
80 |
+
"Quantify": "Quantizar",
|
81 |
+
"Quantify successfully": "Quantizado com sucesso",
|
82 |
+
"Realtime Transform Text": "Transformar Texto em Tempo Real",
|
83 |
+
"Reference Audio": "Áudio de Referência",
|
84 |
+
"Reference Text": "Texto de Referência",
|
85 |
+
"warning": "Aviso",
|
86 |
+
"Pre-processing begins...": "O pré-processamento começou!",
|
87 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
|
88 |
+
"Remove Selected Data": "Remover Dados Selecionados",
|
89 |
+
"Removed path successfully!": "Caminho removido com sucesso!",
|
90 |
+
"Repetition Penalty": "Penalidade de Repetição",
|
91 |
+
"Save model every n steps": "Salvar modelo a cada n etapas",
|
92 |
+
"Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
|
93 |
+
"Select source file processing method": "Escolha como processar o arquivo de origem",
|
94 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
|
95 |
+
"Selected: {}": "Selecionado: {}",
|
96 |
+
"Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
|
97 |
+
"Start Training": "Iniciar Treinamento",
|
98 |
+
"Streaming Audio": "Áudio em Streaming",
|
99 |
+
"Streaming Generate": "Geração em Streaming",
|
100 |
+
"Tensorboard Host": "Host do Tensorboard",
|
101 |
+
"Tensorboard Log Path": "Caminho de Log do Tensorboard",
|
102 |
+
"Tensorboard Port": "Porta do Tensorboard",
|
103 |
+
"Tensorboard interface is closed": "A interface do Tensorboard está fechada",
|
104 |
+
"Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
|
105 |
+
"Text Normalization": "Normalização de Texto",
|
106 |
+
"Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
|
107 |
+
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
|
108 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
|
109 |
+
"Training Configuration": "Configuração de Treinamento",
|
110 |
+
"Training Error": "Erro de Treinamento",
|
111 |
+
"Training stopped": "Treinamento interrompido!",
|
112 |
+
"Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
|
113 |
+
"Use LoRA": "Usar LoRA",
|
114 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
|
115 |
+
"Use filelist": "Usar lista de arquivos",
|
116 |
+
"VQGAN Configuration": "Configuração do VQGAN",
|
117 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
|
118 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
|
119 |
+
"WebUI Host": "Host da WebUI",
|
120 |
+
"WebUI Port": "Porta da WebUI",
|
121 |
+
"Whisper Model": "Modelo Whisper",
|
122 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
|
123 |
+
"auto": "automático",
|
124 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
|
125 |
+
"latest": "mais recente",
|
126 |
+
"new": "novo",
|
127 |
+
"This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
|
128 |
+
"You don't need to train this model!": "Não é necessário treinar este modelo!",
|
129 |
+
"Yes": "Sim",
|
130 |
+
"No": "Não",
|
131 |
+
"version:": "versão:",
|
132 |
+
"author:": "autor:"
|
133 |
+
}
|
fish_speech/i18n/locale/zh_CN.json
CHANGED
@@ -1,122 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
|
5 |
-
"Accumulate Gradient Batches": "梯度累积批次",
|
6 |
-
"Add to Processing Area": "加入处理区",
|
7 |
-
"Added path successfully!": "添加路径成功!",
|
8 |
-
"Advanced Config": "高级参数",
|
9 |
-
"Base LLAMA Model": "基础 LLAMA 模型",
|
10 |
-
"Batch Inference": "批量推理",
|
11 |
-
"Batch Size": "批次大小",
|
12 |
-
"Changing with the Model Path": "随模型路径变化",
|
13 |
-
"Chinese": "中文",
|
14 |
-
"Compile Model": "编译模型",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
|
16 |
-
"Copy": "复制",
|
17 |
-
"Data Preprocessing": "数据预处理",
|
18 |
-
"Data Preprocessing Path": "数据预处理路径",
|
19 |
-
"Data Source": "数据源",
|
20 |
-
"Decoder Model Config": "解码器模型配置",
|
21 |
-
"Decoder Model Path": "解码器模型路径",
|
22 |
-
"Disabled": "禁用",
|
23 |
-
"Enable Reference Audio": "启用参考音频",
|
24 |
-
"English": "英文",
|
25 |
-
"Error Message": "错误信息",
|
26 |
-
"File Preprocessing": "文件预处理",
|
27 |
-
"Generate": "生成",
|
28 |
-
"Generated Audio": "音频",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
|
30 |
-
"Infer interface is closed": "推理界面已关闭",
|
31 |
-
"Inference Configuration": "推理配置",
|
32 |
-
"Inference Server Configuration": "推理服务器配置",
|
33 |
-
"Inference Server Error": "推理服务器错误",
|
34 |
-
"Inferring interface is launched at {}": "推理界面已在 {} 上启动",
|
35 |
-
"Initial Learning Rate": "初始学习率",
|
36 |
-
"Input Audio & Source Path for Transcription": "输入音频和转录源路径",
|
37 |
-
"Input Text": "输入文本",
|
38 |
-
"Invalid path: {}": "无效路径: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA
|
40 |
-
"Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
|
41 |
-
"Japanese": "日文",
|
42 |
-
"LLAMA Configuration": "LLAMA 配置",
|
43 |
-
"LLAMA Model Config": "LLAMA 模型配置",
|
44 |
-
"LLAMA Model Path": "LLAMA 模型路径",
|
45 |
-
"Labeling Device": "标注加速设备",
|
46 |
-
"LoRA Model to be merged": "要合并的 LoRA 模型",
|
47 |
-
"Maximum Audio Duration": "最大音频时长",
|
48 |
-
"Maximum Length per Sample": "每个样本的最大长度",
|
49 |
-
"Maximum Training Steps": "最大训练步数",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
|
51 |
-
"Merge": "合并",
|
52 |
-
"Merge LoRA": "合并 LoRA",
|
53 |
-
"Merge successfully": "合并成功",
|
54 |
-
"Minimum Audio Duration": "最小音频时长",
|
55 |
-
"Model Output Path": "模型输出路径",
|
56 |
-
"Model Size": "模型规模",
|
57 |
-
"Move": "移动",
|
58 |
-
"Move files successfully": "移动文件成功",
|
59 |
-
"No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
|
60 |
-
"No selected options": "没有选择的选项",
|
61 |
-
"Number of Workers": "数据加载进程数",
|
62 |
-
"Open Inference Server": "打开推理服务器",
|
63 |
-
"Open Labeler WebUI": "打开标注工具",
|
64 |
-
"Open Tensorboard": "打开 Tensorboard",
|
65 |
-
"Opened labeler in browser": "在浏览器中打开标注工具",
|
66 |
-
"Optional Label Language": "[可选] 标注语言",
|
67 |
-
"Optional online ver": "[可选] 使用在线版",
|
68 |
-
"Output Path": "输出路径",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
|
70 |
-
"Precision": "精度",
|
71 |
-
"Probability of applying Speaker Condition": "应用说话人条件的概率",
|
72 |
-
"Put your text here.": "在此处输入文本.",
|
73 |
-
"Reference Audio": "参考音频",
|
74 |
-
"Reference Text": "参考文本",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
|
76 |
-
"Remove Selected Data": "移除选中数据",
|
77 |
-
"Removed path successfully!": "移除路径成功!",
|
78 |
-
"Repetition Penalty": "重复惩罚",
|
79 |
-
"Save model every n steps": "每 n 步保存模型",
|
80 |
-
"Select LLAMA ckpt": "选择 LLAMA 检查点",
|
81 |
-
"Select VITS ckpt": "选择 VITS 检查点",
|
82 |
-
"Select VQGAN ckpt": "选择 VQGAN 检查点",
|
83 |
-
"Select source file processing method": "选择源文件处理方法",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
|
85 |
-
"Selected: {}": "已选择: {}",
|
86 |
-
"Speaker": "说话人",
|
87 |
-
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
|
88 |
-
"Start Training": "开始训练",
|
89 |
-
"Streaming Audio": "流式音频",
|
90 |
-
"Streaming Generate": "流式合成",
|
91 |
-
"Tensorboard Host": "Tensorboard 监听地址",
|
92 |
-
"Tensorboard Log Path": "Tensorboard 日志路径",
|
93 |
-
"Tensorboard Port": "Tensorboard 端口",
|
94 |
-
"Tensorboard interface is closed": "Tensorboard 界面已关闭",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
|
96 |
-
"Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
|
98 |
-
"Training Configuration": "训练配置",
|
99 |
-
"Training Error": "训练错误",
|
100 |
-
"Training stopped": "训练已停止",
|
101 |
-
"Type name of the speaker": "输入说话人的名称",
|
102 |
-
"Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
|
103 |
-
"Use LoRA": "使用 LoRA",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
|
105 |
-
"Use filelist": "使用文件列表",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
|
107 |
-
"VITS Configuration": "VITS 配置",
|
108 |
-
"VQGAN Configuration": "VQGAN 配置",
|
109 |
-
"Validation Batch Size": "验证批次大小",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
|
112 |
-
"WebUI Host": "WebUI 监听地址",
|
113 |
-
"WebUI Port": "WebUI 端口",
|
114 |
-
"Whisper Model": "Whisper 模型",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
|
117 |
-
"latest": "最近的检查点",
|
118 |
-
"new": "创建新的检查点",
|
119 |
-
"Realtime Transform Text": "实时规范化文本",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
|
121 |
-
"Text Normalization": "文本规范化"
|
122 |
-
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
|
5 |
+
"Accumulate Gradient Batches": "梯度累积批次",
|
6 |
+
"Add to Processing Area": "加入处理区",
|
7 |
+
"Added path successfully!": "添加路径成功!",
|
8 |
+
"Advanced Config": "高级参数",
|
9 |
+
"Base LLAMA Model": "基础 LLAMA 模型",
|
10 |
+
"Batch Inference": "批量推理",
|
11 |
+
"Batch Size": "批次大小",
|
12 |
+
"Changing with the Model Path": "随模型路径变化",
|
13 |
+
"Chinese": "中文",
|
14 |
+
"Compile Model": "编译模型",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
|
16 |
+
"Copy": "复制",
|
17 |
+
"Data Preprocessing": "数据预处理",
|
18 |
+
"Data Preprocessing Path": "数据预处理路径",
|
19 |
+
"Data Source": "数据源",
|
20 |
+
"Decoder Model Config": "解码器模型配置",
|
21 |
+
"Decoder Model Path": "解码器模型路径",
|
22 |
+
"Disabled": "禁用",
|
23 |
+
"Enable Reference Audio": "启用参考音频",
|
24 |
+
"English": "英文",
|
25 |
+
"Error Message": "错误信息",
|
26 |
+
"File Preprocessing": "文件预处理",
|
27 |
+
"Generate": "生成",
|
28 |
+
"Generated Audio": "音频",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
|
30 |
+
"Infer interface is closed": "推理界面已关闭",
|
31 |
+
"Inference Configuration": "推理配置",
|
32 |
+
"Inference Server Configuration": "推理服务器配置",
|
33 |
+
"Inference Server Error": "推理服务器错误",
|
34 |
+
"Inferring interface is launched at {}": "推理界面已在 {} 上启动",
|
35 |
+
"Initial Learning Rate": "初始学习率",
|
36 |
+
"Input Audio & Source Path for Transcription": "输入音频和转录源路径",
|
37 |
+
"Input Text": "输入文本",
|
38 |
+
"Invalid path: {}": "无效路径: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使��� CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
|
41 |
+
"Japanese": "日文",
|
42 |
+
"LLAMA Configuration": "LLAMA 配置",
|
43 |
+
"LLAMA Model Config": "LLAMA 模型配置",
|
44 |
+
"LLAMA Model Path": "LLAMA 模型路径",
|
45 |
+
"Labeling Device": "标注加速设备",
|
46 |
+
"LoRA Model to be merged": "要合并的 LoRA 模型",
|
47 |
+
"Maximum Audio Duration": "最大音频时长",
|
48 |
+
"Maximum Length per Sample": "每个样本的最大长度",
|
49 |
+
"Maximum Training Steps": "最大训练步数",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
|
51 |
+
"Merge": "合并",
|
52 |
+
"Merge LoRA": "合并 LoRA",
|
53 |
+
"Merge successfully": "合并成功",
|
54 |
+
"Minimum Audio Duration": "最小音频时长",
|
55 |
+
"Model Output Path": "模型输出路径",
|
56 |
+
"Model Size": "模型规模",
|
57 |
+
"Move": "移动",
|
58 |
+
"Move files successfully": "移动文件成功",
|
59 |
+
"No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
|
60 |
+
"No selected options": "没有选择的选项",
|
61 |
+
"Number of Workers": "数据加载进程数",
|
62 |
+
"Open Inference Server": "打开推理服务器",
|
63 |
+
"Open Labeler WebUI": "打开标注工具",
|
64 |
+
"Open Tensorboard": "打开 Tensorboard",
|
65 |
+
"Opened labeler in browser": "在浏览器中打开标注工具",
|
66 |
+
"Optional Label Language": "[可选] 标注语言",
|
67 |
+
"Optional online ver": "[可选] 使用在线版",
|
68 |
+
"Output Path": "输出路径",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
|
70 |
+
"Precision": "精度",
|
71 |
+
"Probability of applying Speaker Condition": "应用说话人条件的概率",
|
72 |
+
"Put your text here.": "在此处输入文本.",
|
73 |
+
"Reference Audio": "参考音频",
|
74 |
+
"Reference Text": "参考文本",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
|
76 |
+
"Remove Selected Data": "移除选中数据",
|
77 |
+
"Removed path successfully!": "移除路径成功!",
|
78 |
+
"Repetition Penalty": "重复惩罚",
|
79 |
+
"Save model every n steps": "每 n 步保存模型",
|
80 |
+
"Select LLAMA ckpt": "选择 LLAMA 检查点",
|
81 |
+
"Select VITS ckpt": "选择 VITS 检查点",
|
82 |
+
"Select VQGAN ckpt": "选择 VQGAN 检查点",
|
83 |
+
"Select source file processing method": "选择源文件处理方法",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
|
85 |
+
"Selected: {}": "已选择: {}",
|
86 |
+
"Speaker": "说话人",
|
87 |
+
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
|
88 |
+
"Start Training": "开始训练",
|
89 |
+
"Streaming Audio": "流式音频",
|
90 |
+
"Streaming Generate": "流式合成",
|
91 |
+
"Tensorboard Host": "Tensorboard 监听地址",
|
92 |
+
"Tensorboard Log Path": "Tensorboard 日志路径",
|
93 |
+
"Tensorboard Port": "Tensorboard 端口",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard 界面已关闭",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
|
96 |
+
"Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
|
98 |
+
"Training Configuration": "训练配置",
|
99 |
+
"Training Error": "训练错误",
|
100 |
+
"Training stopped": "训练已停止",
|
101 |
+
"Type name of the speaker": "输入说话人的名称",
|
102 |
+
"Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
|
103 |
+
"Use LoRA": "使用 LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
|
105 |
+
"Use filelist": "使用文件列表",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
|
107 |
+
"VITS Configuration": "VITS 配置",
|
108 |
+
"VQGAN Configuration": "VQGAN 配置",
|
109 |
+
"Validation Batch Size": "验证批次大小",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
|
112 |
+
"WebUI Host": "WebUI 监听地址",
|
113 |
+
"WebUI Port": "WebUI 端口",
|
114 |
+
"Whisper Model": "Whisper 模型",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
|
117 |
+
"latest": "最近的检查点",
|
118 |
+
"new": "创建新的检查点",
|
119 |
+
"Realtime Transform Text": "实时规范化文本",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
|
121 |
+
"Text Normalization": "文本规范化",
|
122 |
+
"Select Example Audio": "选择参考音频"
|
123 |
+
}
|
fish_speech/i18n/scan.py
CHANGED
@@ -1,122 +1,122 @@
|
|
1 |
-
import ast
|
2 |
-
import glob
|
3 |
-
import json
|
4 |
-
from collections import OrderedDict
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
from loguru import logger
|
8 |
-
|
9 |
-
from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
|
10 |
-
|
11 |
-
|
12 |
-
def extract_i18n_strings(node):
|
13 |
-
i18n_strings = []
|
14 |
-
|
15 |
-
if (
|
16 |
-
isinstance(node, ast.Call)
|
17 |
-
and isinstance(node.func, ast.Name)
|
18 |
-
and node.func.id == "i18n"
|
19 |
-
):
|
20 |
-
for arg in node.args:
|
21 |
-
if isinstance(arg, ast.Str):
|
22 |
-
i18n_strings.append(arg.s)
|
23 |
-
|
24 |
-
for child_node in ast.iter_child_nodes(node):
|
25 |
-
i18n_strings.extend(extract_i18n_strings(child_node))
|
26 |
-
|
27 |
-
return i18n_strings
|
28 |
-
|
29 |
-
|
30 |
-
# scan the directory for all .py files (recursively)
|
31 |
-
# for each file, parse the code into an AST
|
32 |
-
# for each AST, extract the i18n strings
|
33 |
-
|
34 |
-
strings = []
|
35 |
-
folders = ["fish_speech", "tools"]
|
36 |
-
# for filename in glob.iglob("**/*.py", recursive=True):
|
37 |
-
for folder in folders:
|
38 |
-
for f in Path(folder).rglob("*.py"):
|
39 |
-
code = f.read_text(encoding="utf-8")
|
40 |
-
if "i18n(" in code:
|
41 |
-
tree = ast.parse(code)
|
42 |
-
i18n_strings = extract_i18n_strings(tree)
|
43 |
-
logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
|
44 |
-
strings.extend(i18n_strings)
|
45 |
-
|
46 |
-
code_keys = set(strings)
|
47 |
-
logger.info(f"Total unique: {len(code_keys)}")
|
48 |
-
|
49 |
-
|
50 |
-
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
51 |
-
with open(standard_file, "r", encoding="utf-8") as f:
|
52 |
-
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
53 |
-
standard_keys = set(standard_data.keys())
|
54 |
-
|
55 |
-
# Define the standard file name
|
56 |
-
unused_keys = standard_keys - code_keys
|
57 |
-
logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
|
58 |
-
for unused_key in unused_keys:
|
59 |
-
logger.info(f"\t{unused_key}")
|
60 |
-
|
61 |
-
missing_keys = code_keys - standard_keys
|
62 |
-
logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
|
63 |
-
for missing_key in missing_keys:
|
64 |
-
logger.info(f"\t{missing_key}")
|
65 |
-
|
66 |
-
code_keys_dict = OrderedDict()
|
67 |
-
for s in strings:
|
68 |
-
code_keys_dict[s] = s
|
69 |
-
|
70 |
-
# write back
|
71 |
-
with open(standard_file, "w", encoding="utf-8") as f:
|
72 |
-
json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
|
73 |
-
f.write("\n")
|
74 |
-
|
75 |
-
logger.info(f"Updated {standard_file}")
|
76 |
-
|
77 |
-
|
78 |
-
# Define the standard file name
|
79 |
-
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
80 |
-
|
81 |
-
# Find all JSON files in the directory
|
82 |
-
dir_path = I18N_FILE_PATH
|
83 |
-
languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
|
84 |
-
|
85 |
-
# Load the standard file
|
86 |
-
with open(standard_file, "r", encoding="utf-8") as f:
|
87 |
-
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
88 |
-
|
89 |
-
# Loop through each language file
|
90 |
-
for lang_file in languages:
|
91 |
-
# Load the language file
|
92 |
-
with open(lang_file, "r", encoding="utf-8") as f:
|
93 |
-
lang_data = json.load(f, object_pairs_hook=OrderedDict)
|
94 |
-
|
95 |
-
# Find the difference between the language file and the standard file
|
96 |
-
diff = set(standard_data.keys()) - set(lang_data.keys())
|
97 |
-
|
98 |
-
miss = set(lang_data.keys()) - set(standard_data.keys())
|
99 |
-
|
100 |
-
# Add any missing keys to the language file
|
101 |
-
for key in diff:
|
102 |
-
lang_data[key] = "#!" + key
|
103 |
-
logger.info(f"Added missing key: {key} to {lang_file}")
|
104 |
-
|
105 |
-
# Del any extra keys to the language file
|
106 |
-
for key in miss:
|
107 |
-
del lang_data[key]
|
108 |
-
logger.info(f"Del extra key: {key} from {lang_file}")
|
109 |
-
|
110 |
-
# Sort the keys of the language file to match the order of the standard file
|
111 |
-
lang_data = OrderedDict(
|
112 |
-
sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
|
113 |
-
)
|
114 |
-
|
115 |
-
# Save the updated language file
|
116 |
-
with open(lang_file, "w", encoding="utf-8") as f:
|
117 |
-
json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
|
118 |
-
f.write("\n")
|
119 |
-
|
120 |
-
logger.info(f"Updated {lang_file}")
|
121 |
-
|
122 |
-
logger.info("Done")
|
|
|
1 |
+
import ast
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
from collections import OrderedDict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
|
10 |
+
|
11 |
+
|
12 |
+
def extract_i18n_strings(node):
|
13 |
+
i18n_strings = []
|
14 |
+
|
15 |
+
if (
|
16 |
+
isinstance(node, ast.Call)
|
17 |
+
and isinstance(node.func, ast.Name)
|
18 |
+
and node.func.id == "i18n"
|
19 |
+
):
|
20 |
+
for arg in node.args:
|
21 |
+
if isinstance(arg, ast.Str):
|
22 |
+
i18n_strings.append(arg.s)
|
23 |
+
|
24 |
+
for child_node in ast.iter_child_nodes(node):
|
25 |
+
i18n_strings.extend(extract_i18n_strings(child_node))
|
26 |
+
|
27 |
+
return i18n_strings
|
28 |
+
|
29 |
+
|
30 |
+
# scan the directory for all .py files (recursively)
|
31 |
+
# for each file, parse the code into an AST
|
32 |
+
# for each AST, extract the i18n strings
|
33 |
+
|
34 |
+
strings = []
|
35 |
+
folders = ["fish_speech", "tools"]
|
36 |
+
# for filename in glob.iglob("**/*.py", recursive=True):
|
37 |
+
for folder in folders:
|
38 |
+
for f in Path(folder).rglob("*.py"):
|
39 |
+
code = f.read_text(encoding="utf-8")
|
40 |
+
if "i18n(" in code:
|
41 |
+
tree = ast.parse(code)
|
42 |
+
i18n_strings = extract_i18n_strings(tree)
|
43 |
+
logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
|
44 |
+
strings.extend(i18n_strings)
|
45 |
+
|
46 |
+
code_keys = set(strings)
|
47 |
+
logger.info(f"Total unique: {len(code_keys)}")
|
48 |
+
|
49 |
+
|
50 |
+
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
51 |
+
with open(standard_file, "r", encoding="utf-8") as f:
|
52 |
+
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
53 |
+
standard_keys = set(standard_data.keys())
|
54 |
+
|
55 |
+
# Define the standard file name
|
56 |
+
unused_keys = standard_keys - code_keys
|
57 |
+
logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
|
58 |
+
for unused_key in unused_keys:
|
59 |
+
logger.info(f"\t{unused_key}")
|
60 |
+
|
61 |
+
missing_keys = code_keys - standard_keys
|
62 |
+
logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
|
63 |
+
for missing_key in missing_keys:
|
64 |
+
logger.info(f"\t{missing_key}")
|
65 |
+
|
66 |
+
code_keys_dict = OrderedDict()
|
67 |
+
for s in strings:
|
68 |
+
code_keys_dict[s] = s
|
69 |
+
|
70 |
+
# write back
|
71 |
+
with open(standard_file, "w", encoding="utf-8") as f:
|
72 |
+
json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
|
73 |
+
f.write("\n")
|
74 |
+
|
75 |
+
logger.info(f"Updated {standard_file}")
|
76 |
+
|
77 |
+
|
78 |
+
# Define the standard file name
|
79 |
+
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
80 |
+
|
81 |
+
# Find all JSON files in the directory
|
82 |
+
dir_path = I18N_FILE_PATH
|
83 |
+
languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
|
84 |
+
|
85 |
+
# Load the standard file
|
86 |
+
with open(standard_file, "r", encoding="utf-8") as f:
|
87 |
+
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
88 |
+
|
89 |
+
# Loop through each language file
|
90 |
+
for lang_file in languages:
|
91 |
+
# Load the language file
|
92 |
+
with open(lang_file, "r", encoding="utf-8") as f:
|
93 |
+
lang_data = json.load(f, object_pairs_hook=OrderedDict)
|
94 |
+
|
95 |
+
# Find the difference between the language file and the standard file
|
96 |
+
diff = set(standard_data.keys()) - set(lang_data.keys())
|
97 |
+
|
98 |
+
miss = set(lang_data.keys()) - set(standard_data.keys())
|
99 |
+
|
100 |
+
# Add any missing keys to the language file
|
101 |
+
for key in diff:
|
102 |
+
lang_data[key] = "#!" + key
|
103 |
+
logger.info(f"Added missing key: {key} to {lang_file}")
|
104 |
+
|
105 |
+
# Del any extra keys to the language file
|
106 |
+
for key in miss:
|
107 |
+
del lang_data[key]
|
108 |
+
logger.info(f"Del extra key: {key} from {lang_file}")
|
109 |
+
|
110 |
+
# Sort the keys of the language file to match the order of the standard file
|
111 |
+
lang_data = OrderedDict(
|
112 |
+
sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
|
113 |
+
)
|
114 |
+
|
115 |
+
# Save the updated language file
|
116 |
+
with open(lang_file, "w", encoding="utf-8") as f:
|
117 |
+
json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
|
118 |
+
f.write("\n")
|
119 |
+
|
120 |
+
logger.info(f"Updated {lang_file}")
|
121 |
+
|
122 |
+
logger.info("Done")
|
fish_speech/models/text2semantic/lit_module.py
CHANGED
@@ -1,202 +1,202 @@
|
|
1 |
-
from typing import Any, Optional
|
2 |
-
|
3 |
-
import lightning as L
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
7 |
-
|
8 |
-
import fish_speech.utils as utils
|
9 |
-
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
10 |
-
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
11 |
-
|
12 |
-
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
13 |
-
|
14 |
-
|
15 |
-
class TextToSemantic(L.LightningModule):
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
model: NaiveTransformer,
|
19 |
-
optimizer: Any,
|
20 |
-
lr_scheduler: Any,
|
21 |
-
):
|
22 |
-
super().__init__()
|
23 |
-
|
24 |
-
self.model = model
|
25 |
-
self.optimizer_builder = optimizer
|
26 |
-
self.lr_scheduler_builder = lr_scheduler
|
27 |
-
|
28 |
-
def forward(self, x):
|
29 |
-
return self.model(x)
|
30 |
-
|
31 |
-
def on_save_checkpoint(self, checkpoint):
|
32 |
-
# Save only LoRA parameters
|
33 |
-
state_dict = checkpoint["state_dict"]
|
34 |
-
use_lora = any("lora" in name for name in state_dict.keys())
|
35 |
-
if not use_lora:
|
36 |
-
return
|
37 |
-
|
38 |
-
for name in list(state_dict.keys()):
|
39 |
-
if "lora" not in name:
|
40 |
-
state_dict.pop(name)
|
41 |
-
|
42 |
-
def configure_optimizers(self) -> OptimizerLRScheduler:
|
43 |
-
# Get weight decay parameters
|
44 |
-
weight_decay_parameters, other_parameters = [], []
|
45 |
-
for name, param in self.named_parameters():
|
46 |
-
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
47 |
-
other_parameters.append(param)
|
48 |
-
else:
|
49 |
-
weight_decay_parameters.append(param)
|
50 |
-
|
51 |
-
optimizer = self.optimizer_builder(
|
52 |
-
[
|
53 |
-
{"params": weight_decay_parameters},
|
54 |
-
{"params": other_parameters, "weight_decay": 0.0},
|
55 |
-
]
|
56 |
-
)
|
57 |
-
|
58 |
-
# Print the parameters and their weight decay
|
59 |
-
for i in optimizer.param_groups:
|
60 |
-
log.info(
|
61 |
-
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
62 |
-
)
|
63 |
-
|
64 |
-
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
65 |
-
|
66 |
-
return {
|
67 |
-
"optimizer": optimizer,
|
68 |
-
"lr_scheduler": {
|
69 |
-
"scheduler": lr_scheduler,
|
70 |
-
"interval": "step",
|
71 |
-
},
|
72 |
-
}
|
73 |
-
|
74 |
-
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
75 |
-
def get_batch_logps(
|
76 |
-
self,
|
77 |
-
logits: torch.FloatTensor,
|
78 |
-
labels: torch.LongTensor,
|
79 |
-
average_log_prob: bool = False,
|
80 |
-
) -> torch.FloatTensor:
|
81 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
82 |
-
|
83 |
-
Args:
|
84 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
85 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
86 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
90 |
-
"""
|
91 |
-
assert logits.shape[:-1] == labels.shape
|
92 |
-
|
93 |
-
labels = labels.clone()
|
94 |
-
loss_mask = labels != -100
|
95 |
-
|
96 |
-
# dummy token; we'll ignore the losses on these tokens later
|
97 |
-
labels[labels == -100] = 0
|
98 |
-
|
99 |
-
per_token_logps = torch.gather(
|
100 |
-
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
101 |
-
).squeeze(-1)
|
102 |
-
|
103 |
-
if average_log_prob:
|
104 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
105 |
-
else:
|
106 |
-
return (per_token_logps * loss_mask).sum(-1)
|
107 |
-
|
108 |
-
def _step(self, batch, batch_idx, stage: str):
|
109 |
-
is_train = stage == "train"
|
110 |
-
|
111 |
-
if is_train:
|
112 |
-
# Key part to make lora work
|
113 |
-
# Otherwise the parameters are merged, which lead to incorrect gradients
|
114 |
-
self.model.train()
|
115 |
-
|
116 |
-
# Do positive and negative samples in the same batch to speed up training
|
117 |
-
labels = batch["labels"]
|
118 |
-
outputs = self.model(
|
119 |
-
inp=batch["inputs"],
|
120 |
-
key_padding_mask=batch["attention_masks"],
|
121 |
-
)
|
122 |
-
token_logits = outputs.token_logits
|
123 |
-
codebook_logits = outputs.codebook_logits
|
124 |
-
|
125 |
-
# Generate labels
|
126 |
-
base_loss = F.cross_entropy(
|
127 |
-
token_logits.view(-1, token_logits.size(-1)),
|
128 |
-
labels[:, 0].reshape(-1),
|
129 |
-
ignore_index=-100,
|
130 |
-
)
|
131 |
-
|
132 |
-
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
133 |
-
semantic_loss = F.cross_entropy(
|
134 |
-
codebook_logits.view(-1, codebook_logits.size(-1)),
|
135 |
-
codebook_labels.reshape(-1),
|
136 |
-
ignore_index=-100,
|
137 |
-
)
|
138 |
-
|
139 |
-
loss = base_loss + semantic_loss
|
140 |
-
|
141 |
-
self.log(
|
142 |
-
f"{stage}/loss",
|
143 |
-
loss,
|
144 |
-
on_step=is_train,
|
145 |
-
on_epoch=not is_train,
|
146 |
-
prog_bar=True,
|
147 |
-
logger=True,
|
148 |
-
sync_dist=not is_train,
|
149 |
-
)
|
150 |
-
|
151 |
-
self.log(
|
152 |
-
f"{stage}/base_loss",
|
153 |
-
base_loss,
|
154 |
-
on_step=is_train,
|
155 |
-
on_epoch=not is_train,
|
156 |
-
prog_bar=False,
|
157 |
-
logger=True,
|
158 |
-
sync_dist=not is_train,
|
159 |
-
)
|
160 |
-
|
161 |
-
self.log(
|
162 |
-
f"{stage}/semantic_loss",
|
163 |
-
semantic_loss,
|
164 |
-
on_step=is_train,
|
165 |
-
on_epoch=not is_train,
|
166 |
-
prog_bar=False,
|
167 |
-
logger=True,
|
168 |
-
sync_dist=not is_train,
|
169 |
-
)
|
170 |
-
|
171 |
-
# Top-5 accuracy
|
172 |
-
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
173 |
-
self.log(
|
174 |
-
f"{stage}/top_5_accuracy",
|
175 |
-
accuracy,
|
176 |
-
on_step=is_train,
|
177 |
-
on_epoch=not is_train,
|
178 |
-
prog_bar=True,
|
179 |
-
logger=True,
|
180 |
-
sync_dist=not is_train,
|
181 |
-
)
|
182 |
-
|
183 |
-
return loss
|
184 |
-
|
185 |
-
def get_accuracy(self, logits, labels):
|
186 |
-
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
187 |
-
if mask.sum() == 0:
|
188 |
-
return torch.tensor(0.0, device=logits.device)
|
189 |
-
|
190 |
-
_, indices = logits.topk(5, dim=-1)
|
191 |
-
correct = indices.eq(labels.unsqueeze(-1))
|
192 |
-
correct[~mask] = 0
|
193 |
-
correct = correct.sum()
|
194 |
-
accuracy = correct / mask.sum()
|
195 |
-
|
196 |
-
return accuracy
|
197 |
-
|
198 |
-
def training_step(self, batch, batch_idx):
|
199 |
-
return self._step(batch, batch_idx, "train")
|
200 |
-
|
201 |
-
def validation_step(self, batch, batch_idx):
|
202 |
-
return self._step(batch, batch_idx, "val")
|
|
|
1 |
+
from typing import Any, Optional
|
2 |
+
|
3 |
+
import lightning as L
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
7 |
+
|
8 |
+
import fish_speech.utils as utils
|
9 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
10 |
+
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
11 |
+
|
12 |
+
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
13 |
+
|
14 |
+
|
15 |
+
class TextToSemantic(L.LightningModule):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
model: NaiveTransformer,
|
19 |
+
optimizer: Any,
|
20 |
+
lr_scheduler: Any,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.model = model
|
25 |
+
self.optimizer_builder = optimizer
|
26 |
+
self.lr_scheduler_builder = lr_scheduler
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.model(x)
|
30 |
+
|
31 |
+
def on_save_checkpoint(self, checkpoint):
|
32 |
+
# Save only LoRA parameters
|
33 |
+
state_dict = checkpoint["state_dict"]
|
34 |
+
use_lora = any("lora" in name for name in state_dict.keys())
|
35 |
+
if not use_lora:
|
36 |
+
return
|
37 |
+
|
38 |
+
for name in list(state_dict.keys()):
|
39 |
+
if "lora" not in name:
|
40 |
+
state_dict.pop(name)
|
41 |
+
|
42 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
43 |
+
# Get weight decay parameters
|
44 |
+
weight_decay_parameters, other_parameters = [], []
|
45 |
+
for name, param in self.named_parameters():
|
46 |
+
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
47 |
+
other_parameters.append(param)
|
48 |
+
else:
|
49 |
+
weight_decay_parameters.append(param)
|
50 |
+
|
51 |
+
optimizer = self.optimizer_builder(
|
52 |
+
[
|
53 |
+
{"params": weight_decay_parameters},
|
54 |
+
{"params": other_parameters, "weight_decay": 0.0},
|
55 |
+
]
|
56 |
+
)
|
57 |
+
|
58 |
+
# Print the parameters and their weight decay
|
59 |
+
for i in optimizer.param_groups:
|
60 |
+
log.info(
|
61 |
+
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
62 |
+
)
|
63 |
+
|
64 |
+
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
65 |
+
|
66 |
+
return {
|
67 |
+
"optimizer": optimizer,
|
68 |
+
"lr_scheduler": {
|
69 |
+
"scheduler": lr_scheduler,
|
70 |
+
"interval": "step",
|
71 |
+
},
|
72 |
+
}
|
73 |
+
|
74 |
+
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
75 |
+
def get_batch_logps(
|
76 |
+
self,
|
77 |
+
logits: torch.FloatTensor,
|
78 |
+
labels: torch.LongTensor,
|
79 |
+
average_log_prob: bool = False,
|
80 |
+
) -> torch.FloatTensor:
|
81 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
85 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
86 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
90 |
+
"""
|
91 |
+
assert logits.shape[:-1] == labels.shape
|
92 |
+
|
93 |
+
labels = labels.clone()
|
94 |
+
loss_mask = labels != -100
|
95 |
+
|
96 |
+
# dummy token; we'll ignore the losses on these tokens later
|
97 |
+
labels[labels == -100] = 0
|
98 |
+
|
99 |
+
per_token_logps = torch.gather(
|
100 |
+
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
101 |
+
).squeeze(-1)
|
102 |
+
|
103 |
+
if average_log_prob:
|
104 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
105 |
+
else:
|
106 |
+
return (per_token_logps * loss_mask).sum(-1)
|
107 |
+
|
108 |
+
def _step(self, batch, batch_idx, stage: str):
|
109 |
+
is_train = stage == "train"
|
110 |
+
|
111 |
+
if is_train:
|
112 |
+
# Key part to make lora work
|
113 |
+
# Otherwise the parameters are merged, which lead to incorrect gradients
|
114 |
+
self.model.train()
|
115 |
+
|
116 |
+
# Do positive and negative samples in the same batch to speed up training
|
117 |
+
labels = batch["labels"]
|
118 |
+
outputs = self.model(
|
119 |
+
inp=batch["inputs"],
|
120 |
+
key_padding_mask=batch["attention_masks"],
|
121 |
+
)
|
122 |
+
token_logits = outputs.token_logits
|
123 |
+
codebook_logits = outputs.codebook_logits
|
124 |
+
|
125 |
+
# Generate labels
|
126 |
+
base_loss = F.cross_entropy(
|
127 |
+
token_logits.view(-1, token_logits.size(-1)),
|
128 |
+
labels[:, 0].reshape(-1),
|
129 |
+
ignore_index=-100,
|
130 |
+
)
|
131 |
+
|
132 |
+
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
133 |
+
semantic_loss = F.cross_entropy(
|
134 |
+
codebook_logits.view(-1, codebook_logits.size(-1)),
|
135 |
+
codebook_labels.reshape(-1),
|
136 |
+
ignore_index=-100,
|
137 |
+
)
|
138 |
+
|
139 |
+
loss = base_loss + semantic_loss
|
140 |
+
|
141 |
+
self.log(
|
142 |
+
f"{stage}/loss",
|
143 |
+
loss,
|
144 |
+
on_step=is_train,
|
145 |
+
on_epoch=not is_train,
|
146 |
+
prog_bar=True,
|
147 |
+
logger=True,
|
148 |
+
sync_dist=not is_train,
|
149 |
+
)
|
150 |
+
|
151 |
+
self.log(
|
152 |
+
f"{stage}/base_loss",
|
153 |
+
base_loss,
|
154 |
+
on_step=is_train,
|
155 |
+
on_epoch=not is_train,
|
156 |
+
prog_bar=False,
|
157 |
+
logger=True,
|
158 |
+
sync_dist=not is_train,
|
159 |
+
)
|
160 |
+
|
161 |
+
self.log(
|
162 |
+
f"{stage}/semantic_loss",
|
163 |
+
semantic_loss,
|
164 |
+
on_step=is_train,
|
165 |
+
on_epoch=not is_train,
|
166 |
+
prog_bar=False,
|
167 |
+
logger=True,
|
168 |
+
sync_dist=not is_train,
|
169 |
+
)
|
170 |
+
|
171 |
+
# Top-5 accuracy
|
172 |
+
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
173 |
+
self.log(
|
174 |
+
f"{stage}/top_5_accuracy",
|
175 |
+
accuracy,
|
176 |
+
on_step=is_train,
|
177 |
+
on_epoch=not is_train,
|
178 |
+
prog_bar=True,
|
179 |
+
logger=True,
|
180 |
+
sync_dist=not is_train,
|
181 |
+
)
|
182 |
+
|
183 |
+
return loss
|
184 |
+
|
185 |
+
def get_accuracy(self, logits, labels):
|
186 |
+
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
187 |
+
if mask.sum() == 0:
|
188 |
+
return torch.tensor(0.0, device=logits.device)
|
189 |
+
|
190 |
+
_, indices = logits.topk(5, dim=-1)
|
191 |
+
correct = indices.eq(labels.unsqueeze(-1))
|
192 |
+
correct[~mask] = 0
|
193 |
+
correct = correct.sum()
|
194 |
+
accuracy = correct / mask.sum()
|
195 |
+
|
196 |
+
return accuracy
|
197 |
+
|
198 |
+
def training_step(self, batch, batch_idx):
|
199 |
+
return self._step(batch, batch_idx, "train")
|
200 |
+
|
201 |
+
def validation_step(self, batch, batch_idx):
|
202 |
+
return self._step(batch, batch_idx, "val")
|
fish_speech/models/text2semantic/llama.py
CHANGED
@@ -1,779 +1,887 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
import torch
|
10 |
-
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from torch
|
14 |
-
from torch.nn
|
15 |
-
from torch.
|
16 |
-
from
|
17 |
-
|
18 |
-
|
19 |
-
from fish_speech.
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
class
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
self,
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
)
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
self.
|
193 |
-
|
194 |
-
if
|
195 |
-
self.
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
def
|
230 |
-
self,
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
#
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
if
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
)
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
self
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
#
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
self
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
)
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
self
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
self.
|
735 |
-
|
736 |
-
self.
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
def forward(
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
from collections import OrderedDict
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from loguru import logger
|
13 |
+
from torch import Tensor
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
16 |
+
from torch.utils.checkpoint import checkpoint
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
|
20 |
+
from fish_speech.utils import RankedLogger
|
21 |
+
|
22 |
+
from .lora import LoraConfig, setup_lora
|
23 |
+
|
24 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
25 |
+
|
26 |
+
|
27 |
+
def find_multiple(n: int, k: int) -> int:
|
28 |
+
if n % k == 0:
|
29 |
+
return n
|
30 |
+
return n + k - (n % k)
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class BaseModelArgs:
|
35 |
+
model_type: str = "base"
|
36 |
+
|
37 |
+
vocab_size: int = 32000
|
38 |
+
n_layer: int = 32
|
39 |
+
n_head: int = 32
|
40 |
+
dim: int = 4096
|
41 |
+
intermediate_size: int = None
|
42 |
+
n_local_heads: int = -1
|
43 |
+
head_dim: int = 64
|
44 |
+
rope_base: float = 10000
|
45 |
+
norm_eps: float = 1e-5
|
46 |
+
max_seq_len: int = 2048
|
47 |
+
dropout: float = 0.0
|
48 |
+
tie_word_embeddings: bool = True
|
49 |
+
attention_qkv_bias: bool = False
|
50 |
+
|
51 |
+
# Codebook configs
|
52 |
+
codebook_size: int = 160
|
53 |
+
num_codebooks: int = 4
|
54 |
+
|
55 |
+
# Gradient checkpointing
|
56 |
+
use_gradient_checkpointing: bool = True
|
57 |
+
|
58 |
+
# Initialize the model
|
59 |
+
initializer_range: float = 0.02
|
60 |
+
|
61 |
+
# Dummy vars
|
62 |
+
is_reward_model: bool = False
|
63 |
+
share_codebook_embeddings: bool = True
|
64 |
+
scale_codebook_embeddings: bool = False
|
65 |
+
|
66 |
+
def __post_init__(self):
|
67 |
+
if self.n_local_heads == -1:
|
68 |
+
self.n_local_heads = self.n_head
|
69 |
+
if self.intermediate_size is None:
|
70 |
+
hidden_dim = 4 * self.dim
|
71 |
+
n_hidden = int(2 * hidden_dim / 3)
|
72 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
73 |
+
self.head_dim = self.dim // self.n_head
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def from_pretrained(path: str):
|
77 |
+
path = Path(path)
|
78 |
+
|
79 |
+
if path.is_dir():
|
80 |
+
path = path / "config.json"
|
81 |
+
|
82 |
+
with open(path, "r", encoding="utf-8") as f:
|
83 |
+
data = json.load(f)
|
84 |
+
|
85 |
+
match data["model_type"]:
|
86 |
+
case "naive":
|
87 |
+
cls = NaiveModelArgs
|
88 |
+
case "dual_ar":
|
89 |
+
cls = DualARModelArgs
|
90 |
+
case _:
|
91 |
+
raise ValueError(f"Unknown model type: {data['model_type']}")
|
92 |
+
|
93 |
+
return cls(**data)
|
94 |
+
|
95 |
+
def save(self, path: str):
|
96 |
+
with open(path, "w") as f:
|
97 |
+
json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class NaiveModelArgs(BaseModelArgs):
|
102 |
+
model_type: str = "naive"
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class DualARModelArgs(BaseModelArgs):
|
107 |
+
model_type: str = "dual_ar"
|
108 |
+
n_fast_layer: int = 4
|
109 |
+
fast_dim: int | None = None
|
110 |
+
fast_n_head: int | None = None
|
111 |
+
fast_n_local_heads: int | None = None
|
112 |
+
fast_head_dim: int | None = None
|
113 |
+
fast_intermediate_size: int | None = None
|
114 |
+
fast_attention_qkv_bias: bool | None = None
|
115 |
+
|
116 |
+
def __post_init__(self):
|
117 |
+
super().__post_init__()
|
118 |
+
|
119 |
+
self.fast_dim = self.fast_dim or self.dim
|
120 |
+
self.fast_n_head = self.fast_n_head or self.n_head
|
121 |
+
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
|
122 |
+
self.fast_head_dim = self.fast_head_dim or self.head_dim
|
123 |
+
self.fast_intermediate_size = (
|
124 |
+
self.fast_intermediate_size or self.intermediate_size
|
125 |
+
)
|
126 |
+
self.fast_attention_qkv_bias = (
|
127 |
+
self.fast_attention_qkv_bias
|
128 |
+
if self.fast_attention_qkv_bias is not None
|
129 |
+
else self.attention_qkv_bias
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class KVCache(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
|
139 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
140 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
141 |
+
|
142 |
+
def update(self, input_pos, k_val, v_val):
|
143 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
144 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
145 |
+
|
146 |
+
k_out = self.k_cache
|
147 |
+
v_out = self.v_cache
|
148 |
+
k_out[:, :, input_pos] = k_val
|
149 |
+
v_out[:, :, input_pos] = v_val
|
150 |
+
|
151 |
+
return k_out, v_out
|
152 |
+
|
153 |
+
|
154 |
+
@dataclass
|
155 |
+
class TransformerForwardResult:
|
156 |
+
token_logits: Tensor
|
157 |
+
codebook_logits: Tensor
|
158 |
+
|
159 |
+
|
160 |
+
@dataclass
|
161 |
+
class BaseTransformerForwardResult:
|
162 |
+
logits: Tensor
|
163 |
+
hidden_states: Tensor
|
164 |
+
|
165 |
+
|
166 |
+
class BaseTransformer(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
config: BaseModelArgs,
|
170 |
+
tokenizer: FishTokenizer | AutoTokenizer,
|
171 |
+
init_weights: bool = True,
|
172 |
+
) -> None:
|
173 |
+
super().__init__()
|
174 |
+
self.config = config
|
175 |
+
self.tokenizer = tokenizer
|
176 |
+
self.semantic_token_ids = [
|
177 |
+
tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
|
178 |
+
]
|
179 |
+
|
180 |
+
# Slow transformer
|
181 |
+
self.embeddings = nn.Embedding(
|
182 |
+
config.vocab_size,
|
183 |
+
config.dim,
|
184 |
+
)
|
185 |
+
self.codebook_embeddings = nn.Embedding(
|
186 |
+
config.codebook_size * config.num_codebooks,
|
187 |
+
config.dim,
|
188 |
+
)
|
189 |
+
self.layers = nn.ModuleList(
|
190 |
+
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
191 |
+
)
|
192 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
193 |
+
|
194 |
+
if self.config.tie_word_embeddings is False:
|
195 |
+
self.output = nn.Linear(
|
196 |
+
config.dim,
|
197 |
+
config.vocab_size,
|
198 |
+
bias=False,
|
199 |
+
)
|
200 |
+
|
201 |
+
self.register_buffer(
|
202 |
+
"freqs_cis",
|
203 |
+
precompute_freqs_cis(
|
204 |
+
config.max_seq_len,
|
205 |
+
config.dim // config.n_head,
|
206 |
+
config.rope_base,
|
207 |
+
),
|
208 |
+
persistent=False,
|
209 |
+
)
|
210 |
+
self.register_buffer(
|
211 |
+
"causal_mask",
|
212 |
+
torch.tril(
|
213 |
+
torch.ones(
|
214 |
+
config.max_seq_len,
|
215 |
+
config.max_seq_len,
|
216 |
+
dtype=torch.bool,
|
217 |
+
)
|
218 |
+
),
|
219 |
+
persistent=False,
|
220 |
+
)
|
221 |
+
|
222 |
+
# For kv cache
|
223 |
+
self.max_batch_size = -1
|
224 |
+
self.max_seq_len = -1
|
225 |
+
|
226 |
+
if init_weights:
|
227 |
+
self.apply(self._init_weights)
|
228 |
+
|
229 |
+
def setup_caches(
|
230 |
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
231 |
+
):
|
232 |
+
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
|
233 |
+
return
|
234 |
+
|
235 |
+
head_dim = self.config.dim // self.config.n_head
|
236 |
+
max_seq_len = find_multiple(max_seq_len, 8)
|
237 |
+
self.max_seq_len = max_seq_len
|
238 |
+
self.max_batch_size = max_batch_size
|
239 |
+
|
240 |
+
for b in self.layers:
|
241 |
+
b.attention.kv_cache = KVCache(
|
242 |
+
max_batch_size,
|
243 |
+
max_seq_len,
|
244 |
+
self.config.n_local_heads,
|
245 |
+
head_dim,
|
246 |
+
dtype=dtype,
|
247 |
+
)
|
248 |
+
|
249 |
+
def embed(self, x: Tensor) -> Tensor:
|
250 |
+
vocab_embeds = [self.embeddings(x[:, 0])]
|
251 |
+
for i in range(self.config.num_codebooks):
|
252 |
+
emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
|
253 |
+
semantic_token_ids_tensor = torch.tensor(
|
254 |
+
self.semantic_token_ids, device=x.device
|
255 |
+
)
|
256 |
+
emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
|
257 |
+
|
258 |
+
x = torch.stack(vocab_embeds, dim=3)
|
259 |
+
x = x.sum(dim=3)
|
260 |
+
|
261 |
+
return x
|
262 |
+
|
263 |
+
def forward(
|
264 |
+
self,
|
265 |
+
inp: Tensor,
|
266 |
+
key_padding_mask: Optional[Tensor] = None,
|
267 |
+
) -> BaseTransformerForwardResult:
|
268 |
+
seq_len = inp.size(2)
|
269 |
+
|
270 |
+
# Here we want to merge the embeddings of the codebooks
|
271 |
+
x = self.embed(inp)
|
272 |
+
|
273 |
+
freqs_cis = self.freqs_cis[:seq_len]
|
274 |
+
|
275 |
+
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
276 |
+
# That is, FALSE means masked out
|
277 |
+
# To maintain consistency, key_padding_mask use TRUE to mask out
|
278 |
+
mask = None
|
279 |
+
if key_padding_mask is not None:
|
280 |
+
mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
|
281 |
+
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
282 |
+
|
283 |
+
for layer in self.layers:
|
284 |
+
if self.config.use_gradient_checkpointing and self.training:
|
285 |
+
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
|
286 |
+
else:
|
287 |
+
x = layer(x, freqs_cis, mask)
|
288 |
+
|
289 |
+
# We got slow_out here
|
290 |
+
slow_out = self.norm(x)
|
291 |
+
|
292 |
+
if self.config.tie_word_embeddings:
|
293 |
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
294 |
+
else:
|
295 |
+
token_logits = self.output(slow_out)
|
296 |
+
|
297 |
+
return BaseTransformerForwardResult(
|
298 |
+
logits=token_logits,
|
299 |
+
hidden_states=x,
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward_generate(
|
303 |
+
self,
|
304 |
+
inp: Tensor,
|
305 |
+
input_pos: Optional[Tensor] = None,
|
306 |
+
vq_masks: Optional[Tensor] = None, # this is not used in fact
|
307 |
+
return_all: bool = False,
|
308 |
+
) -> BaseTransformerForwardResult:
|
309 |
+
# This is used for generation, optimized for torch compile
|
310 |
+
# assert (
|
311 |
+
# self.max_seq_len != -1 and self.max_batch_size != -1
|
312 |
+
# ), "Please call setup_caches before forward_generate"
|
313 |
+
|
314 |
+
embeds = []
|
315 |
+
for i in range(self.config.num_codebooks):
|
316 |
+
if self.config.share_codebook_embeddings:
|
317 |
+
_tokens = inp[:, i + 1] + i * self.config.codebook_size
|
318 |
+
else:
|
319 |
+
_tokens = inp[:, i + 1]
|
320 |
+
|
321 |
+
emb = self.codebook_embeddings(_tokens)
|
322 |
+
embeds.append(emb)
|
323 |
+
|
324 |
+
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
|
325 |
+
# if self.config.use_codebook_mlp:
|
326 |
+
# vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
|
327 |
+
# vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
|
328 |
+
|
329 |
+
vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
|
330 |
+
inp[:, 0] <= self.tokenizer.semantic_end_id
|
331 |
+
)
|
332 |
+
|
333 |
+
vq_embeds_sum[~vq_masks] = 0
|
334 |
+
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
|
335 |
+
|
336 |
+
if input_pos is None:
|
337 |
+
input_pos = torch.arange(inp.shape[-1], device=x.device)
|
338 |
+
max_seq_len = inp.shape[-1]
|
339 |
+
else:
|
340 |
+
max_seq_len = self.max_seq_len
|
341 |
+
|
342 |
+
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
|
343 |
+
freqs_cis = self.freqs_cis[input_pos]
|
344 |
+
|
345 |
+
for layer in self.layers:
|
346 |
+
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
347 |
+
|
348 |
+
# If prefill, we only calculate the logits of last token
|
349 |
+
if x.size(1) > 1 and not return_all:
|
350 |
+
x = x[:, -1:]
|
351 |
+
|
352 |
+
# We got slow_out here
|
353 |
+
slow_out = self.norm(x)
|
354 |
+
|
355 |
+
if self.config.is_reward_model:
|
356 |
+
token_logits = self.score_output(slow_out)
|
357 |
+
elif self.config.tie_word_embeddings:
|
358 |
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
359 |
+
else:
|
360 |
+
token_logits = self.output(slow_out)
|
361 |
+
|
362 |
+
return BaseTransformerForwardResult(
|
363 |
+
logits=token_logits,
|
364 |
+
hidden_states=x,
|
365 |
+
)
|
366 |
+
|
367 |
+
def _init_weights(self, module):
|
368 |
+
std = self.config.initializer_range
|
369 |
+
if isinstance(module, nn.Linear):
|
370 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
371 |
+
if module.bias is not None:
|
372 |
+
module.bias.data.zero_()
|
373 |
+
elif isinstance(module, nn.Embedding):
|
374 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
375 |
+
if module.padding_idx is not None:
|
376 |
+
module.weight.data[module.padding_idx].zero_()
|
377 |
+
|
378 |
+
@staticmethod
|
379 |
+
def from_pretrained(
|
380 |
+
path: str,
|
381 |
+
load_weights: bool = False,
|
382 |
+
max_length: int | None = None,
|
383 |
+
lora_config: LoraConfig | None = None,
|
384 |
+
rope_base: int | None = None,
|
385 |
+
is_agent: bool = False,
|
386 |
+
) -> "BaseTransformer":
|
387 |
+
config = BaseModelArgs.from_pretrained(str(path))
|
388 |
+
if max_length is not None:
|
389 |
+
config.max_seq_len = max_length
|
390 |
+
log.info(f"Override max_seq_len to {max_length}")
|
391 |
+
|
392 |
+
if rope_base is not None:
|
393 |
+
config.rope_base = rope_base
|
394 |
+
log.info(f"Override rope_base to {rope_base}")
|
395 |
+
|
396 |
+
match config.model_type:
|
397 |
+
case "naive":
|
398 |
+
model_cls = NaiveTransformer
|
399 |
+
case "dual_ar":
|
400 |
+
model_cls = DualARTransformer
|
401 |
+
case _:
|
402 |
+
raise ValueError(f"Unknown model type: {config.model_type}")
|
403 |
+
|
404 |
+
if is_agent:
|
405 |
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
406 |
+
else:
|
407 |
+
tokenizer_path = str(path) + "/tokenizer.tiktoken"
|
408 |
+
tokenizer = FishTokenizer(tokenizer_path)
|
409 |
+
|
410 |
+
log.info(f"Loading model from {path}, config: {config}")
|
411 |
+
model = model_cls(config, tokenizer=tokenizer)
|
412 |
+
|
413 |
+
if lora_config is not None:
|
414 |
+
setup_lora(model, lora_config)
|
415 |
+
log.info(f"LoRA setup: {lora_config}")
|
416 |
+
|
417 |
+
if load_weights is False:
|
418 |
+
log.info("Randomly initialized model")
|
419 |
+
else:
|
420 |
+
|
421 |
+
if "int8" in str(Path(path)):
|
422 |
+
logger.info("Using int8 weight-only quantization!")
|
423 |
+
from tools.llama.quantize import WeightOnlyInt8QuantHandler
|
424 |
+
|
425 |
+
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
426 |
+
model = simple_quantizer.convert_for_runtime()
|
427 |
+
|
428 |
+
if "int4" in str(Path(path)):
|
429 |
+
logger.info("Using int4 quantization!")
|
430 |
+
path_comps = path.name.split("-")
|
431 |
+
assert path_comps[-2].startswith("g")
|
432 |
+
groupsize = int(path_comps[-2][1:])
|
433 |
+
from tools.llama.quantize import WeightOnlyInt4QuantHandler
|
434 |
+
|
435 |
+
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
436 |
+
model = simple_quantizer.convert_for_runtime()
|
437 |
+
|
438 |
+
weights = torch.load(
|
439 |
+
Path(path) / "model.pth",
|
440 |
+
map_location="cpu",
|
441 |
+
mmap=True,
|
442 |
+
weights_only=True,
|
443 |
+
)
|
444 |
+
|
445 |
+
if "state_dict" in weights:
|
446 |
+
logger.warning(
|
447 |
+
"Using a TextToSemantic LightningModule checkpoint, "
|
448 |
+
"please make sure it is a full model, not a LoRA model."
|
449 |
+
)
|
450 |
+
weights = weights["state_dict"]
|
451 |
+
|
452 |
+
if next(iter(weights.keys())).startswith("model."):
|
453 |
+
logger.info(
|
454 |
+
f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
|
455 |
+
)
|
456 |
+
new_weights = OrderedDict()
|
457 |
+
for k, v in weights.items():
|
458 |
+
new_weights[k.replace("model.", "")] = v
|
459 |
+
weights = new_weights
|
460 |
+
|
461 |
+
# Verify the name and shape of parameters since strict=False in load_state_dict.
|
462 |
+
for k, v in model.named_parameters():
|
463 |
+
if k not in weights:
|
464 |
+
logger.warning(f"No weight for {k}")
|
465 |
+
elif v.shape != weights[k].shape:
|
466 |
+
logger.warning(
|
467 |
+
f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
|
468 |
+
)
|
469 |
+
|
470 |
+
err = model.load_state_dict(weights, strict=False, assign=True)
|
471 |
+
log.info(f"Loaded weights with error: {err}")
|
472 |
+
|
473 |
+
return model
|
474 |
+
|
475 |
+
def save_pretrained(self, path: str, drop_lora: bool = False):
|
476 |
+
path = Path(path)
|
477 |
+
path.mkdir(parents=True, exist_ok=True)
|
478 |
+
|
479 |
+
self.config.save(path / "config.json")
|
480 |
+
state_dict = self.state_dict()
|
481 |
+
|
482 |
+
if drop_lora:
|
483 |
+
for key in list(state_dict.keys()):
|
484 |
+
if "lora" not in key:
|
485 |
+
continue
|
486 |
+
|
487 |
+
state_dict.pop(key)
|
488 |
+
log.info(f"Drop LoRA parameter: {key}")
|
489 |
+
|
490 |
+
torch.save(state_dict, path / "model.pth")
|
491 |
+
self.tokenizer.save_pretrained(path)
|
492 |
+
|
493 |
+
|
494 |
+
class NaiveTransformer(BaseTransformer):
|
495 |
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
496 |
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
497 |
+
|
498 |
+
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
499 |
+
self.codebook_output = nn.Linear(
|
500 |
+
config.dim,
|
501 |
+
config.codebook_size * config.num_codebooks,
|
502 |
+
bias=False,
|
503 |
+
)
|
504 |
+
|
505 |
+
self.apply(self._init_weights)
|
506 |
+
|
507 |
+
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
508 |
+
token_logits = result.logits
|
509 |
+
x = result.hidden_states
|
510 |
+
|
511 |
+
# Codebook
|
512 |
+
codebook_logits = self.codebook_output(self.codebook_norm(x))
|
513 |
+
codebook_logits = rearrange(
|
514 |
+
codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
|
515 |
+
)
|
516 |
+
|
517 |
+
return TransformerForwardResult(
|
518 |
+
token_logits=token_logits,
|
519 |
+
codebook_logits=codebook_logits,
|
520 |
+
)
|
521 |
+
|
522 |
+
def forward(
|
523 |
+
self,
|
524 |
+
inp: Tensor,
|
525 |
+
key_padding_mask: Optional[Tensor] = None,
|
526 |
+
) -> TransformerForwardResult:
|
527 |
+
result = super().forward(
|
528 |
+
inp=inp,
|
529 |
+
key_padding_mask=key_padding_mask,
|
530 |
+
)
|
531 |
+
return self.decode(result)
|
532 |
+
|
533 |
+
def forward_generate(
|
534 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
535 |
+
) -> TransformerForwardResult:
|
536 |
+
result = super().forward_generate(x, input_pos)
|
537 |
+
return self.decode(result)
|
538 |
+
|
539 |
+
|
540 |
+
class DualARTransformer(BaseTransformer):
|
541 |
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
542 |
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
543 |
+
|
544 |
+
# Project to fast dim if needed
|
545 |
+
if config.fast_dim is not None and config.fast_dim != config.dim:
|
546 |
+
self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
|
547 |
+
else:
|
548 |
+
self.fast_project_in = nn.Identity()
|
549 |
+
|
550 |
+
# Fast transformer
|
551 |
+
self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
|
552 |
+
|
553 |
+
# The equivalent bs is so large that sdpa doesn't work
|
554 |
+
override_config = dataclasses.replace(
|
555 |
+
config,
|
556 |
+
dim=config.fast_dim,
|
557 |
+
n_head=config.fast_n_head,
|
558 |
+
n_local_heads=config.fast_n_local_heads,
|
559 |
+
head_dim=config.fast_head_dim,
|
560 |
+
intermediate_size=config.fast_intermediate_size,
|
561 |
+
attention_qkv_bias=config.fast_attention_qkv_bias,
|
562 |
+
)
|
563 |
+
|
564 |
+
self.fast_layers = nn.ModuleList(
|
565 |
+
TransformerBlock(override_config, use_sdpa=False)
|
566 |
+
for _ in range(config.n_fast_layer)
|
567 |
+
)
|
568 |
+
self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
|
569 |
+
self.fast_output = nn.Linear(
|
570 |
+
config.fast_dim,
|
571 |
+
config.codebook_size,
|
572 |
+
bias=False,
|
573 |
+
)
|
574 |
+
|
575 |
+
self.register_buffer(
|
576 |
+
"fast_freqs_cis",
|
577 |
+
precompute_freqs_cis(
|
578 |
+
config.num_codebooks,
|
579 |
+
config.fast_dim // config.fast_n_head,
|
580 |
+
config.rope_base,
|
581 |
+
),
|
582 |
+
persistent=False,
|
583 |
+
)
|
584 |
+
self.apply(self._init_weights)
|
585 |
+
|
586 |
+
def setup_caches(
|
587 |
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
588 |
+
):
|
589 |
+
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
590 |
+
|
591 |
+
head_dim = self.config.fast_dim // self.config.fast_n_head
|
592 |
+
|
593 |
+
# Fast transformer
|
594 |
+
# The max seq len here is the number of codebooks
|
595 |
+
for b in self.fast_layers:
|
596 |
+
b.attention.kv_cache = KVCache(
|
597 |
+
max_batch_size,
|
598 |
+
self.config.num_codebooks,
|
599 |
+
self.config.fast_n_local_heads,
|
600 |
+
head_dim,
|
601 |
+
dtype=dtype,
|
602 |
+
)
|
603 |
+
|
604 |
+
def forward(
|
605 |
+
self,
|
606 |
+
inp: Tensor,
|
607 |
+
key_padding_mask: Optional[Tensor] = None,
|
608 |
+
) -> TransformerForwardResult:
|
609 |
+
parent_result = super().forward(inp, key_padding_mask)
|
610 |
+
token_logits = parent_result.logits
|
611 |
+
x = parent_result.hidden_states
|
612 |
+
x = self.fast_project_in(x)
|
613 |
+
|
614 |
+
# Fast transformer
|
615 |
+
fast_seq_len = self.config.num_codebooks
|
616 |
+
fast_mask = self.causal_mask[
|
617 |
+
None, None, :fast_seq_len, :fast_seq_len
|
618 |
+
] # (B, N, Q, K)
|
619 |
+
|
620 |
+
# Drop the last token and rotate left
|
621 |
+
codebooks = inp[:, 1:-1, 1:]
|
622 |
+
codebooks = F.pad(codebooks, (0, 1), value=0)
|
623 |
+
codebook_embeddings = self.fast_embeddings(codebooks)
|
624 |
+
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
625 |
+
b, s = x.size(0), x.size(2)
|
626 |
+
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
627 |
+
|
628 |
+
# Remove padded part
|
629 |
+
codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
630 |
+
codebook_mask = (codebooks == 0).all(dim=-1)
|
631 |
+
|
632 |
+
if torch.all(codebook_mask):
|
633 |
+
# If all codebooks are padded, we keep first 8 to make sure the model runs
|
634 |
+
codebook_mask[:8] = False
|
635 |
+
|
636 |
+
x_bs, x_len = x.size(0), x.size(1)
|
637 |
+
x = x[~codebook_mask]
|
638 |
+
|
639 |
+
for layer in self.fast_layers:
|
640 |
+
if self.config.use_gradient_checkpointing and self.training:
|
641 |
+
x = checkpoint(
|
642 |
+
layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
x = layer(x, self.fast_freqs_cis, fast_mask)
|
646 |
+
|
647 |
+
# unflatten the batch and num_codebooks
|
648 |
+
fast_out = self.fast_norm(x)
|
649 |
+
codebook_logits = self.fast_output(fast_out)
|
650 |
+
|
651 |
+
# Re-pad the codebook_logits
|
652 |
+
buffer = torch.zeros(
|
653 |
+
x_bs,
|
654 |
+
x_len,
|
655 |
+
codebook_logits.size(-1),
|
656 |
+
device=codebook_logits.device,
|
657 |
+
dtype=codebook_logits.dtype,
|
658 |
+
)
|
659 |
+
buffer[~codebook_mask] = codebook_logits
|
660 |
+
codebook_logits = buffer
|
661 |
+
|
662 |
+
assert codebook_logits.shape[1] == self.config.num_codebooks
|
663 |
+
codebook_logits = rearrange(
|
664 |
+
codebook_logits,
|
665 |
+
"(b s) n d -> b s n d",
|
666 |
+
b=b,
|
667 |
+
s=s,
|
668 |
+
n=self.config.num_codebooks,
|
669 |
+
)
|
670 |
+
|
671 |
+
return TransformerForwardResult(
|
672 |
+
token_logits=token_logits,
|
673 |
+
codebook_logits=codebook_logits,
|
674 |
+
)
|
675 |
+
|
676 |
+
def forward_generate_fast(
|
677 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
678 |
+
) -> Tensor:
|
679 |
+
# Fast transformer
|
680 |
+
x = x.view(1, 1, -1)
|
681 |
+
|
682 |
+
fast_mask = self.causal_mask[
|
683 |
+
None, None, input_pos, : self.config.num_codebooks
|
684 |
+
] # (B, N, Q, K)
|
685 |
+
fast_freqs_cis = self.fast_freqs_cis[input_pos]
|
686 |
+
|
687 |
+
for layer in self.fast_layers:
|
688 |
+
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
689 |
+
|
690 |
+
# unflatten the batch and num_codebooks
|
691 |
+
fast_out = self.fast_norm(x) # only take the last token
|
692 |
+
codebook_logits = self.fast_output(fast_out)
|
693 |
+
|
694 |
+
return codebook_logits
|
695 |
+
|
696 |
+
def forward_generate(
|
697 |
+
self,
|
698 |
+
x: Tensor,
|
699 |
+
input_pos: Optional[Tensor] = None,
|
700 |
+
vq_masks: Optional[Tensor] = None,
|
701 |
+
) -> TransformerForwardResult:
|
702 |
+
x = super().forward_generate(x, input_pos, vq_masks)
|
703 |
+
x.hidden_states = self.fast_project_in(x.hidden_states)
|
704 |
+
return x
|
705 |
+
|
706 |
+
|
707 |
+
class TransformerBlock(nn.Module):
|
708 |
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
709 |
+
super().__init__()
|
710 |
+
self.attention = Attention(config, use_sdpa=use_sdpa)
|
711 |
+
self.feed_forward = FeedForward(config)
|
712 |
+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
713 |
+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
714 |
+
|
715 |
+
def forward(
|
716 |
+
self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
|
717 |
+
) -> Tensor:
|
718 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
719 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
720 |
+
return out
|
721 |
+
|
722 |
+
|
723 |
+
class Attention(nn.Module):
|
724 |
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
|
725 |
+
super().__init__()
|
726 |
+
assert config.dim % config.n_head == 0
|
727 |
+
|
728 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
729 |
+
# key, query, value projections for all heads, but in a batch
|
730 |
+
self.wqkv = nn.Linear(
|
731 |
+
config.dim, total_head_dim, bias=config.attention_qkv_bias
|
732 |
+
)
|
733 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
734 |
+
self.kv_cache = None
|
735 |
+
|
736 |
+
self.dropout = config.dropout
|
737 |
+
self.n_head = config.n_head
|
738 |
+
self.head_dim = config.head_dim
|
739 |
+
self.n_local_heads = config.n_local_heads
|
740 |
+
self.dim = config.dim
|
741 |
+
self.use_sdpa = use_sdpa
|
742 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
743 |
+
|
744 |
+
def load_hook(self, state_dict, prefix, *args):
|
745 |
+
if prefix + "wq.weight" in state_dict:
|
746 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
747 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
748 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
749 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
750 |
+
|
751 |
+
def forward(
|
752 |
+
self,
|
753 |
+
x: Tensor,
|
754 |
+
freqs_cis: Tensor,
|
755 |
+
mask: Tensor,
|
756 |
+
input_pos: Optional[Tensor] = None,
|
757 |
+
) -> Tensor:
|
758 |
+
bsz, seqlen, _ = x.shape
|
759 |
+
|
760 |
+
kv_size = self.n_local_heads * self.head_dim
|
761 |
+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
762 |
+
|
763 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
764 |
+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
765 |
+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
766 |
+
|
767 |
+
q = apply_rotary_emb(q, freqs_cis)
|
768 |
+
k = apply_rotary_emb(k, freqs_cis)
|
769 |
+
|
770 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
771 |
+
|
772 |
+
if self.kv_cache is not None:
|
773 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
774 |
+
|
775 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
776 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
777 |
+
|
778 |
+
if self.use_sdpa:
|
779 |
+
if mask is None:
|
780 |
+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
781 |
+
y = F.scaled_dot_product_attention(
|
782 |
+
q,
|
783 |
+
k,
|
784 |
+
v,
|
785 |
+
dropout_p=self.dropout if self.training else 0.0,
|
786 |
+
is_causal=True,
|
787 |
+
# No third party attn_mask here to use flash_attention
|
788 |
+
)
|
789 |
+
else:
|
790 |
+
y = F.scaled_dot_product_attention(
|
791 |
+
q,
|
792 |
+
k,
|
793 |
+
v,
|
794 |
+
attn_mask=mask,
|
795 |
+
dropout_p=self.dropout if self.training else 0.0,
|
796 |
+
)
|
797 |
+
else:
|
798 |
+
y = self.eq_scaled_dot_product_attention(
|
799 |
+
q,
|
800 |
+
k,
|
801 |
+
v,
|
802 |
+
attn_mask=mask,
|
803 |
+
dropout_p=self.dropout if self.training else 0.0,
|
804 |
+
)
|
805 |
+
|
806 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
807 |
+
|
808 |
+
return self.wo(y)
|
809 |
+
|
810 |
+
def eq_scaled_dot_product_attention(
|
811 |
+
self,
|
812 |
+
query,
|
813 |
+
key,
|
814 |
+
value,
|
815 |
+
attn_mask=None,
|
816 |
+
dropout_p=0.0,
|
817 |
+
) -> torch.Tensor:
|
818 |
+
# This is a standard scaled dot product attention
|
819 |
+
# It's low efficient, but it doesn't raise cuda error
|
820 |
+
|
821 |
+
L, S = query.size(-2), key.size(-2)
|
822 |
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
823 |
+
attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
|
824 |
+
|
825 |
+
if attn_mask is not None:
|
826 |
+
if attn_mask.dtype == torch.bool:
|
827 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
828 |
+
else:
|
829 |
+
attn_bias += attn_mask
|
830 |
+
|
831 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
832 |
+
attn_weight += attn_bias
|
833 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
834 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
835 |
+
|
836 |
+
return attn_weight @ value
|
837 |
+
|
838 |
+
|
839 |
+
class FeedForward(nn.Module):
|
840 |
+
def __init__(self, config: BaseModelArgs) -> None:
|
841 |
+
super().__init__()
|
842 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
843 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
844 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
845 |
+
|
846 |
+
def forward(self, x: Tensor) -> Tensor:
|
847 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
848 |
+
|
849 |
+
|
850 |
+
class RMSNorm(nn.Module):
|
851 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
852 |
+
super().__init__()
|
853 |
+
self.eps = eps
|
854 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
855 |
+
|
856 |
+
def _norm(self, x):
|
857 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
858 |
+
|
859 |
+
def forward(self, x: Tensor) -> Tensor:
|
860 |
+
output = self._norm(x.float()).type_as(x)
|
861 |
+
return output * self.weight
|
862 |
+
|
863 |
+
|
864 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
|
865 |
+
freqs = 1.0 / (
|
866 |
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
867 |
+
)
|
868 |
+
t = torch.arange(seq_len, device=freqs.device)
|
869 |
+
freqs = torch.outer(t, freqs)
|
870 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
871 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
872 |
+
return cache.to(dtype=torch.bfloat16)
|
873 |
+
|
874 |
+
|
875 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
876 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
877 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
878 |
+
x_out2 = torch.stack(
|
879 |
+
[
|
880 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
881 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
882 |
+
],
|
883 |
+
-1,
|
884 |
+
)
|
885 |
+
|
886 |
+
x_out2 = x_out2.flatten(3)
|
887 |
+
return x_out2.type_as(x)
|
fish_speech/models/text2semantic/lora.py
CHANGED
@@ -1,92 +1,92 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
|
3 |
-
import loralib as lora
|
4 |
-
|
5 |
-
|
6 |
-
@dataclass
|
7 |
-
class LoraConfig:
|
8 |
-
r: int
|
9 |
-
lora_alpha: float
|
10 |
-
lora_dropout: float = 0.0
|
11 |
-
|
12 |
-
|
13 |
-
def setup_lora(model, lora_config):
|
14 |
-
# Replace the embedding layer with a LoRA layer
|
15 |
-
model.embeddings = lora.Embedding(
|
16 |
-
num_embeddings=model.embeddings.num_embeddings,
|
17 |
-
embedding_dim=model.embeddings.embedding_dim,
|
18 |
-
padding_idx=model.embeddings.padding_idx,
|
19 |
-
r=lora_config.r,
|
20 |
-
lora_alpha=lora_config.lora_alpha,
|
21 |
-
)
|
22 |
-
|
23 |
-
model.codebook_embeddings = lora.Embedding(
|
24 |
-
num_embeddings=model.codebook_embeddings.num_embeddings,
|
25 |
-
embedding_dim=model.codebook_embeddings.embedding_dim,
|
26 |
-
padding_idx=model.codebook_embeddings.padding_idx,
|
27 |
-
r=lora_config.r,
|
28 |
-
lora_alpha=lora_config.lora_alpha,
|
29 |
-
)
|
30 |
-
|
31 |
-
# Replace output layer with a LoRA layer
|
32 |
-
linears = [(model, "output")]
|
33 |
-
|
34 |
-
# Replace all linear layers with LoRA layers
|
35 |
-
for layer in model.layers:
|
36 |
-
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
37 |
-
linears.extend(
|
38 |
-
[
|
39 |
-
(layer.feed_forward, "w1"),
|
40 |
-
(layer.feed_forward, "w2"),
|
41 |
-
(layer.feed_forward, "w3"),
|
42 |
-
]
|
43 |
-
)
|
44 |
-
|
45 |
-
if hasattr(model, "fast_layers"):
|
46 |
-
model.fast_embeddings = lora.Embedding(
|
47 |
-
num_embeddings=model.fast_embeddings.num_embeddings,
|
48 |
-
embedding_dim=model.fast_embeddings.embedding_dim,
|
49 |
-
padding_idx=model.fast_embeddings.padding_idx,
|
50 |
-
r=lora_config.r,
|
51 |
-
lora_alpha=lora_config.lora_alpha,
|
52 |
-
)
|
53 |
-
|
54 |
-
# Dual-AR model
|
55 |
-
linears.append((model, "fast_output"))
|
56 |
-
|
57 |
-
for layer in model.fast_layers:
|
58 |
-
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
59 |
-
linears.extend(
|
60 |
-
[
|
61 |
-
(layer.feed_forward, "w1"),
|
62 |
-
(layer.feed_forward, "w2"),
|
63 |
-
(layer.feed_forward, "w3"),
|
64 |
-
]
|
65 |
-
)
|
66 |
-
|
67 |
-
for module, layer in linears:
|
68 |
-
updated_linear = lora.Linear(
|
69 |
-
in_features=getattr(module, layer).in_features,
|
70 |
-
out_features=getattr(module, layer).out_features,
|
71 |
-
bias=getattr(module, layer).bias,
|
72 |
-
r=lora_config.r,
|
73 |
-
lora_alpha=lora_config.lora_alpha,
|
74 |
-
lora_dropout=lora_config.lora_dropout,
|
75 |
-
)
|
76 |
-
setattr(module, layer, updated_linear)
|
77 |
-
|
78 |
-
# Mark only the LoRA layers as trainable
|
79 |
-
lora.mark_only_lora_as_trainable(model, bias="none")
|
80 |
-
|
81 |
-
|
82 |
-
def get_merged_state_dict(model):
|
83 |
-
# This line will merge the state dict of the model and the LoRA parameters
|
84 |
-
model.eval()
|
85 |
-
|
86 |
-
# Then we need to remove the LoRA parameters from the state dict
|
87 |
-
state_dict = model.state_dict()
|
88 |
-
for name in list(state_dict.keys()):
|
89 |
-
if "lora" in name:
|
90 |
-
state_dict.pop(name)
|
91 |
-
|
92 |
-
return state_dict
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import loralib as lora
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class LoraConfig:
|
8 |
+
r: int
|
9 |
+
lora_alpha: float
|
10 |
+
lora_dropout: float = 0.0
|
11 |
+
|
12 |
+
|
13 |
+
def setup_lora(model, lora_config):
|
14 |
+
# Replace the embedding layer with a LoRA layer
|
15 |
+
model.embeddings = lora.Embedding(
|
16 |
+
num_embeddings=model.embeddings.num_embeddings,
|
17 |
+
embedding_dim=model.embeddings.embedding_dim,
|
18 |
+
padding_idx=model.embeddings.padding_idx,
|
19 |
+
r=lora_config.r,
|
20 |
+
lora_alpha=lora_config.lora_alpha,
|
21 |
+
)
|
22 |
+
|
23 |
+
model.codebook_embeddings = lora.Embedding(
|
24 |
+
num_embeddings=model.codebook_embeddings.num_embeddings,
|
25 |
+
embedding_dim=model.codebook_embeddings.embedding_dim,
|
26 |
+
padding_idx=model.codebook_embeddings.padding_idx,
|
27 |
+
r=lora_config.r,
|
28 |
+
lora_alpha=lora_config.lora_alpha,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Replace output layer with a LoRA layer
|
32 |
+
linears = [(model, "output")]
|
33 |
+
|
34 |
+
# Replace all linear layers with LoRA layers
|
35 |
+
for layer in model.layers:
|
36 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
37 |
+
linears.extend(
|
38 |
+
[
|
39 |
+
(layer.feed_forward, "w1"),
|
40 |
+
(layer.feed_forward, "w2"),
|
41 |
+
(layer.feed_forward, "w3"),
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
if hasattr(model, "fast_layers"):
|
46 |
+
model.fast_embeddings = lora.Embedding(
|
47 |
+
num_embeddings=model.fast_embeddings.num_embeddings,
|
48 |
+
embedding_dim=model.fast_embeddings.embedding_dim,
|
49 |
+
padding_idx=model.fast_embeddings.padding_idx,
|
50 |
+
r=lora_config.r,
|
51 |
+
lora_alpha=lora_config.lora_alpha,
|
52 |
+
)
|
53 |
+
|
54 |
+
# Dual-AR model
|
55 |
+
linears.append((model, "fast_output"))
|
56 |
+
|
57 |
+
for layer in model.fast_layers:
|
58 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
59 |
+
linears.extend(
|
60 |
+
[
|
61 |
+
(layer.feed_forward, "w1"),
|
62 |
+
(layer.feed_forward, "w2"),
|
63 |
+
(layer.feed_forward, "w3"),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
for module, layer in linears:
|
68 |
+
updated_linear = lora.Linear(
|
69 |
+
in_features=getattr(module, layer).in_features,
|
70 |
+
out_features=getattr(module, layer).out_features,
|
71 |
+
bias=getattr(module, layer).bias,
|
72 |
+
r=lora_config.r,
|
73 |
+
lora_alpha=lora_config.lora_alpha,
|
74 |
+
lora_dropout=lora_config.lora_dropout,
|
75 |
+
)
|
76 |
+
setattr(module, layer, updated_linear)
|
77 |
+
|
78 |
+
# Mark only the LoRA layers as trainable
|
79 |
+
lora.mark_only_lora_as_trainable(model, bias="none")
|
80 |
+
|
81 |
+
|
82 |
+
def get_merged_state_dict(model):
|
83 |
+
# This line will merge the state dict of the model and the LoRA parameters
|
84 |
+
model.eval()
|
85 |
+
|
86 |
+
# Then we need to remove the LoRA parameters from the state dict
|
87 |
+
state_dict = model.state_dict()
|
88 |
+
for name in list(state_dict.keys()):
|
89 |
+
if "lora" in name:
|
90 |
+
state_dict.pop(name)
|
91 |
+
|
92 |
+
return state_dict
|
fish_speech/models/vqgan/lit_module.py
DELETED
@@ -1,442 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import math
|
3 |
-
from typing import Any, Callable
|
4 |
-
|
5 |
-
import lightning as L
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import wandb
|
9 |
-
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
10 |
-
from matplotlib import pyplot as plt
|
11 |
-
from torch import nn
|
12 |
-
|
13 |
-
from fish_speech.models.vqgan.modules.discriminator import Discriminator
|
14 |
-
from fish_speech.models.vqgan.modules.wavenet import WaveNet
|
15 |
-
from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
|
16 |
-
|
17 |
-
|
18 |
-
class VQGAN(L.LightningModule):
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
optimizer: Callable,
|
22 |
-
lr_scheduler: Callable,
|
23 |
-
encoder: WaveNet,
|
24 |
-
quantizer: nn.Module,
|
25 |
-
decoder: WaveNet,
|
26 |
-
discriminator: Discriminator,
|
27 |
-
vocoder: nn.Module,
|
28 |
-
encode_mel_transform: nn.Module,
|
29 |
-
gt_mel_transform: nn.Module,
|
30 |
-
weight_adv: float = 1.0,
|
31 |
-
weight_vq: float = 1.0,
|
32 |
-
weight_mel: float = 1.0,
|
33 |
-
sampling_rate: int = 44100,
|
34 |
-
freeze_encoder: bool = False,
|
35 |
-
):
|
36 |
-
super().__init__()
|
37 |
-
|
38 |
-
# Model parameters
|
39 |
-
self.optimizer_builder = optimizer
|
40 |
-
self.lr_scheduler_builder = lr_scheduler
|
41 |
-
|
42 |
-
# Modules
|
43 |
-
self.encoder = encoder
|
44 |
-
self.quantizer = quantizer
|
45 |
-
self.decoder = decoder
|
46 |
-
self.vocoder = vocoder
|
47 |
-
self.discriminator = discriminator
|
48 |
-
self.encode_mel_transform = encode_mel_transform
|
49 |
-
self.gt_mel_transform = gt_mel_transform
|
50 |
-
|
51 |
-
# A simple linear layer to project quality to condition channels
|
52 |
-
self.quality_projection = nn.Linear(1, 768)
|
53 |
-
|
54 |
-
# Freeze vocoder
|
55 |
-
for param in self.vocoder.parameters():
|
56 |
-
param.requires_grad = False
|
57 |
-
|
58 |
-
# Loss weights
|
59 |
-
self.weight_adv = weight_adv
|
60 |
-
self.weight_vq = weight_vq
|
61 |
-
self.weight_mel = weight_mel
|
62 |
-
|
63 |
-
# Other parameters
|
64 |
-
self.sampling_rate = sampling_rate
|
65 |
-
|
66 |
-
# Disable strict loading
|
67 |
-
self.strict_loading = False
|
68 |
-
|
69 |
-
# If encoder is frozen
|
70 |
-
if freeze_encoder:
|
71 |
-
for param in self.encoder.parameters():
|
72 |
-
param.requires_grad = False
|
73 |
-
|
74 |
-
for param in self.quantizer.parameters():
|
75 |
-
param.requires_grad = False
|
76 |
-
|
77 |
-
self.automatic_optimization = False
|
78 |
-
|
79 |
-
def on_save_checkpoint(self, checkpoint):
|
80 |
-
# Do not save vocoder
|
81 |
-
state_dict = checkpoint["state_dict"]
|
82 |
-
for name in list(state_dict.keys()):
|
83 |
-
if "vocoder" in name:
|
84 |
-
state_dict.pop(name)
|
85 |
-
|
86 |
-
def configure_optimizers(self):
|
87 |
-
optimizer_generator = self.optimizer_builder(
|
88 |
-
itertools.chain(
|
89 |
-
self.encoder.parameters(),
|
90 |
-
self.quantizer.parameters(),
|
91 |
-
self.decoder.parameters(),
|
92 |
-
self.quality_projection.parameters(),
|
93 |
-
)
|
94 |
-
)
|
95 |
-
optimizer_discriminator = self.optimizer_builder(
|
96 |
-
self.discriminator.parameters()
|
97 |
-
)
|
98 |
-
|
99 |
-
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
100 |
-
lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
|
101 |
-
|
102 |
-
return (
|
103 |
-
{
|
104 |
-
"optimizer": optimizer_generator,
|
105 |
-
"lr_scheduler": {
|
106 |
-
"scheduler": lr_scheduler_generator,
|
107 |
-
"interval": "step",
|
108 |
-
"name": "optimizer/generator",
|
109 |
-
},
|
110 |
-
},
|
111 |
-
{
|
112 |
-
"optimizer": optimizer_discriminator,
|
113 |
-
"lr_scheduler": {
|
114 |
-
"scheduler": lr_scheduler_discriminator,
|
115 |
-
"interval": "step",
|
116 |
-
"name": "optimizer/discriminator",
|
117 |
-
},
|
118 |
-
},
|
119 |
-
)
|
120 |
-
|
121 |
-
def training_step(self, batch, batch_idx):
|
122 |
-
optim_g, optim_d = self.optimizers()
|
123 |
-
|
124 |
-
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
125 |
-
|
126 |
-
audios = audios.float()
|
127 |
-
audios = audios[:, None, :]
|
128 |
-
|
129 |
-
with torch.no_grad():
|
130 |
-
encoded_mels = self.encode_mel_transform(audios)
|
131 |
-
gt_mels = self.gt_mel_transform(audios)
|
132 |
-
quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
|
133 |
-
quality = quality.unsqueeze(-1)
|
134 |
-
|
135 |
-
mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
|
136 |
-
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
137 |
-
mel_masks_float_conv = mel_masks[:, None, :].float()
|
138 |
-
gt_mels = gt_mels * mel_masks_float_conv
|
139 |
-
encoded_mels = encoded_mels * mel_masks_float_conv
|
140 |
-
|
141 |
-
# Encode
|
142 |
-
encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
|
143 |
-
|
144 |
-
# Quantize
|
145 |
-
vq_result = self.quantizer(encoded_features)
|
146 |
-
loss_vq = getattr("vq_result", "loss", 0.0)
|
147 |
-
vq_recon_features = vq_result.z * mel_masks_float_conv
|
148 |
-
vq_recon_features = (
|
149 |
-
vq_recon_features + self.quality_projection(quality)[:, :, None]
|
150 |
-
)
|
151 |
-
|
152 |
-
# VQ Decode
|
153 |
-
gen_mel = (
|
154 |
-
self.decoder(
|
155 |
-
torch.randn_like(vq_recon_features) * mel_masks_float_conv,
|
156 |
-
condition=vq_recon_features,
|
157 |
-
)
|
158 |
-
* mel_masks_float_conv
|
159 |
-
)
|
160 |
-
|
161 |
-
# Discriminator
|
162 |
-
real_logits = self.discriminator(gt_mels)
|
163 |
-
fake_logits = self.discriminator(gen_mel.detach())
|
164 |
-
d_mask = F.interpolate(
|
165 |
-
mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
|
166 |
-
)
|
167 |
-
|
168 |
-
loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
|
169 |
-
loss_fake = avg_with_mask(fake_logits**2, d_mask)
|
170 |
-
|
171 |
-
loss_d = loss_real + loss_fake
|
172 |
-
|
173 |
-
self.log(
|
174 |
-
"train/discriminator/loss",
|
175 |
-
loss_d,
|
176 |
-
on_step=True,
|
177 |
-
on_epoch=False,
|
178 |
-
prog_bar=True,
|
179 |
-
logger=True,
|
180 |
-
)
|
181 |
-
|
182 |
-
# Discriminator backward
|
183 |
-
optim_d.zero_grad()
|
184 |
-
self.manual_backward(loss_d)
|
185 |
-
self.clip_gradients(
|
186 |
-
optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
187 |
-
)
|
188 |
-
optim_d.step()
|
189 |
-
|
190 |
-
# Mel Loss, applying l1, using a weighted sum
|
191 |
-
mel_distance = (
|
192 |
-
gen_mel - gt_mels
|
193 |
-
).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
|
194 |
-
loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
|
195 |
-
loss_mel_mid_freq = avg_with_mask(
|
196 |
-
mel_distance[:, 40:70, :], mel_masks_float_conv
|
197 |
-
)
|
198 |
-
loss_mel_high_freq = avg_with_mask(
|
199 |
-
mel_distance[:, 70:, :], mel_masks_float_conv
|
200 |
-
)
|
201 |
-
loss_mel = (
|
202 |
-
loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
|
203 |
-
)
|
204 |
-
|
205 |
-
# Adversarial Loss
|
206 |
-
fake_logits = self.discriminator(gen_mel)
|
207 |
-
loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
|
208 |
-
|
209 |
-
# Total loss
|
210 |
-
loss = (
|
211 |
-
self.weight_vq * loss_vq
|
212 |
-
+ self.weight_mel * loss_mel
|
213 |
-
+ self.weight_adv * loss_adv
|
214 |
-
)
|
215 |
-
|
216 |
-
# Log losses
|
217 |
-
self.log(
|
218 |
-
"train/generator/loss",
|
219 |
-
loss,
|
220 |
-
on_step=True,
|
221 |
-
on_epoch=False,
|
222 |
-
prog_bar=True,
|
223 |
-
logger=True,
|
224 |
-
)
|
225 |
-
self.log(
|
226 |
-
"train/generator/loss_vq",
|
227 |
-
loss_vq,
|
228 |
-
on_step=True,
|
229 |
-
on_epoch=False,
|
230 |
-
prog_bar=False,
|
231 |
-
logger=True,
|
232 |
-
)
|
233 |
-
self.log(
|
234 |
-
"train/generator/loss_mel",
|
235 |
-
loss_mel,
|
236 |
-
on_step=True,
|
237 |
-
on_epoch=False,
|
238 |
-
prog_bar=False,
|
239 |
-
logger=True,
|
240 |
-
)
|
241 |
-
self.log(
|
242 |
-
"train/generator/loss_adv",
|
243 |
-
loss_adv,
|
244 |
-
on_step=True,
|
245 |
-
on_epoch=False,
|
246 |
-
prog_bar=False,
|
247 |
-
logger=True,
|
248 |
-
)
|
249 |
-
|
250 |
-
# Generator backward
|
251 |
-
optim_g.zero_grad()
|
252 |
-
self.manual_backward(loss)
|
253 |
-
self.clip_gradients(
|
254 |
-
optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
255 |
-
)
|
256 |
-
optim_g.step()
|
257 |
-
|
258 |
-
scheduler_g, scheduler_d = self.lr_schedulers()
|
259 |
-
scheduler_g.step()
|
260 |
-
scheduler_d.step()
|
261 |
-
|
262 |
-
def validation_step(self, batch: Any, batch_idx: int):
|
263 |
-
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
264 |
-
|
265 |
-
audios = audios.float()
|
266 |
-
audios = audios[:, None, :]
|
267 |
-
|
268 |
-
encoded_mels = self.encode_mel_transform(audios)
|
269 |
-
gt_mels = self.gt_mel_transform(audios)
|
270 |
-
|
271 |
-
mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
|
272 |
-
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
273 |
-
mel_masks_float_conv = mel_masks[:, None, :].float()
|
274 |
-
gt_mels = gt_mels * mel_masks_float_conv
|
275 |
-
encoded_mels = encoded_mels * mel_masks_float_conv
|
276 |
-
|
277 |
-
# Encode
|
278 |
-
encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
|
279 |
-
|
280 |
-
# Quantize
|
281 |
-
vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
|
282 |
-
vq_recon_features = (
|
283 |
-
vq_recon_features
|
284 |
-
+ self.quality_projection(
|
285 |
-
torch.ones(
|
286 |
-
vq_recon_features.shape[0], 1, device=vq_recon_features.device
|
287 |
-
)
|
288 |
-
* 2
|
289 |
-
)[:, :, None]
|
290 |
-
)
|
291 |
-
|
292 |
-
# VQ Decode
|
293 |
-
gen_aux_mels = (
|
294 |
-
self.decoder(
|
295 |
-
torch.randn_like(vq_recon_features) * mel_masks_float_conv,
|
296 |
-
condition=vq_recon_features,
|
297 |
-
)
|
298 |
-
* mel_masks_float_conv
|
299 |
-
)
|
300 |
-
loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
|
301 |
-
|
302 |
-
self.log(
|
303 |
-
"val/loss_mel",
|
304 |
-
loss_mel,
|
305 |
-
on_step=False,
|
306 |
-
on_epoch=True,
|
307 |
-
prog_bar=False,
|
308 |
-
logger=True,
|
309 |
-
sync_dist=True,
|
310 |
-
)
|
311 |
-
|
312 |
-
recon_audios = self.vocoder(gt_mels)
|
313 |
-
gen_aux_audios = self.vocoder(gen_aux_mels)
|
314 |
-
|
315 |
-
# only log the first batch
|
316 |
-
if batch_idx != 0:
|
317 |
-
return
|
318 |
-
|
319 |
-
for idx, (
|
320 |
-
gt_mel,
|
321 |
-
gen_aux_mel,
|
322 |
-
audio,
|
323 |
-
gen_aux_audio,
|
324 |
-
recon_audio,
|
325 |
-
audio_len,
|
326 |
-
) in enumerate(
|
327 |
-
zip(
|
328 |
-
gt_mels,
|
329 |
-
gen_aux_mels,
|
330 |
-
audios.cpu().float(),
|
331 |
-
gen_aux_audios.cpu().float(),
|
332 |
-
recon_audios.cpu().float(),
|
333 |
-
audio_lengths,
|
334 |
-
)
|
335 |
-
):
|
336 |
-
if idx > 4:
|
337 |
-
break
|
338 |
-
|
339 |
-
mel_len = audio_len // self.gt_mel_transform.hop_length
|
340 |
-
|
341 |
-
image_mels = plot_mel(
|
342 |
-
[
|
343 |
-
gt_mel[:, :mel_len],
|
344 |
-
gen_aux_mel[:, :mel_len],
|
345 |
-
],
|
346 |
-
[
|
347 |
-
"Ground-Truth",
|
348 |
-
"Auxiliary",
|
349 |
-
],
|
350 |
-
)
|
351 |
-
|
352 |
-
if isinstance(self.logger, WandbLogger):
|
353 |
-
self.logger.experiment.log(
|
354 |
-
{
|
355 |
-
"reconstruction_mel": wandb.Image(image_mels, caption="mels"),
|
356 |
-
"wavs": [
|
357 |
-
wandb.Audio(
|
358 |
-
audio[0, :audio_len],
|
359 |
-
sample_rate=self.sampling_rate,
|
360 |
-
caption="gt",
|
361 |
-
),
|
362 |
-
wandb.Audio(
|
363 |
-
gen_aux_audio[0, :audio_len],
|
364 |
-
sample_rate=self.sampling_rate,
|
365 |
-
caption="aux",
|
366 |
-
),
|
367 |
-
wandb.Audio(
|
368 |
-
recon_audio[0, :audio_len],
|
369 |
-
sample_rate=self.sampling_rate,
|
370 |
-
caption="recon",
|
371 |
-
),
|
372 |
-
],
|
373 |
-
},
|
374 |
-
)
|
375 |
-
|
376 |
-
if isinstance(self.logger, TensorBoardLogger):
|
377 |
-
self.logger.experiment.add_figure(
|
378 |
-
f"sample-{idx}/mels",
|
379 |
-
image_mels,
|
380 |
-
global_step=self.global_step,
|
381 |
-
)
|
382 |
-
self.logger.experiment.add_audio(
|
383 |
-
f"sample-{idx}/wavs/gt",
|
384 |
-
audio[0, :audio_len],
|
385 |
-
self.global_step,
|
386 |
-
sample_rate=self.sampling_rate,
|
387 |
-
)
|
388 |
-
self.logger.experiment.add_audio(
|
389 |
-
f"sample-{idx}/wavs/gen",
|
390 |
-
gen_aux_audio[0, :audio_len],
|
391 |
-
self.global_step,
|
392 |
-
sample_rate=self.sampling_rate,
|
393 |
-
)
|
394 |
-
self.logger.experiment.add_audio(
|
395 |
-
f"sample-{idx}/wavs/recon",
|
396 |
-
recon_audio[0, :audio_len],
|
397 |
-
self.global_step,
|
398 |
-
sample_rate=self.sampling_rate,
|
399 |
-
)
|
400 |
-
|
401 |
-
plt.close(image_mels)
|
402 |
-
|
403 |
-
def encode(self, audios, audio_lengths):
|
404 |
-
audios = audios.float()
|
405 |
-
|
406 |
-
mels = self.encode_mel_transform(audios)
|
407 |
-
mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
|
408 |
-
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
|
409 |
-
mel_masks_float_conv = mel_masks[:, None, :].float()
|
410 |
-
mels = mels * mel_masks_float_conv
|
411 |
-
|
412 |
-
# Encode
|
413 |
-
encoded_features = self.encoder(mels) * mel_masks_float_conv
|
414 |
-
feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
|
415 |
-
|
416 |
-
return self.quantizer.encode(encoded_features), feature_lengths
|
417 |
-
|
418 |
-
def decode(self, indices, feature_lengths, return_audios=False):
|
419 |
-
factor = math.prod(self.quantizer.downsample_factor)
|
420 |
-
mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
|
421 |
-
mel_masks_float_conv = mel_masks[:, None, :].float()
|
422 |
-
|
423 |
-
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
424 |
-
z = (
|
425 |
-
z
|
426 |
-
+ self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
|
427 |
-
:, :, None
|
428 |
-
]
|
429 |
-
)
|
430 |
-
|
431 |
-
gen_mel = (
|
432 |
-
self.decoder(
|
433 |
-
torch.randn_like(z) * mel_masks_float_conv,
|
434 |
-
condition=z,
|
435 |
-
)
|
436 |
-
* mel_masks_float_conv
|
437 |
-
)
|
438 |
-
|
439 |
-
if return_audios:
|
440 |
-
return self.vocoder(gen_mel)
|
441 |
-
|
442 |
-
return gen_mel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/modules/discriminator.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from torch.nn.utils.parametrizations import weight_norm
|
4 |
-
|
5 |
-
|
6 |
-
class Discriminator(nn.Module):
|
7 |
-
def __init__(self):
|
8 |
-
super().__init__()
|
9 |
-
|
10 |
-
blocks = []
|
11 |
-
convs = [
|
12 |
-
(1, 64, (3, 9), 1, (1, 4)),
|
13 |
-
(64, 128, (3, 9), (1, 2), (1, 4)),
|
14 |
-
(128, 256, (3, 9), (1, 2), (1, 4)),
|
15 |
-
(256, 512, (3, 9), (1, 2), (1, 4)),
|
16 |
-
(512, 1024, (3, 3), 1, (1, 1)),
|
17 |
-
(1024, 1, (3, 3), 1, (1, 1)),
|
18 |
-
]
|
19 |
-
|
20 |
-
for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
|
21 |
-
convs
|
22 |
-
):
|
23 |
-
blocks.append(
|
24 |
-
weight_norm(
|
25 |
-
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
26 |
-
)
|
27 |
-
)
|
28 |
-
|
29 |
-
if idx != len(convs) - 1:
|
30 |
-
blocks.append(nn.SiLU(inplace=True))
|
31 |
-
|
32 |
-
self.blocks = nn.Sequential(*blocks)
|
33 |
-
|
34 |
-
def forward(self, x):
|
35 |
-
return self.blocks(x[:, None])[:, 0]
|
36 |
-
|
37 |
-
|
38 |
-
if __name__ == "__main__":
|
39 |
-
model = Discriminator()
|
40 |
-
print(sum(p.numel() for p in model.parameters()) / 1_000_000)
|
41 |
-
x = torch.randn(1, 128, 1024)
|
42 |
-
y = model(x)
|
43 |
-
print(y.shape)
|
44 |
-
print(y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/modules/firefly.py
CHANGED
@@ -1,596 +1,596 @@
|
|
1 |
-
import math
|
2 |
-
from functools import partial
|
3 |
-
from math import prod
|
4 |
-
from typing import Callable
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from torch import nn
|
9 |
-
from torch.nn.utils.parametrizations import weight_norm
|
10 |
-
from torch.nn.utils.parametrize import remove_parametrizations
|
11 |
-
from torch.utils.checkpoint import checkpoint
|
12 |
-
|
13 |
-
|
14 |
-
def sequence_mask(length, max_length=None):
|
15 |
-
if max_length is None:
|
16 |
-
max_length = length.max()
|
17 |
-
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
18 |
-
return x.unsqueeze(0) < length.unsqueeze(1)
|
19 |
-
|
20 |
-
|
21 |
-
def init_weights(m, mean=0.0, std=0.01):
|
22 |
-
classname = m.__class__.__name__
|
23 |
-
if classname.find("Conv1D") != -1:
|
24 |
-
m.weight.data.normal_(mean, std)
|
25 |
-
|
26 |
-
|
27 |
-
def get_padding(kernel_size, dilation=1):
|
28 |
-
return (kernel_size * dilation - dilation) // 2
|
29 |
-
|
30 |
-
|
31 |
-
def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
|
32 |
-
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
33 |
-
padding_left, padding_right = paddings
|
34 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
35 |
-
assert (padding_left + padding_right) <= x.shape[-1]
|
36 |
-
end = x.shape[-1] - padding_right
|
37 |
-
return x[..., padding_left:end]
|
38 |
-
|
39 |
-
|
40 |
-
def get_extra_padding_for_conv1d(
|
41 |
-
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
42 |
-
) -> int:
|
43 |
-
"""See `pad_for_conv1d`."""
|
44 |
-
length = x.shape[-1]
|
45 |
-
n_frames = (length - kernel_size + padding_total) / stride + 1
|
46 |
-
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
47 |
-
return ideal_length - length
|
48 |
-
|
49 |
-
|
50 |
-
def pad1d(
|
51 |
-
x: torch.Tensor,
|
52 |
-
paddings: tuple[int, int],
|
53 |
-
mode: str = "zeros",
|
54 |
-
value: float = 0.0,
|
55 |
-
):
|
56 |
-
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
57 |
-
If this is the case, we insert extra 0 padding to the right
|
58 |
-
before the reflection happen.
|
59 |
-
"""
|
60 |
-
length = x.shape[-1]
|
61 |
-
padding_left, padding_right = paddings
|
62 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
63 |
-
if mode == "reflect":
|
64 |
-
max_pad = max(padding_left, padding_right)
|
65 |
-
extra_pad = 0
|
66 |
-
if length <= max_pad:
|
67 |
-
extra_pad = max_pad - length + 1
|
68 |
-
x = F.pad(x, (0, extra_pad))
|
69 |
-
padded = F.pad(x, paddings, mode, value)
|
70 |
-
end = padded.shape[-1] - extra_pad
|
71 |
-
return padded[..., :end]
|
72 |
-
else:
|
73 |
-
return F.pad(x, paddings, mode, value)
|
74 |
-
|
75 |
-
|
76 |
-
class FishConvNet(nn.Module):
|
77 |
-
def __init__(
|
78 |
-
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
|
79 |
-
):
|
80 |
-
super(FishConvNet, self).__init__()
|
81 |
-
self.conv = nn.Conv1d(
|
82 |
-
in_channels,
|
83 |
-
out_channels,
|
84 |
-
kernel_size,
|
85 |
-
stride=stride,
|
86 |
-
dilation=dilation,
|
87 |
-
groups=groups,
|
88 |
-
)
|
89 |
-
self.stride = stride
|
90 |
-
self.kernel_size = (kernel_size - 1) * dilation + 1
|
91 |
-
self.dilation = dilation
|
92 |
-
|
93 |
-
def forward(self, x):
|
94 |
-
pad = self.kernel_size - self.stride
|
95 |
-
extra_padding = get_extra_padding_for_conv1d(
|
96 |
-
x, self.kernel_size, self.stride, pad
|
97 |
-
)
|
98 |
-
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
99 |
-
return self.conv(x).contiguous()
|
100 |
-
|
101 |
-
def weight_norm(self, name="weight", dim=0):
|
102 |
-
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
103 |
-
return self
|
104 |
-
|
105 |
-
def
|
106 |
-
self.conv = remove_parametrizations(self.conv)
|
107 |
-
return self
|
108 |
-
|
109 |
-
|
110 |
-
class FishTransConvNet(nn.Module):
|
111 |
-
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
|
112 |
-
super(FishTransConvNet, self).__init__()
|
113 |
-
self.conv = nn.ConvTranspose1d(
|
114 |
-
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
115 |
-
)
|
116 |
-
self.stride = stride
|
117 |
-
self.kernel_size = kernel_size
|
118 |
-
|
119 |
-
def forward(self, x):
|
120 |
-
x = self.conv(x)
|
121 |
-
pad = self.kernel_size - self.stride
|
122 |
-
padding_right = math.ceil(pad)
|
123 |
-
padding_left = pad - padding_right
|
124 |
-
x = unpad1d(x, (padding_left, padding_right))
|
125 |
-
return x.contiguous()
|
126 |
-
|
127 |
-
def weight_norm(self, name="weight", dim=0):
|
128 |
-
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
129 |
-
return self
|
130 |
-
|
131 |
-
def
|
132 |
-
self.conv = remove_parametrizations(self.conv)
|
133 |
-
return self
|
134 |
-
|
135 |
-
|
136 |
-
class ResBlock1(torch.nn.Module):
|
137 |
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
138 |
-
super().__init__()
|
139 |
-
|
140 |
-
self.convs1 = nn.ModuleList(
|
141 |
-
[
|
142 |
-
FishConvNet(
|
143 |
-
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
144 |
-
).weight_norm(),
|
145 |
-
FishConvNet(
|
146 |
-
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
147 |
-
).weight_norm(),
|
148 |
-
FishConvNet(
|
149 |
-
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
150 |
-
).weight_norm(),
|
151 |
-
]
|
152 |
-
)
|
153 |
-
self.convs1.apply(init_weights)
|
154 |
-
|
155 |
-
self.convs2 = nn.ModuleList(
|
156 |
-
[
|
157 |
-
FishConvNet(
|
158 |
-
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
159 |
-
).weight_norm(),
|
160 |
-
FishConvNet(
|
161 |
-
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
162 |
-
).weight_norm(),
|
163 |
-
FishConvNet(
|
164 |
-
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
165 |
-
).weight_norm(),
|
166 |
-
]
|
167 |
-
)
|
168 |
-
self.convs2.apply(init_weights)
|
169 |
-
|
170 |
-
def forward(self, x):
|
171 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
172 |
-
xt = F.silu(x)
|
173 |
-
xt = c1(xt)
|
174 |
-
xt = F.silu(xt)
|
175 |
-
xt = c2(xt)
|
176 |
-
x = xt + x
|
177 |
-
return x
|
178 |
-
|
179 |
-
def remove_parametrizations(self):
|
180 |
-
for conv in self.convs1:
|
181 |
-
remove_parametrizations(
|
182 |
-
for conv in self.convs2:
|
183 |
-
remove_parametrizations(
|
184 |
-
|
185 |
-
|
186 |
-
class ParallelBlock(nn.Module):
|
187 |
-
def __init__(
|
188 |
-
self,
|
189 |
-
channels: int,
|
190 |
-
kernel_sizes: tuple[int] = (3, 7, 11),
|
191 |
-
dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
192 |
-
):
|
193 |
-
super().__init__()
|
194 |
-
|
195 |
-
assert len(kernel_sizes) == len(dilation_sizes)
|
196 |
-
|
197 |
-
self.blocks = nn.ModuleList()
|
198 |
-
for k, d in zip(kernel_sizes, dilation_sizes):
|
199 |
-
self.blocks.append(ResBlock1(channels, k, d))
|
200 |
-
|
201 |
-
def forward(self, x):
|
202 |
-
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
|
203 |
-
|
204 |
-
def remove_parametrizations(self):
|
205 |
-
for block in self.blocks:
|
206 |
-
block.remove_parametrizations()
|
207 |
-
|
208 |
-
|
209 |
-
class HiFiGANGenerator(nn.Module):
|
210 |
-
def __init__(
|
211 |
-
self,
|
212 |
-
*,
|
213 |
-
hop_length: int = 512,
|
214 |
-
upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
|
215 |
-
upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
|
216 |
-
resblock_kernel_sizes: tuple[int] = (3, 7, 11),
|
217 |
-
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
218 |
-
num_mels: int = 128,
|
219 |
-
upsample_initial_channel: int = 512,
|
220 |
-
pre_conv_kernel_size: int = 7,
|
221 |
-
post_conv_kernel_size: int = 7,
|
222 |
-
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
223 |
-
):
|
224 |
-
super().__init__()
|
225 |
-
|
226 |
-
assert (
|
227 |
-
prod(upsample_rates) == hop_length
|
228 |
-
), f"hop_length must be {prod(upsample_rates)}"
|
229 |
-
|
230 |
-
self.conv_pre = FishConvNet(
|
231 |
-
num_mels,
|
232 |
-
upsample_initial_channel,
|
233 |
-
pre_conv_kernel_size,
|
234 |
-
stride=1,
|
235 |
-
).weight_norm()
|
236 |
-
|
237 |
-
self.num_upsamples = len(upsample_rates)
|
238 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
239 |
-
|
240 |
-
self.noise_convs = nn.ModuleList()
|
241 |
-
self.ups = nn.ModuleList()
|
242 |
-
|
243 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
244 |
-
self.ups.append(
|
245 |
-
FishTransConvNet(
|
246 |
-
upsample_initial_channel // (2**i),
|
247 |
-
upsample_initial_channel // (2 ** (i + 1)),
|
248 |
-
k,
|
249 |
-
stride=u,
|
250 |
-
).weight_norm()
|
251 |
-
)
|
252 |
-
|
253 |
-
self.resblocks = nn.ModuleList()
|
254 |
-
for i in range(len(self.ups)):
|
255 |
-
ch = upsample_initial_channel // (2 ** (i + 1))
|
256 |
-
self.resblocks.append(
|
257 |
-
ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
|
258 |
-
)
|
259 |
-
|
260 |
-
self.activation_post = post_activation()
|
261 |
-
self.conv_post = FishConvNet(
|
262 |
-
ch, 1, post_conv_kernel_size, stride=1
|
263 |
-
).weight_norm()
|
264 |
-
self.ups.apply(init_weights)
|
265 |
-
self.conv_post.apply(init_weights)
|
266 |
-
|
267 |
-
def forward(self, x):
|
268 |
-
x = self.conv_pre(x)
|
269 |
-
|
270 |
-
for i in range(self.num_upsamples):
|
271 |
-
x = F.silu(x, inplace=True)
|
272 |
-
x = self.ups[i](x)
|
273 |
-
|
274 |
-
if self.training and self.checkpointing:
|
275 |
-
x = checkpoint(
|
276 |
-
self.resblocks[i],
|
277 |
-
x,
|
278 |
-
use_reentrant=False,
|
279 |
-
)
|
280 |
-
else:
|
281 |
-
x = self.resblocks[i](x)
|
282 |
-
|
283 |
-
x = self.activation_post(x)
|
284 |
-
x = self.conv_post(x)
|
285 |
-
x = torch.tanh(x)
|
286 |
-
|
287 |
-
return x
|
288 |
-
|
289 |
-
def remove_parametrizations(self):
|
290 |
-
for up in self.ups:
|
291 |
-
remove_parametrizations(
|
292 |
-
for block in self.resblocks:
|
293 |
-
block.remove_parametrizations()
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
# DropPath copied from timm library
|
299 |
-
def drop_path(
|
300 |
-
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
301 |
-
):
|
302 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
303 |
-
|
304 |
-
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
305 |
-
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
306 |
-
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
307 |
-
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
308 |
-
'survival rate' as the argument.
|
309 |
-
|
310 |
-
""" # noqa: E501
|
311 |
-
|
312 |
-
if drop_prob == 0.0 or not training:
|
313 |
-
return x
|
314 |
-
keep_prob = 1 - drop_prob
|
315 |
-
shape = (x.shape[0],) + (1,) * (
|
316 |
-
x.ndim - 1
|
317 |
-
) # work with diff dim tensors, not just 2D ConvNets
|
318 |
-
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
319 |
-
if keep_prob > 0.0 and scale_by_keep:
|
320 |
-
random_tensor.div_(keep_prob)
|
321 |
-
return x * random_tensor
|
322 |
-
|
323 |
-
|
324 |
-
class DropPath(nn.Module):
|
325 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
326 |
-
|
327 |
-
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
328 |
-
super(DropPath, self).__init__()
|
329 |
-
self.drop_prob = drop_prob
|
330 |
-
self.scale_by_keep = scale_by_keep
|
331 |
-
|
332 |
-
def forward(self, x):
|
333 |
-
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
334 |
-
|
335 |
-
def extra_repr(self):
|
336 |
-
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
337 |
-
|
338 |
-
|
339 |
-
class LayerNorm(nn.Module):
|
340 |
-
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
341 |
-
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
342 |
-
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
343 |
-
with shape (batch_size, channels, height, width).
|
344 |
-
""" # noqa: E501
|
345 |
-
|
346 |
-
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
347 |
-
super().__init__()
|
348 |
-
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
349 |
-
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
350 |
-
self.eps = eps
|
351 |
-
self.data_format = data_format
|
352 |
-
if self.data_format not in ["channels_last", "channels_first"]:
|
353 |
-
raise NotImplementedError
|
354 |
-
self.normalized_shape = (normalized_shape,)
|
355 |
-
|
356 |
-
def forward(self, x):
|
357 |
-
if self.data_format == "channels_last":
|
358 |
-
return F.layer_norm(
|
359 |
-
x, self.normalized_shape, self.weight, self.bias, self.eps
|
360 |
-
)
|
361 |
-
elif self.data_format == "channels_first":
|
362 |
-
u = x.mean(1, keepdim=True)
|
363 |
-
s = (x - u).pow(2).mean(1, keepdim=True)
|
364 |
-
x = (x - u) / torch.sqrt(s + self.eps)
|
365 |
-
x = self.weight[:, None] * x + self.bias[:, None]
|
366 |
-
return x
|
367 |
-
|
368 |
-
|
369 |
-
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
|
370 |
-
class ConvNeXtBlock(nn.Module):
|
371 |
-
r"""ConvNeXt Block. There are two equivalent implementations:
|
372 |
-
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
373 |
-
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
374 |
-
We use (2) as we find it slightly faster in PyTorch
|
375 |
-
|
376 |
-
Args:
|
377 |
-
dim (int): Number of input channels.
|
378 |
-
drop_path (float): Stochastic depth rate. Default: 0.0
|
379 |
-
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
380 |
-
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
381 |
-
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
382 |
-
dilation (int): Dilation for depthwise conv. Default: 1.
|
383 |
-
""" # noqa: E501
|
384 |
-
|
385 |
-
def __init__(
|
386 |
-
self,
|
387 |
-
dim: int,
|
388 |
-
drop_path: float = 0.0,
|
389 |
-
layer_scale_init_value: float = 1e-6,
|
390 |
-
mlp_ratio: float = 4.0,
|
391 |
-
kernel_size: int = 7,
|
392 |
-
dilation: int = 1,
|
393 |
-
):
|
394 |
-
super().__init__()
|
395 |
-
|
396 |
-
self.dwconv = FishConvNet(
|
397 |
-
dim,
|
398 |
-
dim,
|
399 |
-
kernel_size=kernel_size,
|
400 |
-
# padding=int(dilation * (kernel_size - 1) / 2),
|
401 |
-
groups=dim,
|
402 |
-
) # depthwise conv
|
403 |
-
self.norm = LayerNorm(dim, eps=1e-6)
|
404 |
-
self.pwconv1 = nn.Linear(
|
405 |
-
dim, int(mlp_ratio * dim)
|
406 |
-
) # pointwise/1x1 convs, implemented with linear layers
|
407 |
-
self.act = nn.GELU()
|
408 |
-
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
409 |
-
self.gamma = (
|
410 |
-
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
411 |
-
if layer_scale_init_value > 0
|
412 |
-
else None
|
413 |
-
)
|
414 |
-
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
415 |
-
|
416 |
-
def forward(self, x, apply_residual: bool = True):
|
417 |
-
input = x
|
418 |
-
|
419 |
-
x = self.dwconv(x)
|
420 |
-
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
421 |
-
x = self.norm(x)
|
422 |
-
x = self.pwconv1(x)
|
423 |
-
x = self.act(x)
|
424 |
-
x = self.pwconv2(x)
|
425 |
-
|
426 |
-
if self.gamma is not None:
|
427 |
-
x = self.gamma * x
|
428 |
-
|
429 |
-
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
430 |
-
x = self.drop_path(x)
|
431 |
-
|
432 |
-
if apply_residual:
|
433 |
-
x = input + x
|
434 |
-
|
435 |
-
return x
|
436 |
-
|
437 |
-
|
438 |
-
class ConvNeXtEncoder(nn.Module):
|
439 |
-
def __init__(
|
440 |
-
self,
|
441 |
-
input_channels: int = 3,
|
442 |
-
depths: list[int] = [3, 3, 9, 3],
|
443 |
-
dims: list[int] = [96, 192, 384, 768],
|
444 |
-
drop_path_rate: float = 0.0,
|
445 |
-
layer_scale_init_value: float = 1e-6,
|
446 |
-
kernel_size: int = 7,
|
447 |
-
):
|
448 |
-
super().__init__()
|
449 |
-
assert len(depths) == len(dims)
|
450 |
-
|
451 |
-
self.downsample_layers = nn.ModuleList()
|
452 |
-
stem = nn.Sequential(
|
453 |
-
FishConvNet(
|
454 |
-
input_channels,
|
455 |
-
dims[0],
|
456 |
-
kernel_size=7,
|
457 |
-
# padding=3,
|
458 |
-
# padding_mode="replicate",
|
459 |
-
# padding_mode="zeros",
|
460 |
-
),
|
461 |
-
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
462 |
-
)
|
463 |
-
self.downsample_layers.append(stem)
|
464 |
-
|
465 |
-
for i in range(len(depths) - 1):
|
466 |
-
mid_layer = nn.Sequential(
|
467 |
-
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
468 |
-
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
469 |
-
)
|
470 |
-
self.downsample_layers.append(mid_layer)
|
471 |
-
|
472 |
-
self.stages = nn.ModuleList()
|
473 |
-
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
474 |
-
|
475 |
-
cur = 0
|
476 |
-
for i in range(len(depths)):
|
477 |
-
stage = nn.Sequential(
|
478 |
-
*[
|
479 |
-
ConvNeXtBlock(
|
480 |
-
dim=dims[i],
|
481 |
-
drop_path=dp_rates[cur + j],
|
482 |
-
layer_scale_init_value=layer_scale_init_value,
|
483 |
-
kernel_size=kernel_size,
|
484 |
-
)
|
485 |
-
for j in range(depths[i])
|
486 |
-
]
|
487 |
-
)
|
488 |
-
self.stages.append(stage)
|
489 |
-
cur += depths[i]
|
490 |
-
|
491 |
-
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
492 |
-
self.apply(self._init_weights)
|
493 |
-
|
494 |
-
def _init_weights(self, m):
|
495 |
-
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
496 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
497 |
-
nn.init.constant_(m.bias, 0)
|
498 |
-
|
499 |
-
def forward(
|
500 |
-
self,
|
501 |
-
x: torch.Tensor,
|
502 |
-
) -> torch.Tensor:
|
503 |
-
for i in range(len(self.downsample_layers)):
|
504 |
-
x = self.downsample_layers[i](x)
|
505 |
-
x = self.stages[i](x)
|
506 |
-
|
507 |
-
return self.norm(x)
|
508 |
-
|
509 |
-
|
510 |
-
class FireflyArchitecture(nn.Module):
|
511 |
-
def __init__(
|
512 |
-
self,
|
513 |
-
backbone: nn.Module,
|
514 |
-
head: nn.Module,
|
515 |
-
quantizer: nn.Module,
|
516 |
-
spec_transform: nn.Module,
|
517 |
-
):
|
518 |
-
super().__init__()
|
519 |
-
|
520 |
-
self.backbone = backbone
|
521 |
-
self.head = head
|
522 |
-
self.quantizer = quantizer
|
523 |
-
self.spec_transform = spec_transform
|
524 |
-
self.downsample_factor = math.prod(self.quantizer.downsample_factor)
|
525 |
-
|
526 |
-
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
|
527 |
-
if self.spec_transform is not None:
|
528 |
-
x = self.spec_transform(x)
|
529 |
-
|
530 |
-
x = self.backbone(x)
|
531 |
-
if mask is not None:
|
532 |
-
x = x * mask
|
533 |
-
|
534 |
-
if self.quantizer is not None:
|
535 |
-
vq_result = self.quantizer(x)
|
536 |
-
x = vq_result.z
|
537 |
-
|
538 |
-
if mask is not None:
|
539 |
-
x = x * mask
|
540 |
-
|
541 |
-
x = self.head(x, template=template)
|
542 |
-
|
543 |
-
if x.ndim == 2:
|
544 |
-
x = x[:, None, :]
|
545 |
-
|
546 |
-
if self.vq is not None:
|
547 |
-
return x, vq_result
|
548 |
-
|
549 |
-
return x
|
550 |
-
|
551 |
-
def encode(self, audios, audio_lengths):
|
552 |
-
audios = audios.float()
|
553 |
-
|
554 |
-
mels = self.spec_transform(audios)
|
555 |
-
mel_lengths = audio_lengths // self.spec_transform.hop_length
|
556 |
-
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
|
557 |
-
mel_masks_float_conv = mel_masks[:, None, :].float()
|
558 |
-
mels = mels * mel_masks_float_conv
|
559 |
-
|
560 |
-
# Encode
|
561 |
-
encoded_features = self.backbone(mels) * mel_masks_float_conv
|
562 |
-
feature_lengths = mel_lengths // self.downsample_factor
|
563 |
-
|
564 |
-
return self.quantizer.encode(encoded_features), feature_lengths
|
565 |
-
|
566 |
-
def decode(self, indices, feature_lengths) -> torch.Tensor:
|
567 |
-
mel_masks = sequence_mask(
|
568 |
-
feature_lengths * self.downsample_factor,
|
569 |
-
indices.shape[2] * self.downsample_factor,
|
570 |
-
)
|
571 |
-
mel_masks_float_conv = mel_masks[:, None, :].float()
|
572 |
-
audio_lengths = (
|
573 |
-
feature_lengths * self.downsample_factor * self.spec_transform.hop_length
|
574 |
-
)
|
575 |
-
|
576 |
-
audio_masks = sequence_mask(
|
577 |
-
audio_lengths,
|
578 |
-
indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
|
579 |
-
)
|
580 |
-
audio_masks_float_conv = audio_masks[:, None, :].float()
|
581 |
-
|
582 |
-
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
583 |
-
x = self.head(z) * audio_masks_float_conv
|
584 |
-
|
585 |
-
return x, audio_lengths
|
586 |
-
|
587 |
-
def remove_parametrizations(self):
|
588 |
-
if hasattr(self.backbone, "remove_parametrizations"):
|
589 |
-
self.backbone.remove_parametrizations()
|
590 |
-
|
591 |
-
if hasattr(self.head, "remove_parametrizations"):
|
592 |
-
self.head.remove_parametrizations()
|
593 |
-
|
594 |
-
@property
|
595 |
-
def device(self):
|
596 |
-
return next(self.parameters()).device
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
from math import prod
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn.utils.parametrizations import weight_norm
|
10 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
|
14 |
+
def sequence_mask(length, max_length=None):
|
15 |
+
if max_length is None:
|
16 |
+
max_length = length.max()
|
17 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
18 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
19 |
+
|
20 |
+
|
21 |
+
def init_weights(m, mean=0.0, std=0.01):
|
22 |
+
classname = m.__class__.__name__
|
23 |
+
if classname.find("Conv1D") != -1:
|
24 |
+
m.weight.data.normal_(mean, std)
|
25 |
+
|
26 |
+
|
27 |
+
def get_padding(kernel_size, dilation=1):
|
28 |
+
return (kernel_size * dilation - dilation) // 2
|
29 |
+
|
30 |
+
|
31 |
+
def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
|
32 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
33 |
+
padding_left, padding_right = paddings
|
34 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
35 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
36 |
+
end = x.shape[-1] - padding_right
|
37 |
+
return x[..., padding_left:end]
|
38 |
+
|
39 |
+
|
40 |
+
def get_extra_padding_for_conv1d(
|
41 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
42 |
+
) -> int:
|
43 |
+
"""See `pad_for_conv1d`."""
|
44 |
+
length = x.shape[-1]
|
45 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
46 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
47 |
+
return ideal_length - length
|
48 |
+
|
49 |
+
|
50 |
+
def pad1d(
|
51 |
+
x: torch.Tensor,
|
52 |
+
paddings: tuple[int, int],
|
53 |
+
mode: str = "zeros",
|
54 |
+
value: float = 0.0,
|
55 |
+
):
|
56 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
57 |
+
If this is the case, we insert extra 0 padding to the right
|
58 |
+
before the reflection happen.
|
59 |
+
"""
|
60 |
+
length = x.shape[-1]
|
61 |
+
padding_left, padding_right = paddings
|
62 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
63 |
+
if mode == "reflect":
|
64 |
+
max_pad = max(padding_left, padding_right)
|
65 |
+
extra_pad = 0
|
66 |
+
if length <= max_pad:
|
67 |
+
extra_pad = max_pad - length + 1
|
68 |
+
x = F.pad(x, (0, extra_pad))
|
69 |
+
padded = F.pad(x, paddings, mode, value)
|
70 |
+
end = padded.shape[-1] - extra_pad
|
71 |
+
return padded[..., :end]
|
72 |
+
else:
|
73 |
+
return F.pad(x, paddings, mode, value)
|
74 |
+
|
75 |
+
|
76 |
+
class FishConvNet(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
|
79 |
+
):
|
80 |
+
super(FishConvNet, self).__init__()
|
81 |
+
self.conv = nn.Conv1d(
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size,
|
85 |
+
stride=stride,
|
86 |
+
dilation=dilation,
|
87 |
+
groups=groups,
|
88 |
+
)
|
89 |
+
self.stride = stride
|
90 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
91 |
+
self.dilation = dilation
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
pad = self.kernel_size - self.stride
|
95 |
+
extra_padding = get_extra_padding_for_conv1d(
|
96 |
+
x, self.kernel_size, self.stride, pad
|
97 |
+
)
|
98 |
+
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
99 |
+
return self.conv(x).contiguous()
|
100 |
+
|
101 |
+
def weight_norm(self, name="weight", dim=0):
|
102 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
103 |
+
return self
|
104 |
+
|
105 |
+
def remove_parametrizations(self, name="weight"):
|
106 |
+
self.conv = remove_parametrizations(self.conv, name)
|
107 |
+
return self
|
108 |
+
|
109 |
+
|
110 |
+
class FishTransConvNet(nn.Module):
|
111 |
+
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
|
112 |
+
super(FishTransConvNet, self).__init__()
|
113 |
+
self.conv = nn.ConvTranspose1d(
|
114 |
+
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
115 |
+
)
|
116 |
+
self.stride = stride
|
117 |
+
self.kernel_size = kernel_size
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.conv(x)
|
121 |
+
pad = self.kernel_size - self.stride
|
122 |
+
padding_right = math.ceil(pad)
|
123 |
+
padding_left = pad - padding_right
|
124 |
+
x = unpad1d(x, (padding_left, padding_right))
|
125 |
+
return x.contiguous()
|
126 |
+
|
127 |
+
def weight_norm(self, name="weight", dim=0):
|
128 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
129 |
+
return self
|
130 |
+
|
131 |
+
def remove_parametrizations(self, name="weight"):
|
132 |
+
self.conv = remove_parametrizations(self.conv, name)
|
133 |
+
return self
|
134 |
+
|
135 |
+
|
136 |
+
class ResBlock1(torch.nn.Module):
|
137 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
self.convs1 = nn.ModuleList(
|
141 |
+
[
|
142 |
+
FishConvNet(
|
143 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
144 |
+
).weight_norm(),
|
145 |
+
FishConvNet(
|
146 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
147 |
+
).weight_norm(),
|
148 |
+
FishConvNet(
|
149 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
150 |
+
).weight_norm(),
|
151 |
+
]
|
152 |
+
)
|
153 |
+
self.convs1.apply(init_weights)
|
154 |
+
|
155 |
+
self.convs2 = nn.ModuleList(
|
156 |
+
[
|
157 |
+
FishConvNet(
|
158 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
159 |
+
).weight_norm(),
|
160 |
+
FishConvNet(
|
161 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
162 |
+
).weight_norm(),
|
163 |
+
FishConvNet(
|
164 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
165 |
+
).weight_norm(),
|
166 |
+
]
|
167 |
+
)
|
168 |
+
self.convs2.apply(init_weights)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
172 |
+
xt = F.silu(x)
|
173 |
+
xt = c1(xt)
|
174 |
+
xt = F.silu(xt)
|
175 |
+
xt = c2(xt)
|
176 |
+
x = xt + x
|
177 |
+
return x
|
178 |
+
|
179 |
+
def remove_parametrizations(self):
|
180 |
+
for conv in self.convs1:
|
181 |
+
conv.remove_parametrizations()
|
182 |
+
for conv in self.convs2:
|
183 |
+
conv.remove_parametrizations()
|
184 |
+
|
185 |
+
|
186 |
+
class ParallelBlock(nn.Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
channels: int,
|
190 |
+
kernel_sizes: tuple[int] = (3, 7, 11),
|
191 |
+
dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
|
195 |
+
assert len(kernel_sizes) == len(dilation_sizes)
|
196 |
+
|
197 |
+
self.blocks = nn.ModuleList()
|
198 |
+
for k, d in zip(kernel_sizes, dilation_sizes):
|
199 |
+
self.blocks.append(ResBlock1(channels, k, d))
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
|
203 |
+
|
204 |
+
def remove_parametrizations(self):
|
205 |
+
for block in self.blocks:
|
206 |
+
block.remove_parametrizations()
|
207 |
+
|
208 |
+
|
209 |
+
class HiFiGANGenerator(nn.Module):
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
*,
|
213 |
+
hop_length: int = 512,
|
214 |
+
upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
|
215 |
+
upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
|
216 |
+
resblock_kernel_sizes: tuple[int] = (3, 7, 11),
|
217 |
+
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
218 |
+
num_mels: int = 128,
|
219 |
+
upsample_initial_channel: int = 512,
|
220 |
+
pre_conv_kernel_size: int = 7,
|
221 |
+
post_conv_kernel_size: int = 7,
|
222 |
+
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
223 |
+
):
|
224 |
+
super().__init__()
|
225 |
+
|
226 |
+
assert (
|
227 |
+
prod(upsample_rates) == hop_length
|
228 |
+
), f"hop_length must be {prod(upsample_rates)}"
|
229 |
+
|
230 |
+
self.conv_pre = FishConvNet(
|
231 |
+
num_mels,
|
232 |
+
upsample_initial_channel,
|
233 |
+
pre_conv_kernel_size,
|
234 |
+
stride=1,
|
235 |
+
).weight_norm()
|
236 |
+
|
237 |
+
self.num_upsamples = len(upsample_rates)
|
238 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
239 |
+
|
240 |
+
self.noise_convs = nn.ModuleList()
|
241 |
+
self.ups = nn.ModuleList()
|
242 |
+
|
243 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
244 |
+
self.ups.append(
|
245 |
+
FishTransConvNet(
|
246 |
+
upsample_initial_channel // (2**i),
|
247 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
248 |
+
k,
|
249 |
+
stride=u,
|
250 |
+
).weight_norm()
|
251 |
+
)
|
252 |
+
|
253 |
+
self.resblocks = nn.ModuleList()
|
254 |
+
for i in range(len(self.ups)):
|
255 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
256 |
+
self.resblocks.append(
|
257 |
+
ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
|
258 |
+
)
|
259 |
+
|
260 |
+
self.activation_post = post_activation()
|
261 |
+
self.conv_post = FishConvNet(
|
262 |
+
ch, 1, post_conv_kernel_size, stride=1
|
263 |
+
).weight_norm()
|
264 |
+
self.ups.apply(init_weights)
|
265 |
+
self.conv_post.apply(init_weights)
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
x = self.conv_pre(x)
|
269 |
+
|
270 |
+
for i in range(self.num_upsamples):
|
271 |
+
x = F.silu(x, inplace=True)
|
272 |
+
x = self.ups[i](x)
|
273 |
+
|
274 |
+
if self.training and self.checkpointing:
|
275 |
+
x = checkpoint(
|
276 |
+
self.resblocks[i],
|
277 |
+
x,
|
278 |
+
use_reentrant=False,
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
x = self.resblocks[i](x)
|
282 |
+
|
283 |
+
x = self.activation_post(x)
|
284 |
+
x = self.conv_post(x)
|
285 |
+
x = torch.tanh(x)
|
286 |
+
|
287 |
+
return x
|
288 |
+
|
289 |
+
def remove_parametrizations(self):
|
290 |
+
for up in self.ups:
|
291 |
+
up.remove_parametrizations()
|
292 |
+
for block in self.resblocks:
|
293 |
+
block.remove_parametrizations()
|
294 |
+
self.conv_pre.remove_parametrizations()
|
295 |
+
self.conv_post.remove_parametrizations()
|
296 |
+
|
297 |
+
|
298 |
+
# DropPath copied from timm library
|
299 |
+
def drop_path(
|
300 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
301 |
+
):
|
302 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
303 |
+
|
304 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
305 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
306 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
307 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
308 |
+
'survival rate' as the argument.
|
309 |
+
|
310 |
+
""" # noqa: E501
|
311 |
+
|
312 |
+
if drop_prob == 0.0 or not training:
|
313 |
+
return x
|
314 |
+
keep_prob = 1 - drop_prob
|
315 |
+
shape = (x.shape[0],) + (1,) * (
|
316 |
+
x.ndim - 1
|
317 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
318 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
319 |
+
if keep_prob > 0.0 and scale_by_keep:
|
320 |
+
random_tensor.div_(keep_prob)
|
321 |
+
return x * random_tensor
|
322 |
+
|
323 |
+
|
324 |
+
class DropPath(nn.Module):
|
325 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
326 |
+
|
327 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
328 |
+
super(DropPath, self).__init__()
|
329 |
+
self.drop_prob = drop_prob
|
330 |
+
self.scale_by_keep = scale_by_keep
|
331 |
+
|
332 |
+
def forward(self, x):
|
333 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
334 |
+
|
335 |
+
def extra_repr(self):
|
336 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
337 |
+
|
338 |
+
|
339 |
+
class LayerNorm(nn.Module):
|
340 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
341 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
342 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
343 |
+
with shape (batch_size, channels, height, width).
|
344 |
+
""" # noqa: E501
|
345 |
+
|
346 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
347 |
+
super().__init__()
|
348 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
349 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
350 |
+
self.eps = eps
|
351 |
+
self.data_format = data_format
|
352 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
353 |
+
raise NotImplementedError
|
354 |
+
self.normalized_shape = (normalized_shape,)
|
355 |
+
|
356 |
+
def forward(self, x):
|
357 |
+
if self.data_format == "channels_last":
|
358 |
+
return F.layer_norm(
|
359 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
360 |
+
)
|
361 |
+
elif self.data_format == "channels_first":
|
362 |
+
u = x.mean(1, keepdim=True)
|
363 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
364 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
365 |
+
x = self.weight[:, None] * x + self.bias[:, None]
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
|
370 |
+
class ConvNeXtBlock(nn.Module):
|
371 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
372 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
373 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
374 |
+
We use (2) as we find it slightly faster in PyTorch
|
375 |
+
|
376 |
+
Args:
|
377 |
+
dim (int): Number of input channels.
|
378 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
379 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
380 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
381 |
+
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
382 |
+
dilation (int): Dilation for depthwise conv. Default: 1.
|
383 |
+
""" # noqa: E501
|
384 |
+
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
dim: int,
|
388 |
+
drop_path: float = 0.0,
|
389 |
+
layer_scale_init_value: float = 1e-6,
|
390 |
+
mlp_ratio: float = 4.0,
|
391 |
+
kernel_size: int = 7,
|
392 |
+
dilation: int = 1,
|
393 |
+
):
|
394 |
+
super().__init__()
|
395 |
+
|
396 |
+
self.dwconv = FishConvNet(
|
397 |
+
dim,
|
398 |
+
dim,
|
399 |
+
kernel_size=kernel_size,
|
400 |
+
# padding=int(dilation * (kernel_size - 1) / 2),
|
401 |
+
groups=dim,
|
402 |
+
) # depthwise conv
|
403 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
404 |
+
self.pwconv1 = nn.Linear(
|
405 |
+
dim, int(mlp_ratio * dim)
|
406 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
407 |
+
self.act = nn.GELU()
|
408 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
409 |
+
self.gamma = (
|
410 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
411 |
+
if layer_scale_init_value > 0
|
412 |
+
else None
|
413 |
+
)
|
414 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
415 |
+
|
416 |
+
def forward(self, x, apply_residual: bool = True):
|
417 |
+
input = x
|
418 |
+
|
419 |
+
x = self.dwconv(x)
|
420 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
421 |
+
x = self.norm(x)
|
422 |
+
x = self.pwconv1(x)
|
423 |
+
x = self.act(x)
|
424 |
+
x = self.pwconv2(x)
|
425 |
+
|
426 |
+
if self.gamma is not None:
|
427 |
+
x = self.gamma * x
|
428 |
+
|
429 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
430 |
+
x = self.drop_path(x)
|
431 |
+
|
432 |
+
if apply_residual:
|
433 |
+
x = input + x
|
434 |
+
|
435 |
+
return x
|
436 |
+
|
437 |
+
|
438 |
+
class ConvNeXtEncoder(nn.Module):
|
439 |
+
def __init__(
|
440 |
+
self,
|
441 |
+
input_channels: int = 3,
|
442 |
+
depths: list[int] = [3, 3, 9, 3],
|
443 |
+
dims: list[int] = [96, 192, 384, 768],
|
444 |
+
drop_path_rate: float = 0.0,
|
445 |
+
layer_scale_init_value: float = 1e-6,
|
446 |
+
kernel_size: int = 7,
|
447 |
+
):
|
448 |
+
super().__init__()
|
449 |
+
assert len(depths) == len(dims)
|
450 |
+
|
451 |
+
self.downsample_layers = nn.ModuleList()
|
452 |
+
stem = nn.Sequential(
|
453 |
+
FishConvNet(
|
454 |
+
input_channels,
|
455 |
+
dims[0],
|
456 |
+
kernel_size=7,
|
457 |
+
# padding=3,
|
458 |
+
# padding_mode="replicate",
|
459 |
+
# padding_mode="zeros",
|
460 |
+
),
|
461 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
462 |
+
)
|
463 |
+
self.downsample_layers.append(stem)
|
464 |
+
|
465 |
+
for i in range(len(depths) - 1):
|
466 |
+
mid_layer = nn.Sequential(
|
467 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
468 |
+
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
469 |
+
)
|
470 |
+
self.downsample_layers.append(mid_layer)
|
471 |
+
|
472 |
+
self.stages = nn.ModuleList()
|
473 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
474 |
+
|
475 |
+
cur = 0
|
476 |
+
for i in range(len(depths)):
|
477 |
+
stage = nn.Sequential(
|
478 |
+
*[
|
479 |
+
ConvNeXtBlock(
|
480 |
+
dim=dims[i],
|
481 |
+
drop_path=dp_rates[cur + j],
|
482 |
+
layer_scale_init_value=layer_scale_init_value,
|
483 |
+
kernel_size=kernel_size,
|
484 |
+
)
|
485 |
+
for j in range(depths[i])
|
486 |
+
]
|
487 |
+
)
|
488 |
+
self.stages.append(stage)
|
489 |
+
cur += depths[i]
|
490 |
+
|
491 |
+
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
492 |
+
self.apply(self._init_weights)
|
493 |
+
|
494 |
+
def _init_weights(self, m):
|
495 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
496 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
497 |
+
nn.init.constant_(m.bias, 0)
|
498 |
+
|
499 |
+
def forward(
|
500 |
+
self,
|
501 |
+
x: torch.Tensor,
|
502 |
+
) -> torch.Tensor:
|
503 |
+
for i in range(len(self.downsample_layers)):
|
504 |
+
x = self.downsample_layers[i](x)
|
505 |
+
x = self.stages[i](x)
|
506 |
+
|
507 |
+
return self.norm(x)
|
508 |
+
|
509 |
+
|
510 |
+
class FireflyArchitecture(nn.Module):
|
511 |
+
def __init__(
|
512 |
+
self,
|
513 |
+
backbone: nn.Module,
|
514 |
+
head: nn.Module,
|
515 |
+
quantizer: nn.Module,
|
516 |
+
spec_transform: nn.Module,
|
517 |
+
):
|
518 |
+
super().__init__()
|
519 |
+
|
520 |
+
self.backbone = backbone
|
521 |
+
self.head = head
|
522 |
+
self.quantizer = quantizer
|
523 |
+
self.spec_transform = spec_transform
|
524 |
+
self.downsample_factor = math.prod(self.quantizer.downsample_factor)
|
525 |
+
|
526 |
+
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
|
527 |
+
if self.spec_transform is not None:
|
528 |
+
x = self.spec_transform(x)
|
529 |
+
|
530 |
+
x = self.backbone(x)
|
531 |
+
if mask is not None:
|
532 |
+
x = x * mask
|
533 |
+
|
534 |
+
if self.quantizer is not None:
|
535 |
+
vq_result = self.quantizer(x)
|
536 |
+
x = vq_result.z
|
537 |
+
|
538 |
+
if mask is not None:
|
539 |
+
x = x * mask
|
540 |
+
|
541 |
+
x = self.head(x, template=template)
|
542 |
+
|
543 |
+
if x.ndim == 2:
|
544 |
+
x = x[:, None, :]
|
545 |
+
|
546 |
+
if self.vq is not None:
|
547 |
+
return x, vq_result
|
548 |
+
|
549 |
+
return x
|
550 |
+
|
551 |
+
def encode(self, audios, audio_lengths):
|
552 |
+
audios = audios.float()
|
553 |
+
|
554 |
+
mels = self.spec_transform(audios)
|
555 |
+
mel_lengths = audio_lengths // self.spec_transform.hop_length
|
556 |
+
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
|
557 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
558 |
+
mels = mels * mel_masks_float_conv
|
559 |
+
|
560 |
+
# Encode
|
561 |
+
encoded_features = self.backbone(mels) * mel_masks_float_conv
|
562 |
+
feature_lengths = mel_lengths // self.downsample_factor
|
563 |
+
|
564 |
+
return self.quantizer.encode(encoded_features), feature_lengths
|
565 |
+
|
566 |
+
def decode(self, indices, feature_lengths) -> torch.Tensor:
|
567 |
+
mel_masks = sequence_mask(
|
568 |
+
feature_lengths * self.downsample_factor,
|
569 |
+
indices.shape[2] * self.downsample_factor,
|
570 |
+
)
|
571 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
572 |
+
audio_lengths = (
|
573 |
+
feature_lengths * self.downsample_factor * self.spec_transform.hop_length
|
574 |
+
)
|
575 |
+
|
576 |
+
audio_masks = sequence_mask(
|
577 |
+
audio_lengths,
|
578 |
+
indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
|
579 |
+
)
|
580 |
+
audio_masks_float_conv = audio_masks[:, None, :].float()
|
581 |
+
|
582 |
+
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
583 |
+
x = self.head(z) * audio_masks_float_conv
|
584 |
+
|
585 |
+
return x, audio_lengths
|
586 |
+
|
587 |
+
def remove_parametrizations(self):
|
588 |
+
if hasattr(self.backbone, "remove_parametrizations"):
|
589 |
+
self.backbone.remove_parametrizations()
|
590 |
+
|
591 |
+
if hasattr(self.head, "remove_parametrizations"):
|
592 |
+
self.head.remove_parametrizations()
|
593 |
+
|
594 |
+
@property
|
595 |
+
def device(self):
|
596 |
+
return next(self.parameters()).device
|
fish_speech/models/vqgan/modules/fsq.py
CHANGED
@@ -1,116 +1,116 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from einops import rearrange
|
7 |
-
from vector_quantize_pytorch import GroupedResidualFSQ
|
8 |
-
|
9 |
-
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
|
10 |
-
|
11 |
-
|
12 |
-
@dataclass
|
13 |
-
class FSQResult:
|
14 |
-
z: torch.Tensor
|
15 |
-
codes: torch.Tensor
|
16 |
-
latents: torch.Tensor
|
17 |
-
|
18 |
-
|
19 |
-
class DownsampleFiniteScalarQuantize(nn.Module):
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
input_dim: int = 512,
|
23 |
-
n_codebooks: int = 9,
|
24 |
-
n_groups: int = 1,
|
25 |
-
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
26 |
-
downsample_factor: tuple[int] = (2, 2),
|
27 |
-
downsample_dims: tuple[int] | None = None,
|
28 |
-
):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
if downsample_dims is None:
|
32 |
-
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
33 |
-
|
34 |
-
all_dims = (input_dim,) + tuple(downsample_dims)
|
35 |
-
|
36 |
-
self.residual_fsq = GroupedResidualFSQ(
|
37 |
-
dim=all_dims[-1],
|
38 |
-
levels=levels,
|
39 |
-
num_quantizers=n_codebooks,
|
40 |
-
groups=n_groups,
|
41 |
-
)
|
42 |
-
|
43 |
-
self.downsample_factor = downsample_factor
|
44 |
-
self.downsample_dims = downsample_dims
|
45 |
-
|
46 |
-
self.downsample = nn.Sequential(
|
47 |
-
*[
|
48 |
-
nn.Sequential(
|
49 |
-
FishConvNet(
|
50 |
-
all_dims[idx],
|
51 |
-
all_dims[idx + 1],
|
52 |
-
kernel_size=factor,
|
53 |
-
stride=factor,
|
54 |
-
),
|
55 |
-
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
56 |
-
)
|
57 |
-
for idx, factor in enumerate(downsample_factor)
|
58 |
-
]
|
59 |
-
)
|
60 |
-
|
61 |
-
self.upsample = nn.Sequential(
|
62 |
-
*[
|
63 |
-
nn.Sequential(
|
64 |
-
FishTransConvNet(
|
65 |
-
all_dims[idx + 1],
|
66 |
-
all_dims[idx],
|
67 |
-
kernel_size=factor,
|
68 |
-
stride=factor,
|
69 |
-
),
|
70 |
-
ConvNeXtBlock(dim=all_dims[idx]),
|
71 |
-
)
|
72 |
-
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
73 |
-
]
|
74 |
-
)
|
75 |
-
|
76 |
-
self.apply(self._init_weights)
|
77 |
-
|
78 |
-
def _init_weights(self, m):
|
79 |
-
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
80 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
81 |
-
nn.init.constant_(m.bias, 0)
|
82 |
-
|
83 |
-
def forward(self, z) -> FSQResult:
|
84 |
-
original_shape = z.shape
|
85 |
-
z = self.downsample(z)
|
86 |
-
quantized, indices = self.residual_fsq(z.mT)
|
87 |
-
result = FSQResult(
|
88 |
-
z=quantized.mT,
|
89 |
-
codes=indices.mT,
|
90 |
-
latents=z,
|
91 |
-
)
|
92 |
-
result.z = self.upsample(result.z)
|
93 |
-
|
94 |
-
# Pad or crop z to match original shape
|
95 |
-
diff = original_shape[-1] - result.z.shape[-1]
|
96 |
-
left = diff // 2
|
97 |
-
right = diff - left
|
98 |
-
|
99 |
-
if diff > 0:
|
100 |
-
result.z = F.pad(result.z, (left, right))
|
101 |
-
elif diff < 0:
|
102 |
-
result.z = result.z[..., left
|
103 |
-
|
104 |
-
return result
|
105 |
-
|
106 |
-
def encode(self, z):
|
107 |
-
z = self.downsample(z)
|
108 |
-
_, indices = self.residual_fsq(z.mT)
|
109 |
-
indices = rearrange(indices, "g b l r -> b (g r) l")
|
110 |
-
return indices
|
111 |
-
|
112 |
-
def decode(self, indices: torch.Tensor):
|
113 |
-
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
|
114 |
-
z_q = self.residual_fsq.get_output_from_indices(indices)
|
115 |
-
z_q = self.upsample(z_q.mT)
|
116 |
-
return z_q
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
8 |
+
|
9 |
+
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FSQResult:
|
14 |
+
z: torch.Tensor
|
15 |
+
codes: torch.Tensor
|
16 |
+
latents: torch.Tensor
|
17 |
+
|
18 |
+
|
19 |
+
class DownsampleFiniteScalarQuantize(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
input_dim: int = 512,
|
23 |
+
n_codebooks: int = 9,
|
24 |
+
n_groups: int = 1,
|
25 |
+
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
26 |
+
downsample_factor: tuple[int] = (2, 2),
|
27 |
+
downsample_dims: tuple[int] | None = None,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
if downsample_dims is None:
|
32 |
+
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
33 |
+
|
34 |
+
all_dims = (input_dim,) + tuple(downsample_dims)
|
35 |
+
|
36 |
+
self.residual_fsq = GroupedResidualFSQ(
|
37 |
+
dim=all_dims[-1],
|
38 |
+
levels=levels,
|
39 |
+
num_quantizers=n_codebooks,
|
40 |
+
groups=n_groups,
|
41 |
+
)
|
42 |
+
|
43 |
+
self.downsample_factor = downsample_factor
|
44 |
+
self.downsample_dims = downsample_dims
|
45 |
+
|
46 |
+
self.downsample = nn.Sequential(
|
47 |
+
*[
|
48 |
+
nn.Sequential(
|
49 |
+
FishConvNet(
|
50 |
+
all_dims[idx],
|
51 |
+
all_dims[idx + 1],
|
52 |
+
kernel_size=factor,
|
53 |
+
stride=factor,
|
54 |
+
),
|
55 |
+
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
56 |
+
)
|
57 |
+
for idx, factor in enumerate(downsample_factor)
|
58 |
+
]
|
59 |
+
)
|
60 |
+
|
61 |
+
self.upsample = nn.Sequential(
|
62 |
+
*[
|
63 |
+
nn.Sequential(
|
64 |
+
FishTransConvNet(
|
65 |
+
all_dims[idx + 1],
|
66 |
+
all_dims[idx],
|
67 |
+
kernel_size=factor,
|
68 |
+
stride=factor,
|
69 |
+
),
|
70 |
+
ConvNeXtBlock(dim=all_dims[idx]),
|
71 |
+
)
|
72 |
+
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
self.apply(self._init_weights)
|
77 |
+
|
78 |
+
def _init_weights(self, m):
|
79 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
80 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
81 |
+
nn.init.constant_(m.bias, 0)
|
82 |
+
|
83 |
+
def forward(self, z) -> FSQResult:
|
84 |
+
original_shape = z.shape
|
85 |
+
z = self.downsample(z)
|
86 |
+
quantized, indices = self.residual_fsq(z.mT)
|
87 |
+
result = FSQResult(
|
88 |
+
z=quantized.mT,
|
89 |
+
codes=indices.mT,
|
90 |
+
latents=z,
|
91 |
+
)
|
92 |
+
result.z = self.upsample(result.z)
|
93 |
+
|
94 |
+
# Pad or crop z to match original shape
|
95 |
+
diff = original_shape[-1] - result.z.shape[-1]
|
96 |
+
left = diff // 2
|
97 |
+
right = diff - left
|
98 |
+
|
99 |
+
if diff > 0:
|
100 |
+
result.z = F.pad(result.z, (left, right))
|
101 |
+
elif diff < 0:
|
102 |
+
result.z = result.z[..., -left:right]
|
103 |
+
|
104 |
+
return result
|
105 |
+
|
106 |
+
def encode(self, z):
|
107 |
+
z = self.downsample(z)
|
108 |
+
_, indices = self.residual_fsq(z.mT)
|
109 |
+
indices = rearrange(indices, "g b l r -> b (g r) l")
|
110 |
+
return indices
|
111 |
+
|
112 |
+
def decode(self, indices: torch.Tensor):
|
113 |
+
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
|
114 |
+
z_q = self.residual_fsq.get_output_from_indices(indices)
|
115 |
+
z_q = self.upsample(z_q.mT)
|
116 |
+
return z_q
|
fish_speech/models/vqgan/modules/reference.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
-
from .wavenet import WaveNet
|
8 |
-
|
9 |
-
|
10 |
-
class ReferenceEncoder(WaveNet):
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
input_channels: Optional[int] = None,
|
14 |
-
output_channels: Optional[int] = None,
|
15 |
-
residual_channels: int = 512,
|
16 |
-
residual_layers: int = 20,
|
17 |
-
dilation_cycle: Optional[int] = 4,
|
18 |
-
num_heads: int = 8,
|
19 |
-
latent_len: int = 4,
|
20 |
-
):
|
21 |
-
super().__init__(
|
22 |
-
input_channels=input_channels,
|
23 |
-
residual_channels=residual_channels,
|
24 |
-
residual_layers=residual_layers,
|
25 |
-
dilation_cycle=dilation_cycle,
|
26 |
-
)
|
27 |
-
|
28 |
-
self.head_dim = residual_channels // num_heads
|
29 |
-
self.num_heads = num_heads
|
30 |
-
|
31 |
-
self.latent_len = latent_len
|
32 |
-
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
|
33 |
-
|
34 |
-
self.q = nn.Linear(residual_channels, residual_channels, bias=True)
|
35 |
-
self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
|
36 |
-
self.q_norm = nn.LayerNorm(self.head_dim)
|
37 |
-
self.k_norm = nn.LayerNorm(self.head_dim)
|
38 |
-
self.proj = nn.Linear(residual_channels, residual_channels)
|
39 |
-
self.proj_drop = nn.Dropout(0.1)
|
40 |
-
|
41 |
-
self.norm = nn.LayerNorm(residual_channels)
|
42 |
-
self.mlp = nn.Sequential(
|
43 |
-
nn.Linear(residual_channels, residual_channels * 4),
|
44 |
-
nn.SiLU(),
|
45 |
-
nn.Linear(residual_channels * 4, residual_channels),
|
46 |
-
)
|
47 |
-
self.output_projection_attn = nn.Linear(residual_channels, output_channels)
|
48 |
-
|
49 |
-
torch.nn.init.trunc_normal_(self.latent, std=0.02)
|
50 |
-
self.apply(self.init_weights)
|
51 |
-
|
52 |
-
def init_weights(self, m):
|
53 |
-
if isinstance(m, nn.Linear):
|
54 |
-
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
55 |
-
if m.bias is not None:
|
56 |
-
torch.nn.init.constant_(m.bias, 0)
|
57 |
-
|
58 |
-
def forward(self, x, attn_mask=None):
|
59 |
-
x = super().forward(x).mT
|
60 |
-
B, N, C = x.shape
|
61 |
-
|
62 |
-
# Calculate mask
|
63 |
-
if attn_mask is not None:
|
64 |
-
assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
|
65 |
-
|
66 |
-
attn_mask = attn_mask[:, None, None, :].expand(
|
67 |
-
B, self.num_heads, self.latent_len, N
|
68 |
-
)
|
69 |
-
|
70 |
-
q_latent = self.latent.expand(B, -1, -1)
|
71 |
-
q = (
|
72 |
-
self.q(q_latent)
|
73 |
-
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
|
74 |
-
.transpose(1, 2)
|
75 |
-
)
|
76 |
-
|
77 |
-
kv = (
|
78 |
-
self.kv(x)
|
79 |
-
.reshape(B, N, 2, self.num_heads, self.head_dim)
|
80 |
-
.permute(2, 0, 3, 1, 4)
|
81 |
-
)
|
82 |
-
k, v = kv.unbind(0)
|
83 |
-
|
84 |
-
q, k = self.q_norm(q), self.k_norm(k)
|
85 |
-
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
86 |
-
|
87 |
-
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
88 |
-
x = self.proj(x)
|
89 |
-
x = self.proj_drop(x)
|
90 |
-
|
91 |
-
x = x + self.mlp(self.norm(x))
|
92 |
-
x = self.output_projection_attn(x)
|
93 |
-
x = x.mean(1)
|
94 |
-
|
95 |
-
return x
|
96 |
-
|
97 |
-
|
98 |
-
if __name__ == "__main__":
|
99 |
-
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
100 |
-
model = ReferenceEncoder(
|
101 |
-
input_channels=128,
|
102 |
-
output_channels=64,
|
103 |
-
residual_channels=384,
|
104 |
-
residual_layers=20,
|
105 |
-
dilation_cycle=4,
|
106 |
-
num_heads=8,
|
107 |
-
)
|
108 |
-
x = torch.randn(4, 128, 64)
|
109 |
-
mask = torch.ones(4, 64, dtype=torch.bool)
|
110 |
-
y = model(x, mask)
|
111 |
-
print(y.shape)
|
112 |
-
loss = F.mse_loss(y, torch.randn(4, 64))
|
113 |
-
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/modules/wavenet.py
DELETED
@@ -1,225 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from typing import Optional
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
|
9 |
-
class Mish(nn.Module):
|
10 |
-
def forward(self, x):
|
11 |
-
return x * torch.tanh(F.softplus(x))
|
12 |
-
|
13 |
-
|
14 |
-
class DiffusionEmbedding(nn.Module):
|
15 |
-
"""Diffusion Step Embedding"""
|
16 |
-
|
17 |
-
def __init__(self, d_denoiser):
|
18 |
-
super(DiffusionEmbedding, self).__init__()
|
19 |
-
self.dim = d_denoiser
|
20 |
-
|
21 |
-
def forward(self, x):
|
22 |
-
device = x.device
|
23 |
-
half_dim = self.dim // 2
|
24 |
-
emb = math.log(10000) / (half_dim - 1)
|
25 |
-
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
26 |
-
emb = x[:, None] * emb[None, :]
|
27 |
-
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
28 |
-
return emb
|
29 |
-
|
30 |
-
|
31 |
-
class LinearNorm(nn.Module):
|
32 |
-
"""LinearNorm Projection"""
|
33 |
-
|
34 |
-
def __init__(self, in_features, out_features, bias=False):
|
35 |
-
super(LinearNorm, self).__init__()
|
36 |
-
self.linear = nn.Linear(in_features, out_features, bias)
|
37 |
-
|
38 |
-
nn.init.xavier_uniform_(self.linear.weight)
|
39 |
-
if bias:
|
40 |
-
nn.init.constant_(self.linear.bias, 0.0)
|
41 |
-
|
42 |
-
def forward(self, x):
|
43 |
-
x = self.linear(x)
|
44 |
-
return x
|
45 |
-
|
46 |
-
|
47 |
-
class ConvNorm(nn.Module):
|
48 |
-
"""1D Convolution"""
|
49 |
-
|
50 |
-
def __init__(
|
51 |
-
self,
|
52 |
-
in_channels,
|
53 |
-
out_channels,
|
54 |
-
kernel_size=1,
|
55 |
-
stride=1,
|
56 |
-
padding=None,
|
57 |
-
dilation=1,
|
58 |
-
bias=True,
|
59 |
-
w_init_gain="linear",
|
60 |
-
):
|
61 |
-
super(ConvNorm, self).__init__()
|
62 |
-
|
63 |
-
if padding is None:
|
64 |
-
assert kernel_size % 2 == 1
|
65 |
-
padding = int(dilation * (kernel_size - 1) / 2)
|
66 |
-
|
67 |
-
self.conv = nn.Conv1d(
|
68 |
-
in_channels,
|
69 |
-
out_channels,
|
70 |
-
kernel_size=kernel_size,
|
71 |
-
stride=stride,
|
72 |
-
padding=padding,
|
73 |
-
dilation=dilation,
|
74 |
-
bias=bias,
|
75 |
-
)
|
76 |
-
nn.init.kaiming_normal_(self.conv.weight)
|
77 |
-
|
78 |
-
def forward(self, signal):
|
79 |
-
conv_signal = self.conv(signal)
|
80 |
-
|
81 |
-
return conv_signal
|
82 |
-
|
83 |
-
|
84 |
-
class ResidualBlock(nn.Module):
|
85 |
-
"""Residual Block"""
|
86 |
-
|
87 |
-
def __init__(
|
88 |
-
self,
|
89 |
-
residual_channels,
|
90 |
-
use_linear_bias=False,
|
91 |
-
dilation=1,
|
92 |
-
condition_channels=None,
|
93 |
-
):
|
94 |
-
super(ResidualBlock, self).__init__()
|
95 |
-
self.conv_layer = ConvNorm(
|
96 |
-
residual_channels,
|
97 |
-
2 * residual_channels,
|
98 |
-
kernel_size=3,
|
99 |
-
stride=1,
|
100 |
-
padding=dilation,
|
101 |
-
dilation=dilation,
|
102 |
-
)
|
103 |
-
|
104 |
-
if condition_channels is not None:
|
105 |
-
self.diffusion_projection = LinearNorm(
|
106 |
-
residual_channels, residual_channels, use_linear_bias
|
107 |
-
)
|
108 |
-
self.condition_projection = ConvNorm(
|
109 |
-
condition_channels, 2 * residual_channels, kernel_size=1
|
110 |
-
)
|
111 |
-
|
112 |
-
self.output_projection = ConvNorm(
|
113 |
-
residual_channels, 2 * residual_channels, kernel_size=1
|
114 |
-
)
|
115 |
-
|
116 |
-
def forward(self, x, condition=None, diffusion_step=None):
|
117 |
-
y = x
|
118 |
-
|
119 |
-
if diffusion_step is not None:
|
120 |
-
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
121 |
-
y = y + diffusion_step
|
122 |
-
|
123 |
-
y = self.conv_layer(y)
|
124 |
-
|
125 |
-
if condition is not None:
|
126 |
-
condition = self.condition_projection(condition)
|
127 |
-
y = y + condition
|
128 |
-
|
129 |
-
gate, filter = torch.chunk(y, 2, dim=1)
|
130 |
-
y = torch.sigmoid(gate) * torch.tanh(filter)
|
131 |
-
|
132 |
-
y = self.output_projection(y)
|
133 |
-
residual, skip = torch.chunk(y, 2, dim=1)
|
134 |
-
|
135 |
-
return (x + residual) / math.sqrt(2.0), skip
|
136 |
-
|
137 |
-
|
138 |
-
class WaveNet(nn.Module):
|
139 |
-
def __init__(
|
140 |
-
self,
|
141 |
-
input_channels: Optional[int] = None,
|
142 |
-
output_channels: Optional[int] = None,
|
143 |
-
residual_channels: int = 512,
|
144 |
-
residual_layers: int = 20,
|
145 |
-
dilation_cycle: Optional[int] = 4,
|
146 |
-
is_diffusion: bool = False,
|
147 |
-
condition_channels: Optional[int] = None,
|
148 |
-
):
|
149 |
-
super().__init__()
|
150 |
-
|
151 |
-
# Input projection
|
152 |
-
self.input_projection = None
|
153 |
-
if input_channels is not None and input_channels != residual_channels:
|
154 |
-
self.input_projection = ConvNorm(
|
155 |
-
input_channels, residual_channels, kernel_size=1
|
156 |
-
)
|
157 |
-
|
158 |
-
if input_channels is None:
|
159 |
-
input_channels = residual_channels
|
160 |
-
|
161 |
-
self.input_channels = input_channels
|
162 |
-
|
163 |
-
# Residual layers
|
164 |
-
self.residual_layers = nn.ModuleList(
|
165 |
-
[
|
166 |
-
ResidualBlock(
|
167 |
-
residual_channels=residual_channels,
|
168 |
-
use_linear_bias=False,
|
169 |
-
dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
|
170 |
-
condition_channels=condition_channels,
|
171 |
-
)
|
172 |
-
for i in range(residual_layers)
|
173 |
-
]
|
174 |
-
)
|
175 |
-
|
176 |
-
# Skip projection
|
177 |
-
self.skip_projection = ConvNorm(
|
178 |
-
residual_channels, residual_channels, kernel_size=1
|
179 |
-
)
|
180 |
-
|
181 |
-
# Output projection
|
182 |
-
self.output_projection = None
|
183 |
-
if output_channels is not None and output_channels != residual_channels:
|
184 |
-
self.output_projection = ConvNorm(
|
185 |
-
residual_channels, output_channels, kernel_size=1
|
186 |
-
)
|
187 |
-
|
188 |
-
if is_diffusion:
|
189 |
-
self.diffusion_embedding = DiffusionEmbedding(residual_channels)
|
190 |
-
self.mlp = nn.Sequential(
|
191 |
-
LinearNorm(residual_channels, residual_channels * 4, False),
|
192 |
-
Mish(),
|
193 |
-
LinearNorm(residual_channels * 4, residual_channels, False),
|
194 |
-
)
|
195 |
-
|
196 |
-
self.apply(self._init_weights)
|
197 |
-
|
198 |
-
def _init_weights(self, m):
|
199 |
-
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
200 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
201 |
-
if getattr(m, "bias", None) is not None:
|
202 |
-
nn.init.constant_(m.bias, 0)
|
203 |
-
|
204 |
-
def forward(self, x, t=None, condition=None):
|
205 |
-
if self.input_projection is not None:
|
206 |
-
x = self.input_projection(x)
|
207 |
-
x = F.silu(x)
|
208 |
-
|
209 |
-
if t is not None:
|
210 |
-
t = self.diffusion_embedding(t)
|
211 |
-
t = self.mlp(t)
|
212 |
-
|
213 |
-
skip = []
|
214 |
-
for layer in self.residual_layers:
|
215 |
-
x, skip_connection = layer(x, condition, t)
|
216 |
-
skip.append(skip_connection)
|
217 |
-
|
218 |
-
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
|
219 |
-
x = self.skip_projection(x)
|
220 |
-
|
221 |
-
if self.output_projection is not None:
|
222 |
-
x = F.silu(x)
|
223 |
-
x = self.output_projection(x)
|
224 |
-
|
225 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/spectrogram.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torchaudio.functional as F
|
3 |
-
from torch import Tensor, nn
|
4 |
-
from torchaudio.transforms import MelScale
|
5 |
-
|
6 |
-
|
7 |
-
class LinearSpectrogram(nn.Module):
|
8 |
-
def __init__(
|
9 |
-
self,
|
10 |
-
n_fft=2048,
|
11 |
-
win_length=2048,
|
12 |
-
hop_length=512,
|
13 |
-
center=False,
|
14 |
-
mode="pow2_sqrt",
|
15 |
-
):
|
16 |
-
super().__init__()
|
17 |
-
|
18 |
-
self.n_fft = n_fft
|
19 |
-
self.win_length = win_length
|
20 |
-
self.hop_length = hop_length
|
21 |
-
self.center = center
|
22 |
-
self.mode = mode
|
23 |
-
|
24 |
-
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
|
25 |
-
|
26 |
-
def forward(self, y: Tensor) -> Tensor:
|
27 |
-
if y.ndim == 3:
|
28 |
-
y = y.squeeze(1)
|
29 |
-
|
30 |
-
y = torch.nn.functional.pad(
|
31 |
-
y.unsqueeze(1),
|
32 |
-
(
|
33 |
-
(self.win_length - self.hop_length) // 2,
|
34 |
-
(self.win_length - self.hop_length + 1) // 2,
|
35 |
-
),
|
36 |
-
mode="reflect",
|
37 |
-
).squeeze(1)
|
38 |
-
|
39 |
-
spec = torch.stft(
|
40 |
-
y,
|
41 |
-
self.n_fft,
|
42 |
-
hop_length=self.hop_length,
|
43 |
-
win_length=self.win_length,
|
44 |
-
window=self.window,
|
45 |
-
center=self.center,
|
46 |
-
pad_mode="reflect",
|
47 |
-
normalized=False,
|
48 |
-
onesided=True,
|
49 |
-
return_complex=True,
|
50 |
-
)
|
51 |
-
|
52 |
-
spec = torch.view_as_real(spec)
|
53 |
-
|
54 |
-
if self.mode == "pow2_sqrt":
|
55 |
-
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
56 |
-
|
57 |
-
return spec
|
58 |
-
|
59 |
-
|
60 |
-
class LogMelSpectrogram(nn.Module):
|
61 |
-
def __init__(
|
62 |
-
self,
|
63 |
-
sample_rate=44100,
|
64 |
-
n_fft=2048,
|
65 |
-
win_length=2048,
|
66 |
-
hop_length=512,
|
67 |
-
n_mels=128,
|
68 |
-
center=False,
|
69 |
-
f_min=0.0,
|
70 |
-
f_max=None,
|
71 |
-
):
|
72 |
-
super().__init__()
|
73 |
-
|
74 |
-
self.sample_rate = sample_rate
|
75 |
-
self.n_fft = n_fft
|
76 |
-
self.win_length = win_length
|
77 |
-
self.hop_length = hop_length
|
78 |
-
self.center = center
|
79 |
-
self.n_mels = n_mels
|
80 |
-
self.f_min = f_min
|
81 |
-
self.f_max = f_max or float(sample_rate // 2)
|
82 |
-
|
83 |
-
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
84 |
-
|
85 |
-
fb = F.melscale_fbanks(
|
86 |
-
n_freqs=self.n_fft // 2 + 1,
|
87 |
-
f_min=self.f_min,
|
88 |
-
f_max=self.f_max,
|
89 |
-
n_mels=self.n_mels,
|
90 |
-
sample_rate=self.sample_rate,
|
91 |
-
norm="slaney",
|
92 |
-
mel_scale="slaney",
|
93 |
-
)
|
94 |
-
self.register_buffer(
|
95 |
-
"fb",
|
96 |
-
fb,
|
97 |
-
persistent=False,
|
98 |
-
)
|
99 |
-
|
100 |
-
def compress(self, x: Tensor) -> Tensor:
|
101 |
-
return torch.log(torch.clamp(x, min=1e-5))
|
102 |
-
|
103 |
-
def decompress(self, x: Tensor) -> Tensor:
|
104 |
-
return torch.exp(x)
|
105 |
-
|
106 |
-
def apply_mel_scale(self, x: Tensor) -> Tensor:
|
107 |
-
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
|
108 |
-
|
109 |
-
def forward(
|
110 |
-
self, x: Tensor, return_linear: bool = False, sample_rate: int = None
|
111 |
-
) -> Tensor:
|
112 |
-
if sample_rate is not None and sample_rate != self.sample_rate:
|
113 |
-
x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
|
114 |
-
|
115 |
-
linear = self.spectrogram(x)
|
116 |
-
x = self.apply_mel_scale(linear)
|
117 |
-
x = self.compress(x)
|
118 |
-
|
119 |
-
if return_linear:
|
120 |
-
return x, self.compress(linear)
|
121 |
-
|
122 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/utils.py
CHANGED
@@ -1,94 +1,94 @@
|
|
1 |
-
import matplotlib
|
2 |
-
import torch
|
3 |
-
from matplotlib import pyplot as plt
|
4 |
-
|
5 |
-
matplotlib.use("Agg")
|
6 |
-
|
7 |
-
|
8 |
-
def convert_pad_shape(pad_shape):
|
9 |
-
l = pad_shape[::-1]
|
10 |
-
pad_shape = [item for sublist in l for item in sublist]
|
11 |
-
return pad_shape
|
12 |
-
|
13 |
-
|
14 |
-
def sequence_mask(length, max_length=None):
|
15 |
-
if max_length is None:
|
16 |
-
max_length = length.max()
|
17 |
-
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
18 |
-
return x.unsqueeze(0) < length.unsqueeze(1)
|
19 |
-
|
20 |
-
|
21 |
-
def init_weights(m, mean=0.0, std=0.01):
|
22 |
-
classname = m.__class__.__name__
|
23 |
-
if classname.find("Conv") != -1:
|
24 |
-
m.weight.data.normal_(mean, std)
|
25 |
-
|
26 |
-
|
27 |
-
def get_padding(kernel_size, dilation=1):
|
28 |
-
return int((kernel_size * dilation - dilation) / 2)
|
29 |
-
|
30 |
-
|
31 |
-
def plot_mel(data, titles=None):
|
32 |
-
fig, axes = plt.subplots(len(data), 1, squeeze=False)
|
33 |
-
|
34 |
-
if titles is None:
|
35 |
-
titles = [None for i in range(len(data))]
|
36 |
-
|
37 |
-
plt.tight_layout()
|
38 |
-
|
39 |
-
for i in range(len(data)):
|
40 |
-
mel = data[i]
|
41 |
-
|
42 |
-
if isinstance(mel, torch.Tensor):
|
43 |
-
mel = mel.float().detach().cpu().numpy()
|
44 |
-
|
45 |
-
axes[i][0].imshow(mel, origin="lower")
|
46 |
-
axes[i][0].set_aspect(2.5, adjustable="box")
|
47 |
-
axes[i][0].set_ylim(0, mel.shape[0])
|
48 |
-
axes[i][0].set_title(titles[i], fontsize="medium")
|
49 |
-
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
|
50 |
-
axes[i][0].set_anchor("W")
|
51 |
-
|
52 |
-
return fig
|
53 |
-
|
54 |
-
|
55 |
-
def slice_segments(x, ids_str, segment_size=4):
|
56 |
-
ret = torch.zeros_like(x[:, :, :segment_size])
|
57 |
-
for i in range(x.size(0)):
|
58 |
-
idx_str = ids_str[i]
|
59 |
-
idx_end = idx_str + segment_size
|
60 |
-
ret[i] = x[i, :, idx_str:idx_end]
|
61 |
-
|
62 |
-
return ret
|
63 |
-
|
64 |
-
|
65 |
-
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
66 |
-
b, d, t = x.size()
|
67 |
-
if x_lengths is None:
|
68 |
-
x_lengths = t
|
69 |
-
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
|
70 |
-
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
71 |
-
ret = slice_segments(x, ids_str, segment_size)
|
72 |
-
return ret, ids_str
|
73 |
-
|
74 |
-
|
75 |
-
@torch.jit.script
|
76 |
-
def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
|
77 |
-
n_channels_int = n_channels[0]
|
78 |
-
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
79 |
-
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
80 |
-
acts = t_act * s_act
|
81 |
-
|
82 |
-
return acts
|
83 |
-
|
84 |
-
|
85 |
-
def avg_with_mask(x, mask):
|
86 |
-
assert mask.dtype == torch.float, "Mask should be float"
|
87 |
-
|
88 |
-
if mask.ndim == 2:
|
89 |
-
mask = mask.unsqueeze(1)
|
90 |
-
|
91 |
-
if mask.shape[1] == 1:
|
92 |
-
mask = mask.expand_as(x)
|
93 |
-
|
94 |
-
return (x * mask).sum() / mask.sum()
|
|
|
1 |
+
import matplotlib
|
2 |
+
import torch
|
3 |
+
from matplotlib import pyplot as plt
|
4 |
+
|
5 |
+
matplotlib.use("Agg")
|
6 |
+
|
7 |
+
|
8 |
+
def convert_pad_shape(pad_shape):
|
9 |
+
l = pad_shape[::-1]
|
10 |
+
pad_shape = [item for sublist in l for item in sublist]
|
11 |
+
return pad_shape
|
12 |
+
|
13 |
+
|
14 |
+
def sequence_mask(length, max_length=None):
|
15 |
+
if max_length is None:
|
16 |
+
max_length = length.max()
|
17 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
18 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
19 |
+
|
20 |
+
|
21 |
+
def init_weights(m, mean=0.0, std=0.01):
|
22 |
+
classname = m.__class__.__name__
|
23 |
+
if classname.find("Conv") != -1:
|
24 |
+
m.weight.data.normal_(mean, std)
|
25 |
+
|
26 |
+
|
27 |
+
def get_padding(kernel_size, dilation=1):
|
28 |
+
return int((kernel_size * dilation - dilation) / 2)
|
29 |
+
|
30 |
+
|
31 |
+
def plot_mel(data, titles=None):
|
32 |
+
fig, axes = plt.subplots(len(data), 1, squeeze=False)
|
33 |
+
|
34 |
+
if titles is None:
|
35 |
+
titles = [None for i in range(len(data))]
|
36 |
+
|
37 |
+
plt.tight_layout()
|
38 |
+
|
39 |
+
for i in range(len(data)):
|
40 |
+
mel = data[i]
|
41 |
+
|
42 |
+
if isinstance(mel, torch.Tensor):
|
43 |
+
mel = mel.float().detach().cpu().numpy()
|
44 |
+
|
45 |
+
axes[i][0].imshow(mel, origin="lower")
|
46 |
+
axes[i][0].set_aspect(2.5, adjustable="box")
|
47 |
+
axes[i][0].set_ylim(0, mel.shape[0])
|
48 |
+
axes[i][0].set_title(titles[i], fontsize="medium")
|
49 |
+
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
|
50 |
+
axes[i][0].set_anchor("W")
|
51 |
+
|
52 |
+
return fig
|
53 |
+
|
54 |
+
|
55 |
+
def slice_segments(x, ids_str, segment_size=4):
|
56 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
57 |
+
for i in range(x.size(0)):
|
58 |
+
idx_str = ids_str[i]
|
59 |
+
idx_end = idx_str + segment_size
|
60 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
61 |
+
|
62 |
+
return ret
|
63 |
+
|
64 |
+
|
65 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
66 |
+
b, d, t = x.size()
|
67 |
+
if x_lengths is None:
|
68 |
+
x_lengths = t
|
69 |
+
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
|
70 |
+
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
71 |
+
ret = slice_segments(x, ids_str, segment_size)
|
72 |
+
return ret, ids_str
|
73 |
+
|
74 |
+
|
75 |
+
@torch.jit.script
|
76 |
+
def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
|
77 |
+
n_channels_int = n_channels[0]
|
78 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
79 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
80 |
+
acts = t_act * s_act
|
81 |
+
|
82 |
+
return acts
|
83 |
+
|
84 |
+
|
85 |
+
def avg_with_mask(x, mask):
|
86 |
+
assert mask.dtype == torch.float, "Mask should be float"
|
87 |
+
|
88 |
+
if mask.ndim == 2:
|
89 |
+
mask = mask.unsqueeze(1)
|
90 |
+
|
91 |
+
if mask.shape[1] == 1:
|
92 |
+
mask = mask.expand_as(x)
|
93 |
+
|
94 |
+
return (x * mask).sum() / mask.sum()
|
fish_speech/scheduler.py
CHANGED
@@ -1,40 +1,40 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
|
4 |
-
def get_cosine_schedule_with_warmup_lr_lambda(
|
5 |
-
current_step: int,
|
6 |
-
*,
|
7 |
-
num_warmup_steps: int | float,
|
8 |
-
num_training_steps: int,
|
9 |
-
num_cycles: float = 0.5,
|
10 |
-
final_lr_ratio: float = 0.0,
|
11 |
-
):
|
12 |
-
if 0 < num_warmup_steps < 1: # float mode
|
13 |
-
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
14 |
-
|
15 |
-
if current_step < num_warmup_steps:
|
16 |
-
return float(current_step) / float(max(1, num_warmup_steps))
|
17 |
-
|
18 |
-
progress = float(current_step - num_warmup_steps) / float(
|
19 |
-
max(1, num_training_steps - num_warmup_steps)
|
20 |
-
)
|
21 |
-
|
22 |
-
return max(
|
23 |
-
final_lr_ratio,
|
24 |
-
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
def get_constant_schedule_with_warmup_lr_lambda(
|
29 |
-
current_step: int,
|
30 |
-
*,
|
31 |
-
num_warmup_steps: int | float,
|
32 |
-
num_training_steps: int | None = None,
|
33 |
-
):
|
34 |
-
if 0 < num_warmup_steps < 1: # float mode
|
35 |
-
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
36 |
-
|
37 |
-
if current_step < num_warmup_steps:
|
38 |
-
return float(current_step) / float(max(1, num_warmup_steps))
|
39 |
-
|
40 |
-
return 1.0
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
|
4 |
+
def get_cosine_schedule_with_warmup_lr_lambda(
|
5 |
+
current_step: int,
|
6 |
+
*,
|
7 |
+
num_warmup_steps: int | float,
|
8 |
+
num_training_steps: int,
|
9 |
+
num_cycles: float = 0.5,
|
10 |
+
final_lr_ratio: float = 0.0,
|
11 |
+
):
|
12 |
+
if 0 < num_warmup_steps < 1: # float mode
|
13 |
+
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
14 |
+
|
15 |
+
if current_step < num_warmup_steps:
|
16 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
17 |
+
|
18 |
+
progress = float(current_step - num_warmup_steps) / float(
|
19 |
+
max(1, num_training_steps - num_warmup_steps)
|
20 |
+
)
|
21 |
+
|
22 |
+
return max(
|
23 |
+
final_lr_ratio,
|
24 |
+
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def get_constant_schedule_with_warmup_lr_lambda(
|
29 |
+
current_step: int,
|
30 |
+
*,
|
31 |
+
num_warmup_steps: int | float,
|
32 |
+
num_training_steps: int | None = None,
|
33 |
+
):
|
34 |
+
if 0 < num_warmup_steps < 1: # float mode
|
35 |
+
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
36 |
+
|
37 |
+
if current_step < num_warmup_steps:
|
38 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
39 |
+
|
40 |
+
return 1.0
|
fish_speech/text/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from .clean import clean_text
|
2 |
-
from .spliter import split_text
|
3 |
-
|
4 |
-
__all__ = ["clean_text", "split_text"]
|
|
|
1 |
+
from .clean import clean_text
|
2 |
+
from .spliter import split_text
|
3 |
+
|
4 |
+
__all__ = ["clean_text", "split_text"]
|
fish_speech/text/chn_text_norm/.gitignore
CHANGED
@@ -1,114 +1,114 @@
|
|
1 |
-
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
-
*.py[cod]
|
4 |
-
*$py.class
|
5 |
-
|
6 |
-
# C extensions
|
7 |
-
*.so
|
8 |
-
|
9 |
-
# Distribution / packaging
|
10 |
-
.Python
|
11 |
-
build/
|
12 |
-
develop-eggs/
|
13 |
-
dist/
|
14 |
-
downloads/
|
15 |
-
eggs/
|
16 |
-
.eggs/
|
17 |
-
lib/
|
18 |
-
lib64/
|
19 |
-
parts/
|
20 |
-
sdist/
|
21 |
-
var/
|
22 |
-
wheels/
|
23 |
-
*.egg-info/
|
24 |
-
.installed.cfg
|
25 |
-
*.egg
|
26 |
-
MANIFEST
|
27 |
-
|
28 |
-
# PyInstaller
|
29 |
-
# Usually these files are written by a python script from a template
|
30 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
-
*.manifest
|
32 |
-
*.spec
|
33 |
-
|
34 |
-
# Installer logs
|
35 |
-
pip-log.txt
|
36 |
-
pip-delete-this-directory.txt
|
37 |
-
|
38 |
-
# Unit test / coverage reports
|
39 |
-
htmlcov/
|
40 |
-
.tox/
|
41 |
-
.coverage
|
42 |
-
.coverage.*
|
43 |
-
.cache
|
44 |
-
nosetests.xml
|
45 |
-
coverage.xml
|
46 |
-
*.cover
|
47 |
-
.hypothesis/
|
48 |
-
.pytest_cache/
|
49 |
-
|
50 |
-
# Translations
|
51 |
-
*.mo
|
52 |
-
*.pot
|
53 |
-
|
54 |
-
# Django stuff:
|
55 |
-
*.log
|
56 |
-
local_settings.py
|
57 |
-
db.sqlite3
|
58 |
-
|
59 |
-
# Flask stuff:
|
60 |
-
instance/
|
61 |
-
.webassets-cache
|
62 |
-
|
63 |
-
# Scrapy stuff:
|
64 |
-
.scrapy
|
65 |
-
|
66 |
-
# Sphinx documentation
|
67 |
-
docs/_build/
|
68 |
-
|
69 |
-
# PyBuilder
|
70 |
-
target/
|
71 |
-
|
72 |
-
# Jupyter Notebook
|
73 |
-
.ipynb_checkpoints
|
74 |
-
|
75 |
-
# pyenv
|
76 |
-
.python-version
|
77 |
-
|
78 |
-
# celery beat schedule file
|
79 |
-
celerybeat-schedule
|
80 |
-
|
81 |
-
# SageMath parsed files
|
82 |
-
*.sage.py
|
83 |
-
|
84 |
-
# Environments
|
85 |
-
.env
|
86 |
-
.venv
|
87 |
-
env/
|
88 |
-
venv/
|
89 |
-
ENV/
|
90 |
-
env.bak/
|
91 |
-
venv.bak/
|
92 |
-
|
93 |
-
# Spyder project settings
|
94 |
-
.spyderproject
|
95 |
-
.spyproject
|
96 |
-
|
97 |
-
# Rope project settings
|
98 |
-
.ropeproject
|
99 |
-
|
100 |
-
# mkdocs documentation
|
101 |
-
/site
|
102 |
-
|
103 |
-
# mypy
|
104 |
-
.mypy_cache/
|
105 |
-
|
106 |
-
# JetBrains PyCharm
|
107 |
-
.idea
|
108 |
-
|
109 |
-
# Customize
|
110 |
-
references
|
111 |
-
url.txt
|
112 |
-
|
113 |
-
# Git
|
114 |
-
.git
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
105 |
+
|
106 |
+
# JetBrains PyCharm
|
107 |
+
.idea
|
108 |
+
|
109 |
+
# Customize
|
110 |
+
references
|
111 |
+
url.txt
|
112 |
+
|
113 |
+
# Git
|
114 |
+
.git
|
fish_speech/text/chn_text_norm/README.md
CHANGED
@@ -1,36 +1,36 @@
|
|
1 |
-
# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
|
2 |
-
|
3 |
-
# Chn Text Norm
|
4 |
-
|
5 |
-
this is a repository for chinese text normalization (no longer maintained).
|
6 |
-
|
7 |
-
## Quick Start ##
|
8 |
-
|
9 |
-
### Git Clone Repo ###
|
10 |
-
|
11 |
-
git clone this repo to the root directory of your project which need to use it.
|
12 |
-
|
13 |
-
cd /path/to/proj
|
14 |
-
git clone https://github.com/Joee1995/chn-text-norm.git
|
15 |
-
|
16 |
-
after that, your doc tree should be:
|
17 |
-
```
|
18 |
-
proj # root of your project
|
19 |
-
|--- chn_text_norm # this chn-text-norm tool
|
20 |
-
|--- text.py
|
21 |
-
|--- ...
|
22 |
-
|--- text_normalize.py # your text normalization code
|
23 |
-
|--- ...
|
24 |
-
```
|
25 |
-
|
26 |
-
### How to Use ? ###
|
27 |
-
|
28 |
-
# text_normalize.py
|
29 |
-
from chn_text_norm.text import *
|
30 |
-
|
31 |
-
raw_text = 'your raw text'
|
32 |
-
text = Text(raw_text=raw_text).normalize()
|
33 |
-
|
34 |
-
### How to add quantums ###
|
35 |
-
|
36 |
-
打开test.py,然后你就知道怎么做了。
|
|
|
1 |
+
# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
|
2 |
+
|
3 |
+
# Chn Text Norm
|
4 |
+
|
5 |
+
this is a repository for chinese text normalization (no longer maintained).
|
6 |
+
|
7 |
+
## Quick Start ##
|
8 |
+
|
9 |
+
### Git Clone Repo ###
|
10 |
+
|
11 |
+
git clone this repo to the root directory of your project which need to use it.
|
12 |
+
|
13 |
+
cd /path/to/proj
|
14 |
+
git clone https://github.com/Joee1995/chn-text-norm.git
|
15 |
+
|
16 |
+
after that, your doc tree should be:
|
17 |
+
```
|
18 |
+
proj # root of your project
|
19 |
+
|--- chn_text_norm # this chn-text-norm tool
|
20 |
+
|--- text.py
|
21 |
+
|--- ...
|
22 |
+
|--- text_normalize.py # your text normalization code
|
23 |
+
|--- ...
|
24 |
+
```
|
25 |
+
|
26 |
+
### How to Use ? ###
|
27 |
+
|
28 |
+
# text_normalize.py
|
29 |
+
from chn_text_norm.text import *
|
30 |
+
|
31 |
+
raw_text = 'your raw text'
|
32 |
+
text = Text(raw_text=raw_text).normalize()
|
33 |
+
|
34 |
+
### How to add quantums ###
|
35 |
+
|
36 |
+
打开test.py,然后你就知道怎么做了。
|
fish_speech/text/chn_text_norm/basic_class.py
CHANGED
@@ -1,172 +1,172 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
"""基本类
|
3 |
-
中文字符类
|
4 |
-
中文数字/数位类
|
5 |
-
中文数字类
|
6 |
-
中文数位类
|
7 |
-
中文数字系统类
|
8 |
-
中文数学符号类
|
9 |
-
*中文其他符号类
|
10 |
-
"""
|
11 |
-
|
12 |
-
__author__ = "Zhiyang Zhou <[email protected]>"
|
13 |
-
__data__ = "2019-05-02"
|
14 |
-
|
15 |
-
from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
|
16 |
-
|
17 |
-
|
18 |
-
class ChineseChar(object):
|
19 |
-
"""
|
20 |
-
中文字符
|
21 |
-
每个字符对应简体和繁体,
|
22 |
-
e.g. 简体 = '负', 繁体 = '負'
|
23 |
-
|
24 |
-
"""
|
25 |
-
|
26 |
-
def __init__(self, simplified, traditional):
|
27 |
-
self.simplified = simplified
|
28 |
-
self.traditional = traditional
|
29 |
-
self.__repr__ = self.__str__
|
30 |
-
|
31 |
-
def __str__(self):
|
32 |
-
return self.simplified or self.traditional or None
|
33 |
-
|
34 |
-
def __repr__(self):
|
35 |
-
return self.__str__()
|
36 |
-
|
37 |
-
|
38 |
-
class ChineseNumberUnit(ChineseChar):
|
39 |
-
"""
|
40 |
-
中文数字/数位字符
|
41 |
-
每个字符除繁简体外还有一个额外的大写字符
|
42 |
-
e.g. '陆' 和 '陸'
|
43 |
-
"""
|
44 |
-
|
45 |
-
def __init__(self, power, simplified, traditional, big_s, big_t):
|
46 |
-
super(ChineseNumberUnit, self).__init__(simplified, traditional)
|
47 |
-
self.power = power
|
48 |
-
self.big_s = big_s
|
49 |
-
self.big_t = big_t
|
50 |
-
|
51 |
-
def __str__(self):
|
52 |
-
return "10^{}".format(self.power)
|
53 |
-
|
54 |
-
@classmethod
|
55 |
-
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
|
56 |
-
|
57 |
-
if small_unit:
|
58 |
-
return ChineseNumberUnit(
|
59 |
-
power=index + 1,
|
60 |
-
simplified=value[0],
|
61 |
-
traditional=value[1],
|
62 |
-
big_s=value[1],
|
63 |
-
big_t=value[1],
|
64 |
-
)
|
65 |
-
elif numbering_type == NUMBERING_TYPES[0]:
|
66 |
-
return ChineseNumberUnit(
|
67 |
-
power=index + 8,
|
68 |
-
simplified=value[0],
|
69 |
-
traditional=value[1],
|
70 |
-
big_s=value[0],
|
71 |
-
big_t=value[1],
|
72 |
-
)
|
73 |
-
elif numbering_type == NUMBERING_TYPES[1]:
|
74 |
-
return ChineseNumberUnit(
|
75 |
-
power=(index + 2) * 4,
|
76 |
-
simplified=value[0],
|
77 |
-
traditional=value[1],
|
78 |
-
big_s=value[0],
|
79 |
-
big_t=value[1],
|
80 |
-
)
|
81 |
-
elif numbering_type == NUMBERING_TYPES[2]:
|
82 |
-
return ChineseNumberUnit(
|
83 |
-
power=pow(2, index + 3),
|
84 |
-
simplified=value[0],
|
85 |
-
traditional=value[1],
|
86 |
-
big_s=value[0],
|
87 |
-
big_t=value[1],
|
88 |
-
)
|
89 |
-
else:
|
90 |
-
raise ValueError(
|
91 |
-
"Counting type should be in {0} ({1} provided).".format(
|
92 |
-
NUMBERING_TYPES, numbering_type
|
93 |
-
)
|
94 |
-
)
|
95 |
-
|
96 |
-
|
97 |
-
class ChineseNumberDigit(ChineseChar):
|
98 |
-
"""
|
99 |
-
中文数字字符
|
100 |
-
"""
|
101 |
-
|
102 |
-
def __init__(
|
103 |
-
self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
|
104 |
-
):
|
105 |
-
super(ChineseNumberDigit, self).__init__(simplified, traditional)
|
106 |
-
self.value = value
|
107 |
-
self.big_s = big_s
|
108 |
-
self.big_t = big_t
|
109 |
-
self.alt_s = alt_s
|
110 |
-
self.alt_t = alt_t
|
111 |
-
|
112 |
-
def __str__(self):
|
113 |
-
return str(self.value)
|
114 |
-
|
115 |
-
@classmethod
|
116 |
-
def create(cls, i, v):
|
117 |
-
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
|
118 |
-
|
119 |
-
|
120 |
-
class ChineseMath(ChineseChar):
|
121 |
-
"""
|
122 |
-
中文数位字符
|
123 |
-
"""
|
124 |
-
|
125 |
-
def __init__(self, simplified, traditional, symbol, expression=None):
|
126 |
-
super(ChineseMath, self).__init__(simplified, traditional)
|
127 |
-
self.symbol = symbol
|
128 |
-
self.expression = expression
|
129 |
-
self.big_s = simplified
|
130 |
-
self.big_t = traditional
|
131 |
-
|
132 |
-
|
133 |
-
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
|
134 |
-
|
135 |
-
|
136 |
-
class NumberSystem(object):
|
137 |
-
"""
|
138 |
-
中文数字系统
|
139 |
-
"""
|
140 |
-
|
141 |
-
pass
|
142 |
-
|
143 |
-
|
144 |
-
class MathSymbol(object):
|
145 |
-
"""
|
146 |
-
用于中文数字系统的数学符号 (繁/简体), e.g.
|
147 |
-
positive = ['正', '正']
|
148 |
-
negative = ['负', '負']
|
149 |
-
point = ['点', '點']
|
150 |
-
"""
|
151 |
-
|
152 |
-
def __init__(self, positive, negative, point):
|
153 |
-
self.positive = positive
|
154 |
-
self.negative = negative
|
155 |
-
self.point = point
|
156 |
-
|
157 |
-
def __iter__(self):
|
158 |
-
for v in self.__dict__.values():
|
159 |
-
yield v
|
160 |
-
|
161 |
-
|
162 |
-
# class OtherSymbol(object):
|
163 |
-
# """
|
164 |
-
# 其他符号
|
165 |
-
# """
|
166 |
-
#
|
167 |
-
# def __init__(self, sil):
|
168 |
-
# self.sil = sil
|
169 |
-
#
|
170 |
-
# def __iter__(self):
|
171 |
-
# for v in self.__dict__.values():
|
172 |
-
# yield v
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""基本类
|
3 |
+
中文字符类
|
4 |
+
中文数字/数位类
|
5 |
+
中文数字类
|
6 |
+
中文数位类
|
7 |
+
中文数字系统类
|
8 |
+
中文数学符号类
|
9 |
+
*中文其他符号类
|
10 |
+
"""
|
11 |
+
|
12 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
13 |
+
__data__ = "2019-05-02"
|
14 |
+
|
15 |
+
from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
|
16 |
+
|
17 |
+
|
18 |
+
class ChineseChar(object):
|
19 |
+
"""
|
20 |
+
中文字符
|
21 |
+
每个字符对应简体和繁体,
|
22 |
+
e.g. 简体 = '负', 繁体 = '負'
|
23 |
+
转换时可转换为简���或繁体
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, simplified, traditional):
|
27 |
+
self.simplified = simplified
|
28 |
+
self.traditional = traditional
|
29 |
+
self.__repr__ = self.__str__
|
30 |
+
|
31 |
+
def __str__(self):
|
32 |
+
return self.simplified or self.traditional or None
|
33 |
+
|
34 |
+
def __repr__(self):
|
35 |
+
return self.__str__()
|
36 |
+
|
37 |
+
|
38 |
+
class ChineseNumberUnit(ChineseChar):
|
39 |
+
"""
|
40 |
+
中文数字/数位字符
|
41 |
+
每个字符除繁简体外还有一个额外的大写字符
|
42 |
+
e.g. '陆' 和 '陸'
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, power, simplified, traditional, big_s, big_t):
|
46 |
+
super(ChineseNumberUnit, self).__init__(simplified, traditional)
|
47 |
+
self.power = power
|
48 |
+
self.big_s = big_s
|
49 |
+
self.big_t = big_t
|
50 |
+
|
51 |
+
def __str__(self):
|
52 |
+
return "10^{}".format(self.power)
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
|
56 |
+
|
57 |
+
if small_unit:
|
58 |
+
return ChineseNumberUnit(
|
59 |
+
power=index + 1,
|
60 |
+
simplified=value[0],
|
61 |
+
traditional=value[1],
|
62 |
+
big_s=value[1],
|
63 |
+
big_t=value[1],
|
64 |
+
)
|
65 |
+
elif numbering_type == NUMBERING_TYPES[0]:
|
66 |
+
return ChineseNumberUnit(
|
67 |
+
power=index + 8,
|
68 |
+
simplified=value[0],
|
69 |
+
traditional=value[1],
|
70 |
+
big_s=value[0],
|
71 |
+
big_t=value[1],
|
72 |
+
)
|
73 |
+
elif numbering_type == NUMBERING_TYPES[1]:
|
74 |
+
return ChineseNumberUnit(
|
75 |
+
power=(index + 2) * 4,
|
76 |
+
simplified=value[0],
|
77 |
+
traditional=value[1],
|
78 |
+
big_s=value[0],
|
79 |
+
big_t=value[1],
|
80 |
+
)
|
81 |
+
elif numbering_type == NUMBERING_TYPES[2]:
|
82 |
+
return ChineseNumberUnit(
|
83 |
+
power=pow(2, index + 3),
|
84 |
+
simplified=value[0],
|
85 |
+
traditional=value[1],
|
86 |
+
big_s=value[0],
|
87 |
+
big_t=value[1],
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise ValueError(
|
91 |
+
"Counting type should be in {0} ({1} provided).".format(
|
92 |
+
NUMBERING_TYPES, numbering_type
|
93 |
+
)
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
class ChineseNumberDigit(ChineseChar):
|
98 |
+
"""
|
99 |
+
中文数字字符
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(
|
103 |
+
self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
|
104 |
+
):
|
105 |
+
super(ChineseNumberDigit, self).__init__(simplified, traditional)
|
106 |
+
self.value = value
|
107 |
+
self.big_s = big_s
|
108 |
+
self.big_t = big_t
|
109 |
+
self.alt_s = alt_s
|
110 |
+
self.alt_t = alt_t
|
111 |
+
|
112 |
+
def __str__(self):
|
113 |
+
return str(self.value)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def create(cls, i, v):
|
117 |
+
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
|
118 |
+
|
119 |
+
|
120 |
+
class ChineseMath(ChineseChar):
|
121 |
+
"""
|
122 |
+
中文数位字符
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, simplified, traditional, symbol, expression=None):
|
126 |
+
super(ChineseMath, self).__init__(simplified, traditional)
|
127 |
+
self.symbol = symbol
|
128 |
+
self.expression = expression
|
129 |
+
self.big_s = simplified
|
130 |
+
self.big_t = traditional
|
131 |
+
|
132 |
+
|
133 |
+
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
|
134 |
+
|
135 |
+
|
136 |
+
class NumberSystem(object):
|
137 |
+
"""
|
138 |
+
中文数字系统
|
139 |
+
"""
|
140 |
+
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
144 |
+
class MathSymbol(object):
|
145 |
+
"""
|
146 |
+
用于中文数字系统的数学符号 (繁/简体), e.g.
|
147 |
+
positive = ['正', '正']
|
148 |
+
negative = ['负', '負']
|
149 |
+
point = ['点', '點']
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self, positive, negative, point):
|
153 |
+
self.positive = positive
|
154 |
+
self.negative = negative
|
155 |
+
self.point = point
|
156 |
+
|
157 |
+
def __iter__(self):
|
158 |
+
for v in self.__dict__.values():
|
159 |
+
yield v
|
160 |
+
|
161 |
+
|
162 |
+
# class OtherSymbol(object):
|
163 |
+
# """
|
164 |
+
# 其他符号
|
165 |
+
# """
|
166 |
+
#
|
167 |
+
# def __init__(self, sil):
|
168 |
+
# self.sil = sil
|
169 |
+
#
|
170 |
+
# def __iter__(self):
|
171 |
+
# for v in self.__dict__.values():
|
172 |
+
# yield v
|