Upload 116 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +47 -43
- .gitignore +28 -0
- README.md +12 -12
- app_v1v2.py +175 -0
- configs/astral_quantization/default_2048.yml +40 -0
- configs/astral_quantization/default_32.yml +40 -0
- configs/config.json +1 -0
- configs/inuse/.gitignore +0 -0
- configs/inuse/config.json +1 -0
- configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml +98 -0
- configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml +91 -0
- configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml +82 -0
- configs/v2/ar_base.yaml +0 -0
- configs/v2/dit_small.yaml +17 -0
- configs/v2/vc_wrapper.yaml +105 -0
- hf_utils.py +1 -1
- modules/__pycache__/audio.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-38.pyc +0 -0
- modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
- modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
- modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
- modules/__pycache__/rmvpe.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/bsq.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/convnext.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/default_model.cpython-310.pyc +0 -0
- modules/astral_quantization/bsq.py +569 -0
- modules/astral_quantization/convnext.py +209 -0
- modules/astral_quantization/default_model.py +73 -0
- modules/astral_quantization/transformer.py +254 -0
- modules/audio.py +82 -82
- modules/bigvgan/__pycache__/activations.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/env.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/meldataset.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/utils.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/cuda/activation1d.py +2 -2
- modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/.ninja_log +7 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp +0 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib +0 -0
- modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd +3 -0
- modules/bigvgan/alias_free_activation/cuda/build/build.ninja +38 -0
- modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc +0 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -1,43 +1,47 @@ | |
| 1 | 
            -
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            -
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            -
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            -
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            -
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            -
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            -
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            -
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            -
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            -
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            -
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            -
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            -
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            -
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            -
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            -
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            -
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            -
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            -
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            -
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            -
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            -
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            -
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            -
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            -
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            -
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            -
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            -
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            -
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            -
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            -
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            -
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            -
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            -
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            -
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            -
            examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            -
            examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            -
            examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            -
            examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            -
            examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            -
            examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            -
            examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            -
            examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd filter=lfs diff=lfs merge=lfs -text
         | 
| 47 | 
            +
            modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # general things to ignore
         | 
| 2 | 
            +
            .DS_Store
         | 
| 3 | 
            +
            build/
         | 
| 4 | 
            +
            build_contrib/
         | 
| 5 | 
            +
            dist/
         | 
| 6 | 
            +
            .cache/
         | 
| 7 | 
            +
            *.egg-info/
         | 
| 8 | 
            +
            *.egg
         | 
| 9 | 
            +
            *.py[cod]
         | 
| 10 | 
            +
            __pycache__/
         | 
| 11 | 
            +
            *.so
         | 
| 12 | 
            +
            *~
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # IDE
         | 
| 15 | 
            +
            .vscode/
         | 
| 16 | 
            +
            .idea/
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # misc
         | 
| 19 | 
            +
            checkpoints/
         | 
| 20 | 
            +
            test_waves/
         | 
| 21 | 
            +
            reconstructed/
         | 
| 22 | 
            +
            .python-version
         | 
| 23 | 
            +
            ruff.log
         | 
| 24 | 
            +
            /configs/inuse/
         | 
| 25 | 
            +
            runs/
         | 
| 26 | 
            +
            /garbages/
         | 
| 27 | 
            +
            /flagged/
         | 
| 28 | 
            +
            /experimental/
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,13 +1,13 @@ | |
| 1 | 
            -
            ---
         | 
| 2 | 
            -
            title: Seed Voice Conversion
         | 
| 3 | 
            -
            emoji: 🎤🔄
         | 
| 4 | 
            -
            colorFrom: green
         | 
| 5 | 
            -
            colorTo: green
         | 
| 6 | 
            -
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
            -
            app_file:  | 
| 9 | 
            -
            pinned: false
         | 
| 10 | 
            -
            license: gpl-3.0
         | 
| 11 | 
            -
            ---
         | 
| 12 | 
            -
             | 
| 13 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: Seed Voice Conversion
         | 
| 3 | 
            +
            emoji: 🎤🔄
         | 
| 4 | 
            +
            colorFrom: green
         | 
| 5 | 
            +
            colorTo: green
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 5.23.0
         | 
| 8 | 
            +
            app_file: app_v1v2.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            license: gpl-3.0
         | 
| 11 | 
            +
            ---
         | 
| 12 | 
            +
             | 
| 13 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app_v1v2.py
    ADDED
    
    | @@ -0,0 +1,175 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spaces
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import yaml
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            from seed_vc_wrapper import SeedVCWrapper
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # Set up device and torch configurations
         | 
| 9 | 
            +
            if torch.cuda.is_available():
         | 
| 10 | 
            +
                device = torch.device("cuda")
         | 
| 11 | 
            +
            elif torch.backends.mps.is_available():
         | 
| 12 | 
            +
                device = torch.device("mps")
         | 
| 13 | 
            +
            else:
         | 
| 14 | 
            +
                device = torch.device("cpu")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            torch._inductor.config.coordinate_descent_tuning = True
         | 
| 17 | 
            +
            torch._inductor.config.triton.unique_kernel_names = True
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            if hasattr(torch._inductor.config, "fx_graph_cache"):
         | 
| 20 | 
            +
                # Experimental feature to reduce compilation times, will be on by default in future
         | 
| 21 | 
            +
                torch._inductor.config.fx_graph_cache = True
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            dtype = torch.float16
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def load_v2_models(args):
         | 
| 26 | 
            +
                from hydra.utils import instantiate
         | 
| 27 | 
            +
                from omegaconf import DictConfig
         | 
| 28 | 
            +
                cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
         | 
| 29 | 
            +
                vc_wrapper = instantiate(cfg)
         | 
| 30 | 
            +
                vc_wrapper.load_checkpoints()
         | 
| 31 | 
            +
                vc_wrapper.to(device)
         | 
| 32 | 
            +
                vc_wrapper.eval()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                if args.compile:
         | 
| 37 | 
            +
                    vc_wrapper.compile_ar()
         | 
| 38 | 
            +
                    # vc_wrapper.compile_cfm()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                return vc_wrapper
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def create_v1_interface():
         | 
| 43 | 
            +
                # Initialize the V1 wrapper
         | 
| 44 | 
            +
                vc_wrapper = SeedVCWrapper()
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                # Set up Gradio interface
         | 
| 47 | 
            +
                description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
         | 
| 48 | 
            +
                               "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
         | 
| 49 | 
            +
                               "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
         | 
| 50 | 
            +
                               "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
         | 
| 51 | 
            +
                               "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                inputs = [
         | 
| 54 | 
            +
                    gr.Audio(type="filepath", label="Source Audio / 源音频"),
         | 
| 55 | 
            +
                    gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
         | 
| 56 | 
            +
                    gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数", 
         | 
| 57 | 
            +
                             info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
         | 
| 58 | 
            +
                    gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", 
         | 
| 59 | 
            +
                             info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
         | 
| 60 | 
            +
                    gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", 
         | 
| 61 | 
            +
                             info="has subtle influence / 有微小影响"),
         | 
| 62 | 
            +
                    gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, 
         | 
| 63 | 
            +
                               info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
         | 
| 64 | 
            +
                    gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
         | 
| 65 | 
            +
                               info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
         | 
| 66 | 
            +
                    gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, 
         | 
| 67 | 
            +
                             info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
         | 
| 68 | 
            +
                ]
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                examples = [
         | 
| 71 | 
            +
                    ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
         | 
| 72 | 
            +
                    ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0],
         | 
| 73 | 
            +
                    ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
         | 
| 74 | 
            +
                     "examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0],
         | 
| 75 | 
            +
                    ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
         | 
| 76 | 
            +
                     "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
         | 
| 77 | 
            +
                ]
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                outputs = [
         | 
| 80 | 
            +
                    gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
         | 
| 81 | 
            +
                    gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
         | 
| 82 | 
            +
                ]
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                return gr.Interface(
         | 
| 85 | 
            +
                    fn=vc_wrapper.convert_voice,
         | 
| 86 | 
            +
                    description=description,
         | 
| 87 | 
            +
                    inputs=inputs,
         | 
| 88 | 
            +
                    outputs=outputs,
         | 
| 89 | 
            +
                    title="Seed Voice Conversion V1 (Voice & Singing Voice Conversion)",
         | 
| 90 | 
            +
                    examples=examples,
         | 
| 91 | 
            +
                    cache_examples=False,
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            def create_v2_interface(vc_wrapper):
         | 
| 95 | 
            +
                # Set up Gradio interface
         | 
| 96 | 
            +
                description = ("Zero-shot voice/style conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
         | 
| 97 | 
            +
                               "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
         | 
| 98 | 
            +
                               "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
         | 
| 99 | 
            +
                               "Please click the 'convert style/emotion/accent' checkbox to convert the style, emotion, or accent of the source audio, or else only timbre conversion will be performed.<br> "
         | 
| 100 | 
            +
                               "Click the 'anonymization only' checkbox will ignore reference audio but convert source to an 'average voice' determined by model itself.<br> "
         | 
| 101 | 
            +
                               "无需训练的 zero-shot 语音/口音转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
         | 
| 102 | 
            +
                               "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。"
         | 
| 103 | 
            +
                               "<br>请勾选 'convert style/emotion/accent' 以转换源音频的风格、情感或口音,否则仅执行音色转换。<br>"
         | 
| 104 | 
            +
                               "勾选 'anonymization only' 会无视参考音频而将源音频转换为某种由模型自身决定的 '平均音色'。<br>"
         | 
| 105 | 
            +
                               
         | 
| 106 | 
            +
                               "Credits to [Vevo](https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo)"
         | 
| 107 | 
            +
                               )
         | 
| 108 | 
            +
                inputs = [
         | 
| 109 | 
            +
                    gr.Audio(type="filepath", label="Source Audio / 源音频"),
         | 
| 110 | 
            +
                    gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
         | 
| 111 | 
            +
                    gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数", 
         | 
| 112 | 
            +
                             info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
         | 
| 113 | 
            +
                    gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", 
         | 
| 114 | 
            +
                             info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
         | 
| 115 | 
            +
                    gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Intelligibility CFG Rate",
         | 
| 116 | 
            +
                             info="controls pronunciation intelligibility / 控制发音清晰度"),
         | 
| 117 | 
            +
                    gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Similarity CFG Rate",
         | 
| 118 | 
            +
                              info="controls similarity to reference audio / 控制与参考音频的相似度"),
         | 
| 119 | 
            +
                    gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
         | 
| 120 | 
            +
                             info="AR model sampling top P"),
         | 
| 121 | 
            +
                    gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
         | 
| 122 | 
            +
                             info="AR model sampling temperature"),
         | 
| 123 | 
            +
                    gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
         | 
| 124 | 
            +
                             info="AR model sampling repetition penalty"),
         | 
| 125 | 
            +
                    gr.Checkbox(label="convert style/emotion/accent", value=False),
         | 
| 126 | 
            +
                    gr.Checkbox(label="anonymization only", value=False),
         | 
| 127 | 
            +
                ]
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                examples = [
         | 
| 130 | 
            +
                    ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
         | 
| 131 | 
            +
                    ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
         | 
| 132 | 
            +
                ]
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                outputs = [
         | 
| 135 | 
            +
                    gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
         | 
| 136 | 
            +
                    gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
         | 
| 137 | 
            +
                ]
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                return gr.Interface(
         | 
| 140 | 
            +
                    fn=vc_wrapper.convert_voice_with_streaming,
         | 
| 141 | 
            +
                    description=description,
         | 
| 142 | 
            +
                    inputs=inputs,
         | 
| 143 | 
            +
                    outputs=outputs,
         | 
| 144 | 
            +
                    title="Seed Voice Conversion V2 (Voice & Style Conversion)",
         | 
| 145 | 
            +
                    examples=examples,
         | 
| 146 | 
            +
                    cache_examples=False,
         | 
| 147 | 
            +
                )
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            def main(args):
         | 
| 150 | 
            +
                # Load V2 models
         | 
| 151 | 
            +
                vc_wrapper_v2 = load_v2_models(args)
         | 
| 152 | 
            +
                
         | 
| 153 | 
            +
                # Create interfaces
         | 
| 154 | 
            +
                v1_interface = create_v1_interface()
         | 
| 155 | 
            +
                v2_interface = create_v2_interface(vc_wrapper_v2)
         | 
| 156 | 
            +
                
         | 
| 157 | 
            +
                # Create tabs
         | 
| 158 | 
            +
                with gr.Blocks(title="Seed Voice Conversion") as demo:
         | 
| 159 | 
            +
                    gr.Markdown("# Seed Voice Conversion")
         | 
| 160 | 
            +
                    gr.Markdown("Choose between V1 (Voice & Singing Voice Conversion) or V2 (Voice & Style Conversion)")
         | 
| 161 | 
            +
                    
         | 
| 162 | 
            +
                    with gr.Tabs():
         | 
| 163 | 
            +
                        with gr.TabItem("V2 - Voice & Style Conversion"):
         | 
| 164 | 
            +
                            v2_interface.render()
         | 
| 165 | 
            +
                        with gr.TabItem("V1 - Voice & Singing Voice Conversion"):
         | 
| 166 | 
            +
                            v1_interface.render()
         | 
| 167 | 
            +
                
         | 
| 168 | 
            +
                # Launch the combined interface
         | 
| 169 | 
            +
                demo.launch()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            if __name__ == "__main__":
         | 
| 172 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 173 | 
            +
                parser.add_argument("--compile", type=bool, default=True)
         | 
| 174 | 
            +
                args = parser.parse_args()
         | 
| 175 | 
            +
                main(args) 
         | 
    	
        configs/astral_quantization/default_2048.yml
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: modules.astral_quantization.default_model.AstralQuantizer
         | 
| 2 | 
            +
            tokenizer_name: "openai/whisper-small"
         | 
| 3 | 
            +
            ssl_model_name: "facebook/hubert-large-ll60k"
         | 
| 4 | 
            +
            ssl_output_layer: 18
         | 
| 5 | 
            +
            encoder:
         | 
| 6 | 
            +
              _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
         | 
| 7 | 
            +
              dim: 512
         | 
| 8 | 
            +
              num_blocks: 12
         | 
| 9 | 
            +
              intermediate_dim: 1536
         | 
| 10 | 
            +
              dilation: 1
         | 
| 11 | 
            +
              input_dim: 1024
         | 
| 12 | 
            +
            quantizer:
         | 
| 13 | 
            +
              _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
         | 
| 14 | 
            +
              codebook_size: 2048  # codebook size, must be a power of 2
         | 
| 15 | 
            +
              dim: 512
         | 
| 16 | 
            +
              entropy_loss_weight: 0.1
         | 
| 17 | 
            +
              diversity_gamma: 1.0
         | 
| 18 | 
            +
              spherical: True
         | 
| 19 | 
            +
              enable_entropy_loss: True
         | 
| 20 | 
            +
              soft_entropy_loss: True
         | 
| 21 | 
            +
            decoder:
         | 
| 22 | 
            +
              _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
         | 
| 23 | 
            +
              dim: 512
         | 
| 24 | 
            +
              num_blocks: 12
         | 
| 25 | 
            +
              intermediate_dim: 1536
         | 
| 26 | 
            +
              dilation: 1
         | 
| 27 | 
            +
              output_dim: 1024
         | 
| 28 | 
            +
              gin_channels: 192
         | 
| 29 | 
            +
            asr_decoder:
         | 
| 30 | 
            +
              _target_: modules.astral_quantization.asr_decoder.ASRDecoder
         | 
| 31 | 
            +
              hidden_dim: 768
         | 
| 32 | 
            +
              num_heads: 12
         | 
| 33 | 
            +
              depth: 12
         | 
| 34 | 
            +
              block_size: 4096
         | 
| 35 | 
            +
              in_channels: 512
         | 
| 36 | 
            +
              n_vocab: 51866
         | 
| 37 | 
            +
              bos_id: 50528
         | 
| 38 | 
            +
              eos_id: 50527
         | 
| 39 | 
            +
              dropout_rate: 0.0
         | 
| 40 | 
            +
              attn_dropout_rate: 0.0
         | 
    	
        configs/astral_quantization/default_32.yml
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: default_model.AstralQuantizer
         | 
| 2 | 
            +
            tokenizer_name: "openai/whisper-small"
         | 
| 3 | 
            +
            ssl_model_name: "facebook/hubert-large-ll60k"
         | 
| 4 | 
            +
            ssl_output_layer: 18
         | 
| 5 | 
            +
            encoder:
         | 
| 6 | 
            +
              _target_: modules.convnext.ConvNeXtV2Stage
         | 
| 7 | 
            +
              dim: 512
         | 
| 8 | 
            +
              num_blocks: 12
         | 
| 9 | 
            +
              intermediate_dim: 1536
         | 
| 10 | 
            +
              dilation: 1
         | 
| 11 | 
            +
              input_dim: 1024
         | 
| 12 | 
            +
            quantizer:
         | 
| 13 | 
            +
              _target_: modules.bsq.BinarySphericalQuantize
         | 
| 14 | 
            +
              codebook_size: 32  # codebook size, must be a power of 2
         | 
| 15 | 
            +
              dim: 512
         | 
| 16 | 
            +
              entropy_loss_weight: 0.1
         | 
| 17 | 
            +
              diversity_gamma: 1.0
         | 
| 18 | 
            +
              spherical: True
         | 
| 19 | 
            +
              enable_entropy_loss: True
         | 
| 20 | 
            +
              soft_entropy_loss: True
         | 
| 21 | 
            +
            decoder:
         | 
| 22 | 
            +
              _target_: modules.convnext.ConvNeXtV2Stage
         | 
| 23 | 
            +
              dim: 512
         | 
| 24 | 
            +
              num_blocks: 12
         | 
| 25 | 
            +
              intermediate_dim: 1536
         | 
| 26 | 
            +
              dilation: 1
         | 
| 27 | 
            +
              output_dim: 1024
         | 
| 28 | 
            +
              gin_channels: 192
         | 
| 29 | 
            +
            asr_decoder:
         | 
| 30 | 
            +
              _target_: modules.asr_decoder.ASRDecoder
         | 
| 31 | 
            +
              hidden_dim: 768
         | 
| 32 | 
            +
              num_heads: 12
         | 
| 33 | 
            +
              depth: 12
         | 
| 34 | 
            +
              block_size: 4096
         | 
| 35 | 
            +
              in_channels: 512
         | 
| 36 | 
            +
              n_vocab: 51866
         | 
| 37 | 
            +
              bos_id: 50528
         | 
| 38 | 
            +
              eos_id: 50527
         | 
| 39 | 
            +
              dropout_rate: 0.0
         | 
| 40 | 
            +
              attn_dropout_rate: 0.0
         | 
    	
        configs/config.json
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            {"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02}
         | 
    	
        configs/inuse/.gitignore
    ADDED
    
    | 
            File without changes
         | 
    	
        configs/inuse/config.json
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            {"reference_audio_path": "D:/seed-vc/examples/reference/trump_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS USB", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS USB", "sr_type": "sr_model", "diffusion_steps": 8.0, "inference_cfg_rate": 0.7, "max_prompt_length": 3.0, "block_time": 0.58, "crossfade_length": 0.04, "extra_time_ce": 2.5, "extra_time": 0.5, "extra_time_right": 0.02}
         | 
    	
        configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            log_dir: "./runs"
         | 
| 2 | 
            +
            save_freq: 1
         | 
| 3 | 
            +
            log_interval: 10
         | 
| 4 | 
            +
            save_interval: 1000
         | 
| 5 | 
            +
            device: "cuda"
         | 
| 6 | 
            +
            epochs: 1000 # number of epochs for first stage training (pre-training)
         | 
| 7 | 
            +
            batch_size: 1
         | 
| 8 | 
            +
            batch_length: 100 # maximum duration of audio in a batch (in seconds)
         | 
| 9 | 
            +
            max_len: 80 # maximum number of frames
         | 
| 10 | 
            +
            pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth"
         | 
| 11 | 
            +
            pretrained_encoder: ""
         | 
| 12 | 
            +
            load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            preprocess_params:
         | 
| 15 | 
            +
              sr: 44100
         | 
| 16 | 
            +
              spect_params:
         | 
| 17 | 
            +
                n_fft: 2048
         | 
| 18 | 
            +
                win_length: 2048
         | 
| 19 | 
            +
                hop_length: 512
         | 
| 20 | 
            +
                n_mels: 128
         | 
| 21 | 
            +
                fmin: 0
         | 
| 22 | 
            +
                fmax: "None"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            model_params:
         | 
| 25 | 
            +
              dit_type: "DiT" # uDiT or DiT
         | 
| 26 | 
            +
              reg_loss_type: "l1" # l1 or l2
         | 
| 27 | 
            +
             | 
| 28 | 
            +
              timbre_shifter:
         | 
| 29 | 
            +
                se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
         | 
| 30 | 
            +
                ckpt_path: './modules/openvoice/checkpoints_v2/converter'
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              vocoder:
         | 
| 33 | 
            +
                type: "bigvgan"
         | 
| 34 | 
            +
                name: "nvidia/bigvgan_v2_44khz_128band_512x"
         | 
| 35 | 
            +
             | 
| 36 | 
            +
              speech_tokenizer:
         | 
| 37 | 
            +
                type: 'whisper'
         | 
| 38 | 
            +
                name: "openai/whisper-small"
         | 
| 39 | 
            +
             | 
| 40 | 
            +
              style_encoder:
         | 
| 41 | 
            +
                dim: 192
         | 
| 42 | 
            +
                campplus_path: "campplus_cn_common.bin"
         | 
| 43 | 
            +
             | 
| 44 | 
            +
              DAC:
         | 
| 45 | 
            +
                encoder_dim: 64
         | 
| 46 | 
            +
                encoder_rates: [2, 5, 5, 6]
         | 
| 47 | 
            +
                decoder_dim: 1536
         | 
| 48 | 
            +
                decoder_rates: [ 6, 5, 5, 2 ]
         | 
| 49 | 
            +
                sr: 24000
         | 
| 50 | 
            +
             | 
| 51 | 
            +
              length_regulator:
         | 
| 52 | 
            +
                channels: 768
         | 
| 53 | 
            +
                is_discrete: false
         | 
| 54 | 
            +
                in_channels: 768
         | 
| 55 | 
            +
                content_codebook_size: 2048
         | 
| 56 | 
            +
                sampling_ratios: [1, 1, 1, 1]
         | 
| 57 | 
            +
                vector_quantize: false
         | 
| 58 | 
            +
                n_codebooks: 1
         | 
| 59 | 
            +
                quantizer_dropout: 0.0
         | 
| 60 | 
            +
                f0_condition: true
         | 
| 61 | 
            +
                n_f0_bins: 256
         | 
| 62 | 
            +
             | 
| 63 | 
            +
              DiT:
         | 
| 64 | 
            +
                hidden_dim: 768
         | 
| 65 | 
            +
                num_heads: 12
         | 
| 66 | 
            +
                depth: 17
         | 
| 67 | 
            +
                class_dropout_prob: 0.1
         | 
| 68 | 
            +
                block_size: 8192
         | 
| 69 | 
            +
                in_channels: 128
         | 
| 70 | 
            +
                style_condition: true
         | 
| 71 | 
            +
                final_layer_type: 'mlp'
         | 
| 72 | 
            +
                target: 'mel' # mel or codec
         | 
| 73 | 
            +
                content_dim: 768
         | 
| 74 | 
            +
                content_codebook_size: 1024
         | 
| 75 | 
            +
                content_type: 'discrete'
         | 
| 76 | 
            +
                f0_condition: true
         | 
| 77 | 
            +
                n_f0_bins: 256
         | 
| 78 | 
            +
                content_codebooks: 1
         | 
| 79 | 
            +
                is_causal: false
         | 
| 80 | 
            +
                long_skip_connection: false
         | 
| 81 | 
            +
                zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
         | 
| 82 | 
            +
                time_as_token: false
         | 
| 83 | 
            +
                style_as_token: false
         | 
| 84 | 
            +
                uvit_skip_connection: true
         | 
| 85 | 
            +
                add_resblock_in_transformer: false
         | 
| 86 | 
            +
             | 
| 87 | 
            +
              wavenet:
         | 
| 88 | 
            +
                hidden_dim: 768
         | 
| 89 | 
            +
                num_layers: 8
         | 
| 90 | 
            +
                kernel_size: 5
         | 
| 91 | 
            +
                dilation_rate: 1
         | 
| 92 | 
            +
                p_dropout: 0.2
         | 
| 93 | 
            +
                style_condition: true
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            loss_params:
         | 
| 96 | 
            +
              base_lr: 0.0001
         | 
| 97 | 
            +
              lambda_mel: 45
         | 
| 98 | 
            +
              lambda_kl: 1.0
         | 
    	
        configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
    ADDED
    
    | @@ -0,0 +1,91 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            log_dir: "./runs"
         | 
| 2 | 
            +
            save_freq: 1
         | 
| 3 | 
            +
            log_interval: 10
         | 
| 4 | 
            +
            save_interval: 1000
         | 
| 5 | 
            +
            device: "cuda"
         | 
| 6 | 
            +
            epochs: 1000 # number of epochs for first stage training (pre-training)
         | 
| 7 | 
            +
            batch_size: 2
         | 
| 8 | 
            +
            batch_length: 100 # maximum duration of audio in a batch (in seconds)
         | 
| 9 | 
            +
            max_len: 80 # maximum number of frames
         | 
| 10 | 
            +
            pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"
         | 
| 11 | 
            +
            pretrained_encoder: ""
         | 
| 12 | 
            +
            load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            preprocess_params:
         | 
| 15 | 
            +
              sr: 22050
         | 
| 16 | 
            +
              spect_params:
         | 
| 17 | 
            +
                n_fft: 1024
         | 
| 18 | 
            +
                win_length: 1024
         | 
| 19 | 
            +
                hop_length: 256
         | 
| 20 | 
            +
                n_mels: 80
         | 
| 21 | 
            +
                fmin: 0
         | 
| 22 | 
            +
                fmax: "None"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            model_params:
         | 
| 25 | 
            +
              dit_type: "DiT" # uDiT or DiT
         | 
| 26 | 
            +
              reg_loss_type: "l1" # l1 or l2
         | 
| 27 | 
            +
             | 
| 28 | 
            +
              timbre_shifter:
         | 
| 29 | 
            +
                se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
         | 
| 30 | 
            +
                ckpt_path: './modules/openvoice/checkpoints_v2/converter'
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              speech_tokenizer:
         | 
| 33 | 
            +
                type: 'whisper'
         | 
| 34 | 
            +
                name: "openai/whisper-small"
         | 
| 35 | 
            +
             | 
| 36 | 
            +
              style_encoder:
         | 
| 37 | 
            +
                dim: 192
         | 
| 38 | 
            +
                campplus_path: "campplus_cn_common.bin"
         | 
| 39 | 
            +
             | 
| 40 | 
            +
              vocoder:
         | 
| 41 | 
            +
                type: "bigvgan"
         | 
| 42 | 
            +
                name: "nvidia/bigvgan_v2_22khz_80band_256x"
         | 
| 43 | 
            +
             | 
| 44 | 
            +
              length_regulator:
         | 
| 45 | 
            +
                channels: 512
         | 
| 46 | 
            +
                is_discrete: false
         | 
| 47 | 
            +
                in_channels: 768
         | 
| 48 | 
            +
                content_codebook_size: 2048
         | 
| 49 | 
            +
                sampling_ratios: [1, 1, 1, 1]
         | 
| 50 | 
            +
                vector_quantize: false
         | 
| 51 | 
            +
                n_codebooks: 1
         | 
| 52 | 
            +
                quantizer_dropout: 0.0
         | 
| 53 | 
            +
                f0_condition: false
         | 
| 54 | 
            +
                n_f0_bins: 512
         | 
| 55 | 
            +
             | 
| 56 | 
            +
              DiT:
         | 
| 57 | 
            +
                hidden_dim: 512
         | 
| 58 | 
            +
                num_heads: 8
         | 
| 59 | 
            +
                depth: 13
         | 
| 60 | 
            +
                class_dropout_prob: 0.1
         | 
| 61 | 
            +
                block_size: 8192
         | 
| 62 | 
            +
                in_channels: 80
         | 
| 63 | 
            +
                style_condition: true
         | 
| 64 | 
            +
                final_layer_type: 'wavenet'
         | 
| 65 | 
            +
                target: 'mel' # mel or codec
         | 
| 66 | 
            +
                content_dim: 512
         | 
| 67 | 
            +
                content_codebook_size: 1024
         | 
| 68 | 
            +
                content_type: 'discrete'
         | 
| 69 | 
            +
                f0_condition: false
         | 
| 70 | 
            +
                n_f0_bins: 512
         | 
| 71 | 
            +
                content_codebooks: 1
         | 
| 72 | 
            +
                is_causal: false
         | 
| 73 | 
            +
                long_skip_connection: true
         | 
| 74 | 
            +
                zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
         | 
| 75 | 
            +
                time_as_token: false
         | 
| 76 | 
            +
                style_as_token: false
         | 
| 77 | 
            +
                uvit_skip_connection: true
         | 
| 78 | 
            +
                add_resblock_in_transformer: false
         | 
| 79 | 
            +
             | 
| 80 | 
            +
              wavenet:
         | 
| 81 | 
            +
                hidden_dim: 512
         | 
| 82 | 
            +
                num_layers: 8
         | 
| 83 | 
            +
                kernel_size: 5
         | 
| 84 | 
            +
                dilation_rate: 1
         | 
| 85 | 
            +
                p_dropout: 0.2
         | 
| 86 | 
            +
                style_condition: true
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            loss_params:
         | 
| 89 | 
            +
              base_lr: 0.0001
         | 
| 90 | 
            +
              lambda_mel: 45
         | 
| 91 | 
            +
              lambda_kl: 1.0
         | 
    	
        configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
    ADDED
    
    | @@ -0,0 +1,82 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            log_dir: "./runs/"
         | 
| 2 | 
            +
            save_freq: 1
         | 
| 3 | 
            +
            log_interval: 10
         | 
| 4 | 
            +
            save_interval: 500
         | 
| 5 | 
            +
            device: "cuda"
         | 
| 6 | 
            +
            epochs: 1000 # number of epochs for first stage training (pre-training)
         | 
| 7 | 
            +
            batch_size: 2
         | 
| 8 | 
            +
            batch_length: 100 # maximum duration of audio in a batch (in seconds)
         | 
| 9 | 
            +
            max_len: 80 # maximum number of frames
         | 
| 10 | 
            +
            pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
         | 
| 11 | 
            +
            pretrained_encoder: ""
         | 
| 12 | 
            +
            load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            preprocess_params:
         | 
| 15 | 
            +
              sr: 22050
         | 
| 16 | 
            +
              spect_params:
         | 
| 17 | 
            +
                n_fft: 1024
         | 
| 18 | 
            +
                win_length: 1024
         | 
| 19 | 
            +
                hop_length: 256
         | 
| 20 | 
            +
                n_mels: 80
         | 
| 21 | 
            +
                fmin: 0
         | 
| 22 | 
            +
                fmax: 8000
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            model_params:
         | 
| 25 | 
            +
              dit_type: "DiT" # uDiT or DiT
         | 
| 26 | 
            +
              reg_loss_type: "l1" # l1 or l2
         | 
| 27 | 
            +
              diffusion_type: "flow"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
              timbre_shifter:
         | 
| 30 | 
            +
                se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
         | 
| 31 | 
            +
                ckpt_path: './modules/openvoice/checkpoints_v2/converter'
         | 
| 32 | 
            +
             | 
| 33 | 
            +
              vocoder:
         | 
| 34 | 
            +
                type: "hifigan"
         | 
| 35 | 
            +
             | 
| 36 | 
            +
              speech_tokenizer:
         | 
| 37 | 
            +
                type: 'xlsr'
         | 
| 38 | 
            +
                output_layer: 12
         | 
| 39 | 
            +
                name: 'facebook/wav2vec2-xls-r-300m'
         | 
| 40 | 
            +
             | 
| 41 | 
            +
              style_encoder:
         | 
| 42 | 
            +
                dim: 192
         | 
| 43 | 
            +
                campplus_path: "campplus_cn_common.bin"
         | 
| 44 | 
            +
             | 
| 45 | 
            +
              length_regulator:
         | 
| 46 | 
            +
                channels: 384
         | 
| 47 | 
            +
                is_discrete: false
         | 
| 48 | 
            +
                in_channels: 1024
         | 
| 49 | 
            +
                content_codebook_size: 1024
         | 
| 50 | 
            +
                sampling_ratios: [1, 1, 1, 1]
         | 
| 51 | 
            +
                vector_quantize: false
         | 
| 52 | 
            +
                n_codebooks: 2
         | 
| 53 | 
            +
                quantizer_dropout: 0.0
         | 
| 54 | 
            +
                f0_condition: false
         | 
| 55 | 
            +
                n_f0_bins: 512
         | 
| 56 | 
            +
             | 
| 57 | 
            +
              DiT:
         | 
| 58 | 
            +
                hidden_dim: 384
         | 
| 59 | 
            +
                num_heads: 6
         | 
| 60 | 
            +
                depth: 9
         | 
| 61 | 
            +
                class_dropout_prob: 0.1
         | 
| 62 | 
            +
                block_size: 8192
         | 
| 63 | 
            +
                in_channels: 80
         | 
| 64 | 
            +
                style_condition: true
         | 
| 65 | 
            +
                final_layer_type: 'mlp'
         | 
| 66 | 
            +
                target: 'mel' # mel or betavae
         | 
| 67 | 
            +
                content_dim: 384
         | 
| 68 | 
            +
                content_codebook_size: 1024
         | 
| 69 | 
            +
                content_type: 'discrete'
         | 
| 70 | 
            +
                f0_condition: false
         | 
| 71 | 
            +
                n_f0_bins: 512
         | 
| 72 | 
            +
                content_codebooks: 1
         | 
| 73 | 
            +
                is_causal: false
         | 
| 74 | 
            +
                long_skip_connection: false
         | 
| 75 | 
            +
                zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
         | 
| 76 | 
            +
                time_as_token: true
         | 
| 77 | 
            +
                style_as_token: true
         | 
| 78 | 
            +
                uvit_skip_connection: true
         | 
| 79 | 
            +
                add_resblock_in_transformer: false
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            loss_params:
         | 
| 82 | 
            +
              base_lr: 0.0001
         | 
    	
        configs/v2/ar_base.yaml
    ADDED
    
    | 
            File without changes
         | 
    	
        configs/v2/dit_small.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: modules.v2.cfm.CFM
         | 
| 2 | 
            +
            estimator:
         | 
| 3 | 
            +
              _target_: modules.v2.dit_wrapper.DiT
         | 
| 4 | 
            +
              time_as_token: true
         | 
| 5 | 
            +
              style_as_token: true
         | 
| 6 | 
            +
              uvit_skip_connection: false
         | 
| 7 | 
            +
              block_size: 8192
         | 
| 8 | 
            +
              depth: 13
         | 
| 9 | 
            +
              num_heads: 8
         | 
| 10 | 
            +
              hidden_dim: 512
         | 
| 11 | 
            +
              in_channels: 80
         | 
| 12 | 
            +
              content_dim: 512
         | 
| 13 | 
            +
              style_encoder_dim: 192
         | 
| 14 | 
            +
              class_dropout_prob: 0.1
         | 
| 15 | 
            +
              dropout_rate: 0.0
         | 
| 16 | 
            +
              attn_dropout_rate: 0.0
         | 
| 17 | 
            +
             | 
    	
        configs/v2/vc_wrapper.yaml
    ADDED
    
    | @@ -0,0 +1,105 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: modules.v2.vc_wrapper.VoiceConversionWrapper
         | 
| 2 | 
            +
            sr: 22050
         | 
| 3 | 
            +
            hop_size: 256
         | 
| 4 | 
            +
            mel_fn:
         | 
| 5 | 
            +
              _target_: modules.audio.mel_spectrogram
         | 
| 6 | 
            +
              _partial_: true
         | 
| 7 | 
            +
              n_fft: 1024
         | 
| 8 | 
            +
              win_size: 1024
         | 
| 9 | 
            +
              hop_size: 256
         | 
| 10 | 
            +
              num_mels: 80
         | 
| 11 | 
            +
              sampling_rate: 22050
         | 
| 12 | 
            +
              fmin: 0
         | 
| 13 | 
            +
              fmax: null
         | 
| 14 | 
            +
              center: False
         | 
| 15 | 
            +
            cfm:
         | 
| 16 | 
            +
              _target_: modules.v2.cfm.CFM
         | 
| 17 | 
            +
              estimator:
         | 
| 18 | 
            +
                _target_: modules.v2.dit_wrapper.DiT
         | 
| 19 | 
            +
                time_as_token: true
         | 
| 20 | 
            +
                style_as_token: true
         | 
| 21 | 
            +
                uvit_skip_connection: false
         | 
| 22 | 
            +
                block_size: 8192
         | 
| 23 | 
            +
                depth: 13
         | 
| 24 | 
            +
                num_heads: 8
         | 
| 25 | 
            +
                hidden_dim: 512
         | 
| 26 | 
            +
                in_channels: 80
         | 
| 27 | 
            +
                content_dim: 512
         | 
| 28 | 
            +
                style_encoder_dim: 192
         | 
| 29 | 
            +
                class_dropout_prob: 0.1
         | 
| 30 | 
            +
                dropout_rate: 0.0
         | 
| 31 | 
            +
                attn_dropout_rate: 0.0
         | 
| 32 | 
            +
            cfm_length_regulator:
         | 
| 33 | 
            +
              _target_: modules.v2.length_regulator.InterpolateRegulator
         | 
| 34 | 
            +
              channels: 512
         | 
| 35 | 
            +
              is_discrete: true
         | 
| 36 | 
            +
              codebook_size: 2048
         | 
| 37 | 
            +
              sampling_ratios: [ 1, 1, 1, 1 ]
         | 
| 38 | 
            +
              f0_condition: false
         | 
| 39 | 
            +
            ar:
         | 
| 40 | 
            +
              _target_: modules.v2.ar.NaiveWrapper
         | 
| 41 | 
            +
              model:
         | 
| 42 | 
            +
                _target_: modules.v2.ar.NaiveTransformer
         | 
| 43 | 
            +
                config:
         | 
| 44 | 
            +
                  _target_: modules.v2.ar.NaiveModelArgs
         | 
| 45 | 
            +
                  dropout: 0.0
         | 
| 46 | 
            +
                  rope_base: 10000.0
         | 
| 47 | 
            +
                  dim: 768
         | 
| 48 | 
            +
                  head_dim: 64
         | 
| 49 | 
            +
                  n_local_heads: 2
         | 
| 50 | 
            +
                  intermediate_size: 2304
         | 
| 51 | 
            +
                  n_head: 12
         | 
| 52 | 
            +
                  n_layer: 12
         | 
| 53 | 
            +
                  vocab_size: 2049  # 1 + 1 for eos
         | 
| 54 | 
            +
            ar_length_regulator:
         | 
| 55 | 
            +
              _target_: modules.v2.length_regulator.InterpolateRegulator
         | 
| 56 | 
            +
              channels: 768
         | 
| 57 | 
            +
              is_discrete: true
         | 
| 58 | 
            +
              codebook_size: 32
         | 
| 59 | 
            +
              sampling_ratios: [ ]
         | 
| 60 | 
            +
              f0_condition: false
         | 
| 61 | 
            +
            style_encoder:
         | 
| 62 | 
            +
              _target_: modules.campplus.DTDNN.CAMPPlus
         | 
| 63 | 
            +
              feat_dim: 80
         | 
| 64 | 
            +
              embedding_size: 192
         | 
| 65 | 
            +
            content_extractor_narrow:
         | 
| 66 | 
            +
              _target_: modules.astral_quantization.default_model.AstralQuantizer
         | 
| 67 | 
            +
              tokenizer_name: "openai/whisper-small"
         | 
| 68 | 
            +
              ssl_model_name: "facebook/hubert-large-ll60k"
         | 
| 69 | 
            +
              ssl_output_layer: 18
         | 
| 70 | 
            +
              skip_ssl: true
         | 
| 71 | 
            +
              encoder: &bottleneck_encoder
         | 
| 72 | 
            +
                _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
         | 
| 73 | 
            +
                dim: 512
         | 
| 74 | 
            +
                num_blocks: 12
         | 
| 75 | 
            +
                intermediate_dim: 1536
         | 
| 76 | 
            +
                dilation: 1
         | 
| 77 | 
            +
                input_dim: 1024
         | 
| 78 | 
            +
              quantizer:
         | 
| 79 | 
            +
                _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
         | 
| 80 | 
            +
                codebook_size: 32  # codebook size, must be a power of 2
         | 
| 81 | 
            +
                dim: 512
         | 
| 82 | 
            +
                entropy_loss_weight: 0.1
         | 
| 83 | 
            +
                diversity_gamma: 1.0
         | 
| 84 | 
            +
                spherical: True
         | 
| 85 | 
            +
                enable_entropy_loss: True
         | 
| 86 | 
            +
                soft_entropy_loss: True
         | 
| 87 | 
            +
            content_extractor_wide:
         | 
| 88 | 
            +
              _target_: modules.astral_quantization.default_model.AstralQuantizer
         | 
| 89 | 
            +
              tokenizer_name: "openai/whisper-small"
         | 
| 90 | 
            +
              ssl_model_name: "facebook/hubert-large-ll60k"
         | 
| 91 | 
            +
              ssl_output_layer: 18
         | 
| 92 | 
            +
              encoder: *bottleneck_encoder
         | 
| 93 | 
            +
              quantizer:
         | 
| 94 | 
            +
                _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
         | 
| 95 | 
            +
                codebook_size: 2048  # codebook size, must be a power of 2
         | 
| 96 | 
            +
                dim: 512
         | 
| 97 | 
            +
                entropy_loss_weight: 0.1
         | 
| 98 | 
            +
                diversity_gamma: 1.0
         | 
| 99 | 
            +
                spherical: True
         | 
| 100 | 
            +
                enable_entropy_loss: True
         | 
| 101 | 
            +
                soft_entropy_loss: True
         | 
| 102 | 
            +
            vocoder:
         | 
| 103 | 
            +
              _target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained
         | 
| 104 | 
            +
              pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x"
         | 
| 105 | 
            +
              use_cuda_kernel: false
         | 
    	
        hf_utils.py
    CHANGED
    
    | @@ -2,7 +2,7 @@ import os | |
| 2 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 3 |  | 
| 4 |  | 
| 5 | 
            -
            def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename= | 
| 6 | 
             
                os.makedirs("./checkpoints", exist_ok=True)
         | 
| 7 | 
             
                model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
         | 
| 8 | 
             
                if config_filename is None:
         | 
|  | |
| 2 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 3 |  | 
| 4 |  | 
| 5 | 
            +
            def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
         | 
| 6 | 
             
                os.makedirs("./checkpoints", exist_ok=True)
         | 
| 7 | 
             
                model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
         | 
| 8 | 
             
                if config_filename is None:
         | 
    	
        modules/__pycache__/audio.cpython-310.pyc
    CHANGED
    
    | Binary files a/modules/__pycache__/audio.cpython-310.pyc and b/modules/__pycache__/audio.cpython-310.pyc differ | 
|  | 
    	
        modules/__pycache__/commons.cpython-310.pyc
    CHANGED
    
    | Binary files a/modules/__pycache__/commons.cpython-310.pyc and b/modules/__pycache__/commons.cpython-310.pyc differ | 
|  | 
    	
        modules/__pycache__/commons.cpython-38.pyc
    ADDED
    
    | Binary file (14.2 kB). View file | 
|  | 
    	
        modules/__pycache__/diffusion_transformer.cpython-310.pyc
    CHANGED
    
    | Binary files a/modules/__pycache__/diffusion_transformer.cpython-310.pyc and b/modules/__pycache__/diffusion_transformer.cpython-310.pyc differ | 
|  | 
    	
        modules/__pycache__/flow_matching.cpython-310.pyc
    CHANGED
    
    | Binary files a/modules/__pycache__/flow_matching.cpython-310.pyc and b/modules/__pycache__/flow_matching.cpython-310.pyc differ | 
|  | 
    	
        modules/__pycache__/length_regulator.cpython-310.pyc
    CHANGED
    
    | Binary files a/modules/__pycache__/length_regulator.cpython-310.pyc and b/modules/__pycache__/length_regulator.cpython-310.pyc differ | 
|  | 
    	
        modules/__pycache__/rmvpe.cpython-310.pyc
    ADDED
    
    | Binary file (17.6 kB). View file | 
|  | 
    	
        modules/astral_quantization/__pycache__/bsq.cpython-310.pyc
    ADDED
    
    | Binary file (12.7 kB). View file | 
|  | 
    	
        modules/astral_quantization/__pycache__/convnext.cpython-310.pyc
    ADDED
    
    | Binary file (6.87 kB). View file | 
|  | 
    	
        modules/astral_quantization/__pycache__/default_model.cpython-310.pyc
    ADDED
    
    | Binary file (2.8 kB). View file | 
|  | 
    	
        modules/astral_quantization/bsq.py
    ADDED
    
    | @@ -0,0 +1,569 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Lookup Free Quantization
         | 
| 3 | 
            +
            Proposed in https://arxiv.org/abs/2310.05737
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            In the simplest setup, each dimension is quantized into {-1, 1}.
         | 
| 6 | 
            +
            An entropy penalty is used to encourage utilization.
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from math import log2, ceil
         | 
| 10 | 
            +
            from functools import partial, cache
         | 
| 11 | 
            +
            from collections import namedtuple
         | 
| 12 | 
            +
            from contextlib import nullcontext
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch.distributed as dist
         | 
| 15 | 
            +
            from torch.distributed import nn as dist_nn
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            from torch import nn, einsum
         | 
| 19 | 
            +
            import torch.nn.functional as F
         | 
| 20 | 
            +
            from torch.nn import Module
         | 
| 21 | 
            +
            from torch.amp import autocast
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from einops import rearrange, reduce, pack, unpack
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # constants
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # distributed helpers
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            @cache
         | 
| 34 | 
            +
            def is_distributed():
         | 
| 35 | 
            +
                return dist.is_initialized() and dist.get_world_size() > 1
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def maybe_distributed_mean(t):
         | 
| 38 | 
            +
                if not is_distributed():
         | 
| 39 | 
            +
                    return t
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                dist_nn.all_reduce(t)
         | 
| 42 | 
            +
                t = t / dist.get_world_size()
         | 
| 43 | 
            +
                return t
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            # helper functions
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            def exists(v):
         | 
| 48 | 
            +
                return v is not None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def identity(t):
         | 
| 51 | 
            +
                return t
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            def default(*args):
         | 
| 54 | 
            +
                for arg in args:
         | 
| 55 | 
            +
                    if exists(arg):
         | 
| 56 | 
            +
                        return arg() if callable(arg) else arg
         | 
| 57 | 
            +
                return None
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            def pack_one(t, pattern):
         | 
| 60 | 
            +
                return pack([t], pattern)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            def unpack_one(t, ps, pattern):
         | 
| 63 | 
            +
                return unpack(t, ps, pattern)[0]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def l2norm(t):
         | 
| 66 | 
            +
                return F.normalize(t, dim = -1)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # entropy
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            def log(t, eps = 1e-5):
         | 
| 71 | 
            +
                return t.clamp(min = eps).log()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def entropy(prob):
         | 
| 74 | 
            +
                return (-prob * log(prob)).sum(dim=-1)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            # cosine sim linear
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            class CosineSimLinear(Module):
         | 
| 79 | 
            +
                def __init__(
         | 
| 80 | 
            +
                    self,
         | 
| 81 | 
            +
                    dim_in,
         | 
| 82 | 
            +
                    dim_out,
         | 
| 83 | 
            +
                    scale = 1.
         | 
| 84 | 
            +
                ):
         | 
| 85 | 
            +
                    super().__init__()
         | 
| 86 | 
            +
                    self.scale = scale
         | 
| 87 | 
            +
                    self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def forward(self, x):
         | 
| 90 | 
            +
                    x = F.normalize(x, dim = -1)
         | 
| 91 | 
            +
                    w = F.normalize(self.weight, dim = 0)
         | 
| 92 | 
            +
                    return (x @ w) * self.scale
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            def soft_entropy_loss(u, tau=1.0, gamma=1.0):
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                Args:
         | 
| 99 | 
            +
                    u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
         | 
| 100 | 
            +
                    tau (float): Temperature scaling factor.
         | 
| 101 | 
            +
                    gamma (float): Weight for the second entropy term.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                Returns:
         | 
| 104 | 
            +
                    torch.Tensor: Soft entropy loss.
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                # Binary quantization: Generate implicit codebook corners
         | 
| 107 | 
            +
                L = u.size(1)  # Dimensionality of codebook
         | 
| 108 | 
            +
                corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Compute soft quantization probabilities for all dimensions
         | 
| 111 | 
            +
                # q_hat(c|u) for each dimension
         | 
| 112 | 
            +
                prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2))  # Shape: (batch_size, L, 2)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # Entropy of q_hat(c|u) (independent along each dimension)
         | 
| 115 | 
            +
                entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1)  # Shape: (batch_size, L)
         | 
| 116 | 
            +
                entropy_term1 = entropy_per_dim.mean()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                # Expected probabilities for dataset entropy (approximation)
         | 
| 119 | 
            +
                expected_probs = prob_matrix.mean(dim=0)  # Mean across batch, shape: (L, 2)
         | 
| 120 | 
            +
                entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # Final entropy loss
         | 
| 123 | 
            +
                loss = entropy_term1 - gamma * entropy_term2
         | 
| 124 | 
            +
                return loss
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            # class
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            class BinarySphericalQuantize(Module):
         | 
| 129 | 
            +
                def __init__(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    *,
         | 
| 132 | 
            +
                    dim = None,
         | 
| 133 | 
            +
                    codebook_size = None,
         | 
| 134 | 
            +
                    entropy_loss_weight = 0.1,
         | 
| 135 | 
            +
                    commitment_loss_weight = 0.,
         | 
| 136 | 
            +
                    diversity_gamma = 1.,
         | 
| 137 | 
            +
                    straight_through_activation = nn.Identity(),
         | 
| 138 | 
            +
                    num_codebooks = 1,
         | 
| 139 | 
            +
                    keep_num_codebooks_dim = None,
         | 
| 140 | 
            +
                    codebook_scale = 1.,                        # for residual LFQ, codebook scaled down by 2x at each layer
         | 
| 141 | 
            +
                    frac_per_sample_entropy = 0.25,               # make less than 1. to only use a random fraction of the probs for per sample entropy
         | 
| 142 | 
            +
                    has_projections = None,
         | 
| 143 | 
            +
                    projection_has_bias = True,
         | 
| 144 | 
            +
                    soft_clamp_input_value = None,
         | 
| 145 | 
            +
                    cosine_sim_project_in = False,
         | 
| 146 | 
            +
                    cosine_sim_project_in_scale = None,
         | 
| 147 | 
            +
                    channel_first = None,
         | 
| 148 | 
            +
                    experimental_softplus_entropy_loss = False,
         | 
| 149 | 
            +
                    entropy_loss_offset = 5.,                   # how much to shift the loss before softplus
         | 
| 150 | 
            +
                    spherical = True,                          # from https://arxiv.org/abs/2406.07548
         | 
| 151 | 
            +
                    force_quantization_f32 = True,               # will force the quantization step to be full precision
         | 
| 152 | 
            +
                    enable_entropy_loss = True,
         | 
| 153 | 
            +
                    soft_entropy_loss = True,
         | 
| 154 | 
            +
                ):
         | 
| 155 | 
            +
                    super().__init__()
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # some assert validations
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
         | 
| 160 | 
            +
                    assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    codebook_size = default(codebook_size, lambda: 2 ** dim)
         | 
| 163 | 
            +
                    self.codebook_size = codebook_size
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    codebook_dim = int(log2(codebook_size))
         | 
| 166 | 
            +
                    codebook_dims = codebook_dim * num_codebooks
         | 
| 167 | 
            +
                    dim = default(dim, codebook_dims)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    has_projections = default(has_projections, dim != codebook_dims)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    if cosine_sim_project_in:
         | 
| 172 | 
            +
                        cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
         | 
| 173 | 
            +
                        project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
         | 
| 174 | 
            +
                    else:
         | 
| 175 | 
            +
                        project_in_klass = partial(nn.Linear, bias = projection_has_bias)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
         | 
| 178 | 
            +
                    self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
         | 
| 179 | 
            +
                    self.has_projections = has_projections
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    self.dim = dim
         | 
| 182 | 
            +
                    self.codebook_dim = codebook_dim
         | 
| 183 | 
            +
                    self.num_codebooks = num_codebooks
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
         | 
| 186 | 
            +
                    assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
         | 
| 187 | 
            +
                    self.keep_num_codebooks_dim = keep_num_codebooks_dim
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    # channel first
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    self.channel_first = channel_first
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # straight through activation
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    self.activation = straight_through_activation
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    # whether to use BSQ (binary spherical quantization)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    self.spherical = spherical
         | 
| 200 | 
            +
                    self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    # entropy aux loss related weights
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    assert 0 < frac_per_sample_entropy <= 1.
         | 
| 205 | 
            +
                    self.frac_per_sample_entropy = frac_per_sample_entropy
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self.diversity_gamma = diversity_gamma
         | 
| 208 | 
            +
                    self.entropy_loss_weight = entropy_loss_weight
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # codebook scale
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    self.codebook_scale = codebook_scale
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # commitment loss
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.commitment_loss_weight = commitment_loss_weight
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # whether to soft clamp the input value from -value to value
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    self.soft_clamp_input_value = soft_clamp_input_value
         | 
| 221 | 
            +
                    assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.entropy_loss_offset = entropy_loss_offset
         | 
| 226 | 
            +
                    self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # for no auxiliary loss, during inference
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
         | 
| 231 | 
            +
                    self.register_buffer('zero', torch.tensor(0.), persistent = False)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # whether to force quantization step to be f32
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    self.force_quantization_f32 = force_quantization_f32
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    # codes
         | 
| 238 | 
            +
                    self.enable_entropy_loss = enable_entropy_loss
         | 
| 239 | 
            +
                    self.soft_entropy_loss = soft_entropy_loss
         | 
| 240 | 
            +
                    if codebook_size <= 100000:
         | 
| 241 | 
            +
                        all_codes = torch.arange(codebook_size)
         | 
| 242 | 
            +
                        bits = ((all_codes[..., None].int() & self.mask) != 0).float()
         | 
| 243 | 
            +
                        codebook = self.bits_to_codes(bits)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                        self.register_buffer('codebook', codebook.float(), persistent = False)
         | 
| 246 | 
            +
                    else:
         | 
| 247 | 
            +
                        all_codes = torch.arange(pow(2, 16))
         | 
| 248 | 
            +
                        mask = 2 ** torch.arange(16 - 1, -1, -1)
         | 
| 249 | 
            +
                        bits = ((all_codes[..., None].int() & mask) != 0).float()
         | 
| 250 | 
            +
                        codebook = self.bits_to_codes(bits)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                        self.register_buffer('codebook', codebook.float(), persistent = False)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def bits_to_codes(self, bits):
         | 
| 255 | 
            +
                    return bits * self.codebook_scale * 2 - self.codebook_scale
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                @property
         | 
| 258 | 
            +
                def dtype(self):
         | 
| 259 | 
            +
                    return self.codebook.dtype
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def indices_to_codes(
         | 
| 262 | 
            +
                    self,
         | 
| 263 | 
            +
                    indices,
         | 
| 264 | 
            +
                    project_out = True
         | 
| 265 | 
            +
                ):
         | 
| 266 | 
            +
                    is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
         | 
| 267 | 
            +
                    should_transpose = default(self.channel_first, is_img_or_video)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    if not self.keep_num_codebooks_dim:
         | 
| 270 | 
            +
                        indices = rearrange(indices, '... -> ... 1')
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    # indices to codes, which are bits of either -1 or 1
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    codes = self.bits_to_codes(bits)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    codes = self.maybe_l2norm(codes)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    codes = rearrange(codes, '... c d -> ... (c d)')
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # whether to project codes out to original dimensions
         | 
| 283 | 
            +
                    # if the input feature dimensions were not log2(codebook size)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    if project_out:
         | 
| 286 | 
            +
                        codes = self.project_out(codes)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # rearrange codes back to original shape
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    if should_transpose:
         | 
| 291 | 
            +
                        codes = rearrange(codes, 'b ... d -> b d ...')
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    return codes
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def bits_to_z(self, bits):
         | 
| 296 | 
            +
                    # assert bits must contain only -1 and 1
         | 
| 297 | 
            +
                    assert torch.all(bits.abs() == 1)
         | 
| 298 | 
            +
                    quantized = bits.float()
         | 
| 299 | 
            +
                    quantized = self.maybe_l2norm(quantized)
         | 
| 300 | 
            +
                    z = self.project_out(quantized)
         | 
| 301 | 
            +
                    return z
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def forward(
         | 
| 304 | 
            +
                    self,
         | 
| 305 | 
            +
                    x,
         | 
| 306 | 
            +
                    inv_temperature = 100.,
         | 
| 307 | 
            +
                    return_loss_breakdown = False,
         | 
| 308 | 
            +
                    mask = None,
         | 
| 309 | 
            +
                    return_bits = False
         | 
| 310 | 
            +
                ):
         | 
| 311 | 
            +
                    """
         | 
| 312 | 
            +
                    einstein notation
         | 
| 313 | 
            +
                    b - batch
         | 
| 314 | 
            +
                    n - sequence (or flattened spatial dimensions)
         | 
| 315 | 
            +
                    d - feature dimension, which is also log2(codebook size)
         | 
| 316 | 
            +
                    c - number of codebook dim
         | 
| 317 | 
            +
                    """
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    is_img_or_video = x.ndim >= 4
         | 
| 320 | 
            +
                    should_transpose = default(self.channel_first, is_img_or_video)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # standardize image or video into (batch, seq, dimension)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    if should_transpose:
         | 
| 325 | 
            +
                        x = rearrange(x, 'b d ... -> b ... d')
         | 
| 326 | 
            +
                        x, ps = pack_one(x, 'b * d')
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    x = self.project_in(x)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    # maybe soft clamp
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    if exists(self.soft_clamp_input_value):
         | 
| 335 | 
            +
                        clamp_value = self.soft_clamp_input_value
         | 
| 336 | 
            +
                        x = (x / clamp_value).tanh() * clamp_value
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    # split out number of codebooks
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    # maybe l2norm
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    x = self.maybe_l2norm(x)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    # whether to force quantization step to be full precision or not
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    force_f32 = self.force_quantization_f32
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    with quantization_context():
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                        if force_f32:
         | 
| 355 | 
            +
                            orig_dtype = x.dtype
         | 
| 356 | 
            +
                            x = x.float()
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                        # quantize by eq 3.
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                        original_input = x
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                        codebook_value = torch.ones_like(x) * self.codebook_scale
         | 
| 363 | 
            +
                        quantized = torch.where(x > 0, codebook_value, -codebook_value)
         | 
| 364 | 
            +
                        if return_bits:
         | 
| 365 | 
            +
                            return quantized
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                        # calculate indices
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                        indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                        # maybe l2norm
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                        quantized = self.maybe_l2norm(quantized)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                        # use straight-through gradients (optionally with custom activation fn) if training
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                        if self.training:
         | 
| 378 | 
            +
                            x = self.activation(x)
         | 
| 379 | 
            +
                            x = x + (quantized - x).detach()
         | 
| 380 | 
            +
                        else:
         | 
| 381 | 
            +
                            x = quantized
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                        # entropy aux loss
         | 
| 384 | 
            +
                        if self.soft_entropy_loss:
         | 
| 385 | 
            +
                            entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
         | 
| 386 | 
            +
                        elif self.training and self.enable_entropy_loss:
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                            if force_f32:
         | 
| 389 | 
            +
                                codebook = self.codebook.float()
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                            codebook = self.maybe_l2norm(codebook)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                            # whether to only use a fraction of probs, for reducing memory
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                            if self.frac_per_sample_entropy < 1.:
         | 
| 396 | 
            +
                                # account for mask
         | 
| 397 | 
            +
                                if exists(mask):
         | 
| 398 | 
            +
                                    original_input = original_input[mask]
         | 
| 399 | 
            +
                                original_input = rearrange(original_input, 'b n ... -> (b n) ...')
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                                rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                                sampled_input = original_input[..., rand_mask]
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                                sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                                sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                                per_sample_probs = sampled_prob
         | 
| 410 | 
            +
                            else:
         | 
| 411 | 
            +
                                if exists(mask):
         | 
| 412 | 
            +
                                    original_input = original_input[mask]
         | 
| 413 | 
            +
                                original_input = rearrange(original_input, 'b n ... -> (b n) ...')
         | 
| 414 | 
            +
                                # the same as euclidean distance up to a constant
         | 
| 415 | 
            +
                                distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                                prob = (-distance * inv_temperature).softmax(dim = -1)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                                per_sample_probs = prob
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                            # calculate per sample entropy
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                            per_sample_entropy = entropy(per_sample_probs).mean()
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                            # distribution over all available tokens in the batch
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                            avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                            avg_prob = maybe_distributed_mean(avg_prob)
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                            codebook_entropy = entropy(avg_prob).mean()
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                            # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
         | 
| 434 | 
            +
                            # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                            entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
         | 
| 437 | 
            +
                        else:
         | 
| 438 | 
            +
                            # if not training, just return dummy 0
         | 
| 439 | 
            +
                            entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                        # whether to make the entropy loss positive or not through a (shifted) softplus
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                        if self.training and self.experimental_softplus_entropy_loss:
         | 
| 444 | 
            +
                            entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                        # commit loss
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                        if self.training and self.commitment_loss_weight > 0.:
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                            commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                            if exists(mask):
         | 
| 453 | 
            +
                                commit_loss = commit_loss[mask]
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                            commit_loss = commit_loss.mean()
         | 
| 456 | 
            +
                        else:
         | 
| 457 | 
            +
                            commit_loss = self.zero
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                        # input back to original dtype if needed
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                        if force_f32:
         | 
| 462 | 
            +
                            x = x.type(orig_dtype)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    # merge back codebook dim
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    x = rearrange(x, 'b n c d -> b n (c d)')
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    # project out to feature dimension if needed
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    x = self.project_out(x)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    # reconstitute image or video dimensions
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    if should_transpose:
         | 
| 475 | 
            +
                        x = unpack_one(x, ps, 'b * d')
         | 
| 476 | 
            +
                        x = rearrange(x, 'b ... d -> b d ...')
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                        indices = unpack_one(indices, ps, 'b * c')
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    # whether to remove single codebook dim
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    if not self.keep_num_codebooks_dim:
         | 
| 483 | 
            +
                        indices = rearrange(indices, '... 1 -> ...')
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    # complete aux loss
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    # returns
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    ret = Return(x, indices, aux_loss)
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    if not return_loss_breakdown:
         | 
| 494 | 
            +
                        return ret
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
            class GroupedResidualBSQ(Module):
         | 
| 499 | 
            +
                def __init__(
         | 
| 500 | 
            +
                    self,
         | 
| 501 | 
            +
                    *,
         | 
| 502 | 
            +
                    dim,
         | 
| 503 | 
            +
                    groups = 1,
         | 
| 504 | 
            +
                    accept_image_fmap = False,
         | 
| 505 | 
            +
                    **kwargs
         | 
| 506 | 
            +
                ):
         | 
| 507 | 
            +
                    super().__init__()
         | 
| 508 | 
            +
                    self.dim = dim
         | 
| 509 | 
            +
                    self.groups = groups
         | 
| 510 | 
            +
                    assert (dim % groups) == 0
         | 
| 511 | 
            +
                    dim_per_group = dim // groups
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    self.accept_image_fmap = accept_image_fmap
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    self.rvqs = nn.ModuleList([])
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    for _ in range(groups):
         | 
| 518 | 
            +
                        self.rvqs.append(LFQ(
         | 
| 519 | 
            +
                            dim = dim_per_group,
         | 
| 520 | 
            +
                            **kwargs
         | 
| 521 | 
            +
                        ))
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    self.codebook_size = self.rvqs[0].codebook_size
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                @property
         | 
| 526 | 
            +
                def codebooks(self):
         | 
| 527 | 
            +
                    return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                @property
         | 
| 530 | 
            +
                def split_dim(self):
         | 
| 531 | 
            +
                    return 1 if self.accept_image_fmap else -1
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                def get_codes_from_indices(self, indices):
         | 
| 534 | 
            +
                    codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
         | 
| 535 | 
            +
                    return torch.stack(codes)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                def get_output_from_indices(self, indices):
         | 
| 538 | 
            +
                    outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
         | 
| 539 | 
            +
                    return torch.cat(outputs, dim = self.split_dim)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                def forward(
         | 
| 542 | 
            +
                    self,
         | 
| 543 | 
            +
                    x,
         | 
| 544 | 
            +
                    return_all_codes = False
         | 
| 545 | 
            +
                ):
         | 
| 546 | 
            +
                    shape, split_dim = x.shape, self.split_dim
         | 
| 547 | 
            +
                    assert shape[split_dim] == self.dim
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    # split the feature dimension into groups
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                    x = x.chunk(self.groups, dim = split_dim)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    forward_kwargs = dict(
         | 
| 554 | 
            +
                    )
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    # invoke residual vq on each group
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
         | 
| 559 | 
            +
                    out = tuple(zip(*out))
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    # otherwise, get all the zipped outputs and combine them
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    quantized, all_indices, *maybe_aux_loss = out
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                    quantized = torch.cat(quantized, dim = split_dim)
         | 
| 566 | 
            +
                    all_indices = torch.stack(all_indices)
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                    ret = (quantized, all_indices, *maybe_aux_loss)
         | 
| 569 | 
            +
                    return ret
         | 
    	
        modules/astral_quantization/convnext.py
    ADDED
    
    | @@ -0,0 +1,209 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class ConvNextV2LayerNorm(nn.Module):
         | 
| 8 | 
            +
                r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
         | 
| 9 | 
            +
                The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
         | 
| 10 | 
            +
                width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
         | 
| 14 | 
            +
                    super().__init__()
         | 
| 15 | 
            +
                    self.weight = nn.Parameter(torch.ones(normalized_shape))
         | 
| 16 | 
            +
                    self.bias = nn.Parameter(torch.zeros(normalized_shape))
         | 
| 17 | 
            +
                    self.eps = eps
         | 
| 18 | 
            +
                    self.data_format = data_format
         | 
| 19 | 
            +
                    if self.data_format not in ["channels_last", "channels_first"]:
         | 
| 20 | 
            +
                        raise NotImplementedError(f"Unsupported data format: {self.data_format}")
         | 
| 21 | 
            +
                    self.normalized_shape = (normalized_shape,)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 24 | 
            +
                    if self.data_format == "channels_last":
         | 
| 25 | 
            +
                        x = torch.nn.functional.layer_norm(
         | 
| 26 | 
            +
                            x, self.normalized_shape, self.weight, self.bias, self.eps
         | 
| 27 | 
            +
                        )
         | 
| 28 | 
            +
                    elif self.data_format == "channels_first":
         | 
| 29 | 
            +
                        input_dtype = x.dtype
         | 
| 30 | 
            +
                        x = x.float()
         | 
| 31 | 
            +
                        u = x.mean(1, keepdim=True)
         | 
| 32 | 
            +
                        s = (x - u).pow(2).mean(1, keepdim=True)
         | 
| 33 | 
            +
                        x = (x - u) / torch.sqrt(s + self.eps)
         | 
| 34 | 
            +
                        x = x.to(dtype=input_dtype)
         | 
| 35 | 
            +
                        x = self.weight[None, :, None] * x + self.bias[None, :, None]
         | 
| 36 | 
            +
                    return x
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class GRN(nn.Module):
         | 
| 40 | 
            +
                def __init__(self, dim):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
                    self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
         | 
| 43 | 
            +
                    self.beta = nn.Parameter(torch.zeros(1, 1, dim))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(self, x):
         | 
| 46 | 
            +
                    Gx = torch.norm(x, p=2, dim=1, keepdim=True)
         | 
| 47 | 
            +
                    Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
         | 
| 48 | 
            +
                    return self.gamma * (x * Nx) + self.beta + x
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            class InterpolationLayer(nn.Module):
         | 
| 51 | 
            +
                def __init__(self, ):  # this is a default of 1 / 50 * (44100 / 512) / 4
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    pass
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
         | 
| 56 | 
            +
                    x = F.interpolate(x, size=target_len, mode='linear')
         | 
| 57 | 
            +
                    return x
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            class ConvNeXtV2Stage(nn.Module):
         | 
| 60 | 
            +
                def __init__(
         | 
| 61 | 
            +
                    self,
         | 
| 62 | 
            +
                    dim: int = 512,
         | 
| 63 | 
            +
                    intermediate_dim: int = 2048,
         | 
| 64 | 
            +
                    num_blocks: int = 1,
         | 
| 65 | 
            +
                    dilation: int = 1,
         | 
| 66 | 
            +
                    downsample_layer_indices: List[int] = None,
         | 
| 67 | 
            +
                    downsample_factors: List[int] = None,
         | 
| 68 | 
            +
                    upsample_layer_indices: List[int] = None,
         | 
| 69 | 
            +
                    upsample_factors: List[int] = None,
         | 
| 70 | 
            +
                    interpolation_layer_indices: List[int] = None,
         | 
| 71 | 
            +
                    input_dim: int = None,
         | 
| 72 | 
            +
                    output_dim: int = None,
         | 
| 73 | 
            +
                    gin_channels: int = 0,
         | 
| 74 | 
            +
                ):
         | 
| 75 | 
            +
                    super().__init__()
         | 
| 76 | 
            +
                    # maybe downsample layers
         | 
| 77 | 
            +
                    if downsample_layer_indices is not None:
         | 
| 78 | 
            +
                        assert downsample_factors is not None
         | 
| 79 | 
            +
                        self.downsample_blocks = nn.ModuleList(
         | 
| 80 | 
            +
                            [
         | 
| 81 | 
            +
                                nn.Sequential(
         | 
| 82 | 
            +
                                    ConvNextV2LayerNorm(dim, data_format="channels_first"),
         | 
| 83 | 
            +
                                    nn.Conv1d(
         | 
| 84 | 
            +
                                        dim, dim, kernel_size=downsample_factor, stride=downsample_factor
         | 
| 85 | 
            +
                                    ),
         | 
| 86 | 
            +
                                ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
         | 
| 87 | 
            +
                            ]
         | 
| 88 | 
            +
                        )
         | 
| 89 | 
            +
                        self.downsample_layer_indices = downsample_layer_indices
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        self.downsample_blocks = nn.ModuleList()
         | 
| 92 | 
            +
                        self.downsample_layer_indices = []
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # maybe upsample layers
         | 
| 95 | 
            +
                    if upsample_layer_indices is not None:
         | 
| 96 | 
            +
                        assert upsample_factors is not None
         | 
| 97 | 
            +
                        self.upsample_blocks = nn.ModuleList(
         | 
| 98 | 
            +
                            [
         | 
| 99 | 
            +
                                nn.Sequential(
         | 
| 100 | 
            +
                                    ConvNextV2LayerNorm(dim, data_format="channels_first"),
         | 
| 101 | 
            +
                                    nn.ConvTranspose1d(
         | 
| 102 | 
            +
                                        dim, dim, kernel_size=upsample_factor, stride=upsample_factor
         | 
| 103 | 
            +
                                    ),
         | 
| 104 | 
            +
                                ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
         | 
| 105 | 
            +
                            ]
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
                        self.upsample_layer_indices = upsample_layer_indices
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        self.upsample_blocks = nn.ModuleList()
         | 
| 110 | 
            +
                        self.upsample_layer_indices = []
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # maybe interpolation layers
         | 
| 113 | 
            +
                    if interpolation_layer_indices is not None:
         | 
| 114 | 
            +
                        self.interpolation_blocks = nn.ModuleList(
         | 
| 115 | 
            +
                            [
         | 
| 116 | 
            +
                                InterpolationLayer()
         | 
| 117 | 
            +
                                for _ in interpolation_layer_indices
         | 
| 118 | 
            +
                            ]
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
                        self.interpolation_layer_indices = interpolation_layer_indices
         | 
| 121 | 
            +
                    else:
         | 
| 122 | 
            +
                        self.interpolation_blocks = nn.ModuleList()
         | 
| 123 | 
            +
                        self.interpolation_layer_indices = []
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # main blocks
         | 
| 126 | 
            +
                    self.blocks = nn.ModuleList(
         | 
| 127 | 
            +
                        [
         | 
| 128 | 
            +
                            ConvNeXtV2Block(
         | 
| 129 | 
            +
                                dim=dim,
         | 
| 130 | 
            +
                                intermediate_dim=intermediate_dim,
         | 
| 131 | 
            +
                                dilation=dilation,
         | 
| 132 | 
            +
                            )
         | 
| 133 | 
            +
                            for _ in range(num_blocks)
         | 
| 134 | 
            +
                        ]
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
                    # maybe input and output projections
         | 
| 137 | 
            +
                    if input_dim is not None and input_dim != dim:
         | 
| 138 | 
            +
                        self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        self.input_projection = nn.Identity()
         | 
| 141 | 
            +
                    if output_dim is not None and output_dim != dim:
         | 
| 142 | 
            +
                        self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        self.output_projection = nn.Identity()
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if gin_channels > 0:
         | 
| 147 | 
            +
                        self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
         | 
| 150 | 
            +
                    x = self.input_projection(x)  # B, D, T
         | 
| 151 | 
            +
                    if hasattr(self, 'gin'):
         | 
| 152 | 
            +
                        g = kwargs['g']
         | 
| 153 | 
            +
                        x = x + self.gin(g)
         | 
| 154 | 
            +
                    # pad to a multiple of cumprod(downsample_factors)
         | 
| 155 | 
            +
                    if len(self.downsample_blocks) > 0:
         | 
| 156 | 
            +
                        downsample_factor = 1
         | 
| 157 | 
            +
                        for factor in self.downsample_blocks:
         | 
| 158 | 
            +
                            downsample_factor *= factor[1].stride[0]
         | 
| 159 | 
            +
                        pad_len = downsample_factor - x.size(-1) % downsample_factor
         | 
| 160 | 
            +
                        if pad_len > 0:
         | 
| 161 | 
            +
                            x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    # main blocks
         | 
| 164 | 
            +
                    for layer_idx, block in enumerate(self.blocks):
         | 
| 165 | 
            +
                        if layer_idx in self.downsample_layer_indices:
         | 
| 166 | 
            +
                            x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
         | 
| 167 | 
            +
                        if layer_idx in self.upsample_layer_indices:
         | 
| 168 | 
            +
                            x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
         | 
| 169 | 
            +
                        if layer_idx in self.interpolation_layer_indices:
         | 
| 170 | 
            +
                            x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
         | 
| 171 | 
            +
                        x = block(x)
         | 
| 172 | 
            +
                    x = self.output_projection(x)
         | 
| 173 | 
            +
                    return x
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def setup_caches(self, *args, **kwargs):
         | 
| 176 | 
            +
                    pass
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            class ConvNeXtV2Block(nn.Module):
         | 
| 180 | 
            +
                def __init__(
         | 
| 181 | 
            +
                    self,
         | 
| 182 | 
            +
                    dim: int,
         | 
| 183 | 
            +
                    intermediate_dim: int,
         | 
| 184 | 
            +
                    dilation: int = 1,
         | 
| 185 | 
            +
                ):
         | 
| 186 | 
            +
                    super().__init__()
         | 
| 187 | 
            +
                    padding = (dilation * (7 - 1)) // 2
         | 
| 188 | 
            +
                    self.dwconv = nn.Conv1d(
         | 
| 189 | 
            +
                        dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
         | 
| 190 | 
            +
                    )  # depthwise conv
         | 
| 191 | 
            +
                    self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
         | 
| 192 | 
            +
                    self.pwconv1 = nn.Linear(
         | 
| 193 | 
            +
                        dim, intermediate_dim
         | 
| 194 | 
            +
                    )  # pointwise/1x1 convs, implemented with linear layers
         | 
| 195 | 
            +
                    self.act = nn.GELU()
         | 
| 196 | 
            +
                    self.grn = GRN(intermediate_dim)
         | 
| 197 | 
            +
                    self.pwconv2 = nn.Linear(intermediate_dim, dim)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 200 | 
            +
                    residual = x
         | 
| 201 | 
            +
                    x = self.dwconv(x)
         | 
| 202 | 
            +
                    x = self.norm(x)
         | 
| 203 | 
            +
                    x = x.transpose(1, 2)  # b d n -> b n d
         | 
| 204 | 
            +
                    x = self.pwconv1(x)
         | 
| 205 | 
            +
                    x = self.act(x)
         | 
| 206 | 
            +
                    x = self.grn(x)
         | 
| 207 | 
            +
                    x = self.pwconv2(x)
         | 
| 208 | 
            +
                    x = x.transpose(1, 2)  # b n d -> b d n
         | 
| 209 | 
            +
                    return residual + x
         | 
    	
        modules/astral_quantization/default_model.py
    ADDED
    
    | @@ -0,0 +1,73 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            class AstralQuantizer(torch.nn.Module):
         | 
| 5 | 
            +
                def __init__(
         | 
| 6 | 
            +
                        self,
         | 
| 7 | 
            +
                        tokenizer_name: str,
         | 
| 8 | 
            +
                        ssl_model_name: str,
         | 
| 9 | 
            +
                        ssl_output_layer: int,
         | 
| 10 | 
            +
                        encoder: torch.nn.Module,
         | 
| 11 | 
            +
                        quantizer: torch.nn.Module,
         | 
| 12 | 
            +
                        skip_ssl: bool = False,
         | 
| 13 | 
            +
                ):
         | 
| 14 | 
            +
                    super().__init__()
         | 
| 15 | 
            +
                    self.encoder = encoder
         | 
| 16 | 
            +
                    self.quantizer = quantizer
         | 
| 17 | 
            +
                    self.tokenizer_name = tokenizer_name
         | 
| 18 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # Load SSL model from Huggingface
         | 
| 21 | 
            +
                    self.ssl_model_name = ssl_model_name
         | 
| 22 | 
            +
                    self.ssl_output_layer = ssl_output_layer
         | 
| 23 | 
            +
                    self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    if skip_ssl:  # in case the same SSL model has been loaded somewhere else
         | 
| 26 | 
            +
                        self.ssl_model = None
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
         | 
| 29 | 
            +
                        self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
         | 
| 30 | 
            +
                        self.ssl_model.encoder.layer_norm = torch.nn.Identity()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def load_separate_checkpoint(self, checkpoint_path):
         | 
| 33 | 
            +
                    params = torch.load(checkpoint_path, map_location='cpu')['net']
         | 
| 34 | 
            +
                    for key in params.keys():
         | 
| 35 | 
            +
                        for k in list(params[key].keys()):
         | 
| 36 | 
            +
                            if k.startswith("module."):
         | 
| 37 | 
            +
                                params[key][k[len("module."):]] = params[key][k]
         | 
| 38 | 
            +
                                del params[key][k]
         | 
| 39 | 
            +
                    self.encoder.load_state_dict(params['encoder'])
         | 
| 40 | 
            +
                    self.quantizer.load_state_dict(params['vq'])
         | 
| 41 | 
            +
                    if self.decoder is not None:
         | 
| 42 | 
            +
                        self.decoder.load_state_dict(params['decoder'])
         | 
| 43 | 
            +
                    if self.asr_decoder is not None:
         | 
| 44 | 
            +
                        self.asr_decoder.load_state_dict(params['predictor'], strict=False)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
         | 
| 47 | 
            +
                    ssl_fn = self.ssl_model if self.ssl_model else ssl_model
         | 
| 48 | 
            +
                    assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
         | 
| 49 | 
            +
                    waves_16k_input_list = [
         | 
| 50 | 
            +
                        waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
         | 
| 51 | 
            +
                        for bib in range(len(waves_16k))
         | 
| 52 | 
            +
                    ]
         | 
| 53 | 
            +
                    alt_inputs = self.ssl_feature_extractor(
         | 
| 54 | 
            +
                        waves_16k_input_list,
         | 
| 55 | 
            +
                        return_tensors='pt',
         | 
| 56 | 
            +
                        return_attention_mask=True,
         | 
| 57 | 
            +
                        padding=True,
         | 
| 58 | 
            +
                        sampling_rate=16000
         | 
| 59 | 
            +
                    ).to(waves_16k.device)
         | 
| 60 | 
            +
                    feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320  # frame rate of hubert is 50 Hz
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    outputs = ssl_fn(
         | 
| 63 | 
            +
                        alt_inputs.input_values,
         | 
| 64 | 
            +
                        attention_mask=alt_inputs.attention_mask,
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    last_hidden_states = outputs.last_hidden_state
         | 
| 67 | 
            +
                    last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
         | 
| 68 | 
            +
                    feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
         | 
| 69 | 
            +
                    last_hidden_states = last_hidden_states.transpose(1, 2)
         | 
| 70 | 
            +
                    x_hidden = self.encoder(last_hidden_states, feature_lens)
         | 
| 71 | 
            +
                    x_hidden = x_hidden.transpose(1, 2)
         | 
| 72 | 
            +
                    x_quantized, indices = self.quantizer(x_hidden)[:2]
         | 
| 73 | 
            +
                    return x_quantized, indices, feature_lens
         | 
    	
        modules/astral_quantization/transformer.py
    ADDED
    
    | @@ -0,0 +1,254 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            from dataclasses import dataclass
         | 
| 7 | 
            +
            from typing import Optional
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            from torch import Tensor
         | 
| 12 | 
            +
            from torch.nn import functional as F
         | 
| 13 | 
            +
            import time
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def find_multiple(n: int, k: int) -> int:
         | 
| 16 | 
            +
                if n % k == 0:
         | 
| 17 | 
            +
                    return n
         | 
| 18 | 
            +
                return n + k - (n % k)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            class AdaptiveLayerNorm(nn.Module):
         | 
| 21 | 
            +
                r"""Adaptive Layer Normalization"""
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(self, d_model, norm) -> None:
         | 
| 24 | 
            +
                    super(AdaptiveLayerNorm, self).__init__()
         | 
| 25 | 
            +
                    self.project_layer = nn.Linear(d_model, 2 * d_model)
         | 
| 26 | 
            +
                    self.norm = norm
         | 
| 27 | 
            +
                    self.d_model = d_model
         | 
| 28 | 
            +
                    self.eps = self.norm.eps
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
         | 
| 31 | 
            +
                    if embedding is None:
         | 
| 32 | 
            +
                        return self.norm(input)
         | 
| 33 | 
            +
                    weight, bias = torch.split(
         | 
| 34 | 
            +
                        self.project_layer(embedding),
         | 
| 35 | 
            +
                        split_size_or_sections=self.d_model,
         | 
| 36 | 
            +
                        dim=-1,
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    return weight * self.norm(input) + bias
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            @dataclass
         | 
| 42 | 
            +
            class ModelArgs:
         | 
| 43 | 
            +
                block_size: int = 2048
         | 
| 44 | 
            +
                vocab_size: int = 32000
         | 
| 45 | 
            +
                n_layer: int = 32
         | 
| 46 | 
            +
                n_head: int = 32
         | 
| 47 | 
            +
                dim: int = 4096
         | 
| 48 | 
            +
                intermediate_size: int = None
         | 
| 49 | 
            +
                n_local_heads: int = -1
         | 
| 50 | 
            +
                head_dim: int = 64
         | 
| 51 | 
            +
                rope_base: float = 10000
         | 
| 52 | 
            +
                norm_eps: float = 1e-5
         | 
| 53 | 
            +
                has_cross_attention: bool = False
         | 
| 54 | 
            +
                context_dim: int = 0
         | 
| 55 | 
            +
                is_causal: bool = False
         | 
| 56 | 
            +
                dropout_rate: float = 0.1
         | 
| 57 | 
            +
                attn_dropout_rate: float = 0.1
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def __post_init__(self):
         | 
| 60 | 
            +
                    if self.n_local_heads == -1:
         | 
| 61 | 
            +
                        self.n_local_heads = self.n_head
         | 
| 62 | 
            +
                    if self.intermediate_size is None:
         | 
| 63 | 
            +
                        hidden_dim = 4 * self.dim
         | 
| 64 | 
            +
                        n_hidden = int(2 * hidden_dim / 3)
         | 
| 65 | 
            +
                        self.intermediate_size = find_multiple(n_hidden, 256)
         | 
| 66 | 
            +
                    # self.head_dim = self.dim // self.n_head
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            class Transformer(nn.Module):
         | 
| 69 | 
            +
                def __init__(self, config: ModelArgs) -> None:
         | 
| 70 | 
            +
                    super().__init__()
         | 
| 71 | 
            +
                    self.config = config
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
         | 
| 74 | 
            +
                    self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.max_batch_size = -1
         | 
| 77 | 
            +
                    self.max_seq_length = config.block_size
         | 
| 78 | 
            +
                    freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
         | 
| 79 | 
            +
                                                          self.config.rope_base)
         | 
| 80 | 
            +
                    self.register_buffer("freqs_cis", freqs_cis)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    causal_mask = torch.tril(
         | 
| 83 | 
            +
                        torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    self.register_buffer("causal_mask", causal_mask)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward(self,
         | 
| 88 | 
            +
                            x: Tensor,
         | 
| 89 | 
            +
                            c: Tensor,
         | 
| 90 | 
            +
                            input_pos: Optional[Tensor] = None,
         | 
| 91 | 
            +
                            mask: Optional[Tensor] = None,
         | 
| 92 | 
            +
                            context: Optional[Tensor] = None,
         | 
| 93 | 
            +
                            context_input_pos: Optional[Tensor] = None,
         | 
| 94 | 
            +
                            cross_attention_mask: Optional[Tensor] = None,
         | 
| 95 | 
            +
                            ) -> Tensor:
         | 
| 96 | 
            +
                    if mask is None:
         | 
| 97 | 
            +
                        mask = self.causal_mask[:x.size(1), :x.size(1)]
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        mask = mask[..., input_pos]
         | 
| 100 | 
            +
                    freqs_cis = self.freqs_cis[input_pos]
         | 
| 101 | 
            +
                    if context is not None:
         | 
| 102 | 
            +
                        context_freqs_cis = self.freqs_cis[context_input_pos]
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        context_freqs_cis = None
         | 
| 105 | 
            +
                    skip_in_x_list = []
         | 
| 106 | 
            +
                    for i, layer in enumerate(self.layers):
         | 
| 107 | 
            +
                        x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
         | 
| 108 | 
            +
                    x = self.norm(x, c)
         | 
| 109 | 
            +
                    return x
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            class TransformerBlock(nn.Module):
         | 
| 113 | 
            +
                def __init__(self, config: ModelArgs) -> None:
         | 
| 114 | 
            +
                    super().__init__()
         | 
| 115 | 
            +
                    self.attention = Attention(config)
         | 
| 116 | 
            +
                    self.feed_forward = FeedForward(config)
         | 
| 117 | 
            +
                    self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
         | 
| 118 | 
            +
                    self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    if config.has_cross_attention:
         | 
| 121 | 
            +
                        self.has_cross_attention = True
         | 
| 122 | 
            +
                        self.cross_attention = Attention(config, is_cross_attention=True)
         | 
| 123 | 
            +
                        self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        self.has_cross_attention = False
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def forward(self,
         | 
| 128 | 
            +
                            x: Tensor,
         | 
| 129 | 
            +
                            c: Tensor,
         | 
| 130 | 
            +
                            freqs_cis: Tensor,
         | 
| 131 | 
            +
                            mask: Tensor,
         | 
| 132 | 
            +
                            context: Optional[Tensor] = None,
         | 
| 133 | 
            +
                            context_freqs_cis: Optional[Tensor] = None,
         | 
| 134 | 
            +
                            cross_attention_mask: Optional[Tensor] = None,
         | 
| 135 | 
            +
                            ) -> Tensor:
         | 
| 136 | 
            +
                    #time_attn_start = time.time()
         | 
| 137 | 
            +
                    h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
         | 
| 138 | 
            +
                    #print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
         | 
| 139 | 
            +
                    if self.has_cross_attention:
         | 
| 140 | 
            +
                        h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
         | 
| 141 | 
            +
                    out = h + self.feed_forward(self.ffn_norm(h, c))
         | 
| 142 | 
            +
                    return out
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class Attention(nn.Module):
         | 
| 146 | 
            +
                def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
         | 
| 147 | 
            +
                    super().__init__()
         | 
| 148 | 
            +
                    assert config.dim % config.n_head == 0
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
         | 
| 151 | 
            +
                    # key, query, value projections for all heads, but in a batch
         | 
| 152 | 
            +
                    if is_cross_attention:
         | 
| 153 | 
            +
                        self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
         | 
| 154 | 
            +
                        self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
         | 
| 157 | 
            +
                    self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
         | 
| 158 | 
            +
                    self.kv_cache = None
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    self.n_head = config.n_head
         | 
| 161 | 
            +
                    self.head_dim = config.head_dim
         | 
| 162 | 
            +
                    self.n_local_heads = config.n_local_heads
         | 
| 163 | 
            +
                    self.dim = config.dim
         | 
| 164 | 
            +
                    self.attn_dropout_rate = config.attn_dropout_rate
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def forward(self,
         | 
| 167 | 
            +
                            x: Tensor,
         | 
| 168 | 
            +
                            freqs_cis: Tensor,
         | 
| 169 | 
            +
                            mask: Tensor,
         | 
| 170 | 
            +
                            context: Optional[Tensor] = None,
         | 
| 171 | 
            +
                            context_freqs_cis: Optional[Tensor] = None,
         | 
| 172 | 
            +
                            ) -> Tensor:
         | 
| 173 | 
            +
                    bsz, seqlen, _ = x.shape
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    kv_size = self.n_local_heads * self.head_dim
         | 
| 176 | 
            +
                    if context is None:
         | 
| 177 | 
            +
                        q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
         | 
| 178 | 
            +
                        context_seqlen = seqlen
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        q = self.wq(x)
         | 
| 181 | 
            +
                        k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
         | 
| 182 | 
            +
                        context_seqlen = context.shape[1]
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    q = q.view(bsz, seqlen, self.n_head, self.head_dim)
         | 
| 185 | 
            +
                    k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         | 
| 186 | 
            +
                    v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    q = apply_rotary_emb(q, freqs_cis)
         | 
| 189 | 
            +
                    k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         | 
| 194 | 
            +
                    v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         | 
| 195 | 
            +
                    y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    y = self.wo(y)
         | 
| 200 | 
            +
                    return y
         | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            class FeedForward(nn.Module):
         | 
| 204 | 
            +
                def __init__(self, config: ModelArgs) -> None:
         | 
| 205 | 
            +
                    super().__init__()
         | 
| 206 | 
            +
                    self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
         | 
| 207 | 
            +
                    self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
         | 
| 208 | 
            +
                    self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
         | 
| 209 | 
            +
                    self.dropout = nn.Dropout(config.dropout_rate)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 212 | 
            +
                    return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            class RMSNorm(nn.Module):
         | 
| 216 | 
            +
                def __init__(self, dim: int, eps: float = 1e-5):
         | 
| 217 | 
            +
                    super().__init__()
         | 
| 218 | 
            +
                    self.eps = eps
         | 
| 219 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def _norm(self, x):
         | 
| 222 | 
            +
                    return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 225 | 
            +
                    output = self._norm(x.float()).type_as(x)
         | 
| 226 | 
            +
                    return output * self.weight
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            def precompute_freqs_cis(
         | 
| 230 | 
            +
                    seq_len: int, n_elem: int, base: int = 10000,
         | 
| 231 | 
            +
                    dtype: torch.dtype = torch.bfloat16
         | 
| 232 | 
            +
            ) -> Tensor:
         | 
| 233 | 
            +
                freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
         | 
| 234 | 
            +
                t = torch.arange(seq_len, device=freqs.device)
         | 
| 235 | 
            +
                freqs = torch.outer(t, freqs)
         | 
| 236 | 
            +
                freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 237 | 
            +
                cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
         | 
| 238 | 
            +
                return cache.to(dtype=dtype)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
         | 
| 242 | 
            +
                xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
         | 
| 243 | 
            +
                freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
         | 
| 244 | 
            +
                x_out2 = torch.stack(
         | 
| 245 | 
            +
                    [
         | 
| 246 | 
            +
                        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
         | 
| 247 | 
            +
                        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
         | 
| 248 | 
            +
                    ],
         | 
| 249 | 
            +
                    -1,
         | 
| 250 | 
            +
                )
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                x_out2 = x_out2.flatten(3)
         | 
| 253 | 
            +
                return x_out2.type_as(x)
         | 
| 254 | 
            +
             | 
    	
        modules/audio.py
    CHANGED
    
    | @@ -1,82 +1,82 @@ | |
| 1 | 
            -
            import numpy as np
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import torch.utils.data
         | 
| 4 | 
            -
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            -
            from scipy.io.wavfile import read
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            MAX_WAV_VALUE = 32768.0
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def load_wav(full_path):
         | 
| 11 | 
            -
                sampling_rate, data = read(full_path)
         | 
| 12 | 
            -
                return data, sampling_rate
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            def dynamic_range_compression(x, C=1, clip_val=1e-5):
         | 
| 16 | 
            -
                return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            def dynamic_range_decompression(x, C=1):
         | 
| 20 | 
            -
                return np.exp(x) / C
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 24 | 
            -
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 28 | 
            -
                return torch.exp(x) / C
         | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
            def spectral_normalize_torch(magnitudes):
         | 
| 32 | 
            -
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 33 | 
            -
                return output
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            def spectral_de_normalize_torch(magnitudes):
         | 
| 37 | 
            -
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 38 | 
            -
                return output
         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
            mel_basis = {}
         | 
| 42 | 
            -
            hann_window = {}
         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
            def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
         | 
| 46 | 
            -
                if torch.min(y) < -1.0:
         | 
| 47 | 
            -
                    print("min value is ", torch.min(y))
         | 
| 48 | 
            -
                if torch.max(y) > 1.0:
         | 
| 49 | 
            -
                    print("max value is ", torch.max(y))
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                global mel_basis, hann_window  # pylint: disable=global-statement
         | 
| 52 | 
            -
                if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
         | 
| 53 | 
            -
                    mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
         | 
| 54 | 
            -
                    mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
         | 
| 55 | 
            -
                    hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                y = torch.nn.functional.pad(
         | 
| 58 | 
            -
                    y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
         | 
| 59 | 
            -
                )
         | 
| 60 | 
            -
                y = y.squeeze(1)
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                spec = torch.view_as_real(
         | 
| 63 | 
            -
                    torch.stft(
         | 
| 64 | 
            -
                        y,
         | 
| 65 | 
            -
                        n_fft,
         | 
| 66 | 
            -
                        hop_length=hop_size,
         | 
| 67 | 
            -
                        win_length=win_size,
         | 
| 68 | 
            -
                        window=hann_window[str(sampling_rate) + "_" + str(y.device)],
         | 
| 69 | 
            -
                        center=center,
         | 
| 70 | 
            -
                        pad_mode="reflect",
         | 
| 71 | 
            -
                        normalized=False,
         | 
| 72 | 
            -
                        onesided=True,
         | 
| 73 | 
            -
                        return_complex=True,
         | 
| 74 | 
            -
                    )
         | 
| 75 | 
            -
                )
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
         | 
| 80 | 
            -
                spec = spectral_normalize_torch(spec)
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                return spec
         | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.utils.data
         | 
| 4 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            +
            from scipy.io.wavfile import read
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            MAX_WAV_VALUE = 32768.0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def load_wav(full_path):
         | 
| 11 | 
            +
                sampling_rate, data = read(full_path)
         | 
| 12 | 
            +
                return data, sampling_rate
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def dynamic_range_compression(x, C=1, clip_val=1e-5):
         | 
| 16 | 
            +
                return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def dynamic_range_decompression(x, C=1):
         | 
| 20 | 
            +
                return np.exp(x) / C
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 24 | 
            +
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 28 | 
            +
                return torch.exp(x) / C
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def spectral_normalize_torch(magnitudes):
         | 
| 32 | 
            +
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 33 | 
            +
                return output
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def spectral_de_normalize_torch(magnitudes):
         | 
| 37 | 
            +
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 38 | 
            +
                return output
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            mel_basis = {}
         | 
| 42 | 
            +
            hann_window = {}
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
         | 
| 46 | 
            +
                if torch.min(y) < -1.0:
         | 
| 47 | 
            +
                    print("min value is ", torch.min(y))
         | 
| 48 | 
            +
                if torch.max(y) > 1.0:
         | 
| 49 | 
            +
                    print("max value is ", torch.max(y))
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                global mel_basis, hann_window  # pylint: disable=global-statement
         | 
| 52 | 
            +
                if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
         | 
| 53 | 
            +
                    mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
         | 
| 54 | 
            +
                    mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
         | 
| 55 | 
            +
                    hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                y = torch.nn.functional.pad(
         | 
| 58 | 
            +
                    y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
                y = y.squeeze(1)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                spec = torch.view_as_real(
         | 
| 63 | 
            +
                    torch.stft(
         | 
| 64 | 
            +
                        y,
         | 
| 65 | 
            +
                        n_fft,
         | 
| 66 | 
            +
                        hop_length=hop_size,
         | 
| 67 | 
            +
                        win_length=win_size,
         | 
| 68 | 
            +
                        window=hann_window[str(sampling_rate) + "_" + str(y.device)],
         | 
| 69 | 
            +
                        center=center,
         | 
| 70 | 
            +
                        pad_mode="reflect",
         | 
| 71 | 
            +
                        normalized=False,
         | 
| 72 | 
            +
                        onesided=True,
         | 
| 73 | 
            +
                        return_complex=True,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
         | 
| 80 | 
            +
                spec = spectral_normalize_torch(spec)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return spec
         | 
    	
        modules/bigvgan/__pycache__/activations.cpython-310.pyc
    ADDED
    
    | Binary file (4 kB). View file | 
|  | 
    	
        modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc
    ADDED
    
    | Binary file (11.8 kB). View file | 
|  | 
    	
        modules/bigvgan/__pycache__/env.cpython-310.pyc
    ADDED
    
    | Binary file (796 Bytes). View file | 
|  | 
    	
        modules/bigvgan/__pycache__/meldataset.cpython-310.pyc
    ADDED
    
    | Binary file (8.54 kB). View file | 
|  | 
    	
        modules/bigvgan/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (2.84 kB). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (158 Bytes). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc
    ADDED
    
    | Binary file (2.34 kB). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc
    ADDED
    
    | Binary file (1.99 kB). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/cuda/activation1d.py
    CHANGED
    
    | @@ -3,10 +3,10 @@ | |
| 3 |  | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import torch.nn as nn
         | 
| 6 | 
            -
            from  | 
| 7 |  | 
| 8 | 
             
            # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
         | 
| 9 | 
            -
            from  | 
| 10 |  | 
| 11 | 
             
            anti_alias_activation_cuda = load.load()
         | 
| 12 |  | 
|  | |
| 3 |  | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import torch.nn as nn
         | 
| 6 | 
            +
            from ..torch.resample import UpSample1d, DownSample1d
         | 
| 7 |  | 
| 8 | 
             
            # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
         | 
| 9 | 
            +
            from ..cuda import load
         | 
| 10 |  | 
| 11 | 
             
            anti_alias_activation_cuda = load.load()
         | 
| 12 |  | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:e233713716a5778577f244b0f310944ff26d3079ce0e42491791da7d42e363c1
         | 
| 3 | 
            +
            size 522068
         | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # ninja log v5
         | 
| 2 | 
            +
            9	39554	7516864785377831	anti_alias_activation.o	3a177f31dd72c43c
         | 
| 3 | 
            +
            13	152601	7516865914203767	anti_alias_activation_cuda.cuda.o	2d613e7382d803fd
         | 
| 4 | 
            +
            152628	153062	7516865920541751	anti_alias_activation_cuda.pyd	f6366e9bdfb27f7
         | 
| 5 | 
            +
            128	50503	7654004565901584	anti_alias_activation.o	9ed3213f2e0d0858
         | 
| 6 | 
            +
            133	176837	7654005827401976	anti_alias_activation_cuda.cuda.o	a679b6661c609136
         | 
| 7 | 
            +
            176839	177401	7654005835005523	anti_alias_activation_cuda.pyd	f6366e9bdfb27f7
         | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:74c2824b05582070b69f51ec588aadb268c4fddf18fbb4590f901d1cdf32185c
         | 
| 3 | 
            +
            size 3246655
         | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:86c48de557041de7ebaff7926b5f346cc5e4e2dddc6cf5b88409f6cb161db0f4
         | 
| 3 | 
            +
            size 4724513
         | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp
    ADDED
    
    | Binary file (25.1 kB). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib
    ADDED
    
    | Binary file (43.7 kB). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:db37ea2dd31dfe67e68ee6019877d14638c41724ff9342c55f638f4d2cda3d03
         | 
| 3 | 
            +
            size 2454528
         | 
    	
        modules/bigvgan/alias_free_activation/cuda/build/build.ninja
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ninja_required_version = 1.3
         | 
| 2 | 
            +
            cxx = cl
         | 
| 3 | 
            +
            nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            cflags = -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include /std:c++17 -O3 /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc
         | 
| 6 | 
            +
            post_cflags = 
         | 
| 7 | 
            +
            cuda_cflags = -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80
         | 
| 8 | 
            +
            cuda_post_cflags = 
         | 
| 9 | 
            +
            cuda_dlink_post_cflags = 
         | 
| 10 | 
            +
            sycl_dlink_post_cflags = 
         | 
| 11 | 
            +
            ldflags = /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:D:\Anaconda\envs\vocos\lib\site-packages\torch\lib torch_python.lib /LIBPATH:D:\Anaconda\envs\vocos\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\lib\x64" cudart.lib
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            rule compile
         | 
| 14 | 
            +
              command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags
         | 
| 15 | 
            +
              deps = msvc
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            rule cuda_compile
         | 
| 18 | 
            +
              depfile = $out.d
         | 
| 19 | 
            +
              deps = gcc
         | 
| 20 | 
            +
              command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            rule link
         | 
| 27 | 
            +
              command = "D$:\Visual Studio\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64/link.exe" $in /nologo $ldflags /out:$out
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            build anti_alias_activation.o: compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation.cpp
         | 
| 30 | 
            +
            build anti_alias_activation_cuda.cuda.o: cuda_compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation_cuda.cu
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            build anti_alias_activation_cuda.pyd: link anti_alias_activation.o anti_alias_activation_cuda.cuda.o
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            default anti_alias_activation_cuda.pyd
         | 
    	
        modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (217 Bytes). View file | 
|  | 
    	
        modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc
    ADDED
    
    | Binary file (1.05 kB). View file | 
|  | 
