Spaces:
Running
on
Zero
Running
on
Zero
Upload 131 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +20 -0
- audio_separator/__init__.py +0 -0
- audio_separator/model-data.json +22 -0
- audio_separator/models-scores.json +0 -0
- audio_separator/models.json +216 -0
- audio_separator/separator/__init__.py +1 -0
- audio_separator/separator/architectures/__init__.py +0 -0
- audio_separator/separator/architectures/demucs_separator.py +195 -0
- audio_separator/separator/architectures/mdx_separator.py +451 -0
- audio_separator/separator/architectures/mdxc_separator.py +423 -0
- audio_separator/separator/architectures/vr_separator.py +357 -0
- audio_separator/separator/common_separator.py +403 -0
- audio_separator/separator/separator.py +959 -0
- audio_separator/separator/uvr_lib_v5/__init__.py +0 -0
- audio_separator/separator/uvr_lib_v5/demucs/__init__.py +5 -0
- audio_separator/separator/uvr_lib_v5/demucs/__main__.py +212 -0
- audio_separator/separator/uvr_lib_v5/demucs/apply.py +294 -0
- audio_separator/separator/uvr_lib_v5/demucs/demucs.py +453 -0
- audio_separator/separator/uvr_lib_v5/demucs/filtering.py +451 -0
- audio_separator/separator/uvr_lib_v5/demucs/hdemucs.py +783 -0
- audio_separator/separator/uvr_lib_v5/demucs/htdemucs.py +620 -0
- audio_separator/separator/uvr_lib_v5/demucs/model.py +204 -0
- audio_separator/separator/uvr_lib_v5/demucs/model_v2.py +222 -0
- audio_separator/separator/uvr_lib_v5/demucs/pretrained.py +181 -0
- audio_separator/separator/uvr_lib_v5/demucs/repo.py +146 -0
- audio_separator/separator/uvr_lib_v5/demucs/spec.py +38 -0
- audio_separator/separator/uvr_lib_v5/demucs/states.py +131 -0
- audio_separator/separator/uvr_lib_v5/demucs/tasnet.py +401 -0
- audio_separator/separator/uvr_lib_v5/demucs/tasnet_v2.py +404 -0
- audio_separator/separator/uvr_lib_v5/demucs/transformer.py +675 -0
- audio_separator/separator/uvr_lib_v5/demucs/utils.py +496 -0
- audio_separator/separator/uvr_lib_v5/mdxnet.py +136 -0
- audio_separator/separator/uvr_lib_v5/mixer.ckpt +3 -0
- audio_separator/separator/uvr_lib_v5/modules.py +74 -0
- audio_separator/separator/uvr_lib_v5/playsound.py +241 -0
- audio_separator/separator/uvr_lib_v5/pyrb.py +92 -0
- audio_separator/separator/uvr_lib_v5/results.py +48 -0
- audio_separator/separator/uvr_lib_v5/roformer/attend.py +112 -0
- audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py +535 -0
- audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py +445 -0
- audio_separator/separator/uvr_lib_v5/spec_utils.py +1327 -0
- audio_separator/separator/uvr_lib_v5/stft.py +126 -0
- audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py +253 -0
- audio_separator/separator/uvr_lib_v5/vr_network/__init__.py +1 -0
- audio_separator/separator/uvr_lib_v5/vr_network/layers.py +294 -0
- audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py +149 -0
- audio_separator/separator/uvr_lib_v5/vr_network/model_param_init.py +71 -0
- audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json +19 -0
- audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json +19 -0
- audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json +19 -0
.gitattributes
CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tests/inputs/mardy20s.flac filter=lfs diff=lfs merge=lfs -text
|
37 |
+
tests/inputs/reference/expected_mardy20s_(Bass)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
tests/inputs/reference/expected_mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
tests/inputs/reference/expected_mardy20s_(Drums)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
tests/inputs/reference/expected_mardy20s_(Guitar)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_2_HP-UVR_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_kuielab_b_vocals_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_MGM_MAIN_v4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
tests/inputs/reference/expected_mardy20s_(No[[:space:]]Drum-Bass)_model_bs_roformer_ep_937_sdr_10_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
tests/inputs/reference/expected_mardy20s_(Other)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
tests/inputs/reference/expected_mardy20s_(Piano)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_2_HP-UVR_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_kuielab_b_vocals_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_MGM_MAIN_v4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
tests/inputs/reference/expected_mardy20s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
audio_separator/__init__.py
ADDED
File without changes
|
audio_separator/model-data.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vr_model_data": {
|
3 |
+
"97dc361a7a88b2c4542f68364b32c7f6": {
|
4 |
+
"vr_model_param": "4band_v4_ms_fullband",
|
5 |
+
"primary_stem": "Dry",
|
6 |
+
"nout": 32,
|
7 |
+
"nout_lstm": 128,
|
8 |
+
"is_karaoke": false,
|
9 |
+
"is_bv_model": false,
|
10 |
+
"is_bv_model_rebalanced": 0.0
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"mdx_model_data": {
|
14 |
+
"cb790d0c913647ced70fc6b38f5bea1a": {
|
15 |
+
"compensate": 1.010,
|
16 |
+
"mdx_dim_f_set": 2560,
|
17 |
+
"mdx_dim_t_set": 8,
|
18 |
+
"mdx_n_fft_scale_set": 5120,
|
19 |
+
"primary_stem": "Instrumental"
|
20 |
+
}
|
21 |
+
}
|
22 |
+
}
|
audio_separator/models-scores.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
audio_separator/models.json
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vr_download_list": {
|
3 |
+
"VR Arch Single Model v4: UVR-De-Reverb by aufr33-jarredou": "UVR-De-Reverb-aufr33-jarredou.pth"
|
4 |
+
},
|
5 |
+
"mdx_download_list": {
|
6 |
+
"MDX-Net Model: UVR-MDX-NET Inst HQ 5": "UVR-MDX-NET-Inst_HQ_5.onnx"
|
7 |
+
},
|
8 |
+
"mdx23c_download_list": {
|
9 |
+
"MDX23C Model: MDX23C De-Reverb by aufr33-jarredou": {
|
10 |
+
"MDX23C-De-Reverb-aufr33-jarredou.ckpt": "config_dereverb_mdx23c.yaml"
|
11 |
+
},
|
12 |
+
"MDX23C Model: MDX23C DrumSep by aufr33-jarredou": {
|
13 |
+
"MDX23C-DrumSep-aufr33-jarredou.ckpt": "config_drumsep_mdx23c.yaml"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"roformer_download_list": {
|
17 |
+
"Roformer Model: Mel-Roformer-Karaoke-Aufr33-Viperx": {
|
18 |
+
"mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
|
19 |
+
},
|
20 |
+
"Roformer Model: MelBand Roformer | Karaoke by Gabox": {
|
21 |
+
"mel_band_roformer_karaoke_gabox.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
|
22 |
+
},
|
23 |
+
"Roformer Model: MelBand Roformer | Karaoke by becruily": {
|
24 |
+
"mel_band_roformer_karaoke_becruily.ckpt": "config_mel_band_roformer_karaoke_becruily.yaml"
|
25 |
+
},
|
26 |
+
"Roformer Model: Mel-Roformer-Denoise-Aufr33": {
|
27 |
+
"denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt": "denoise_mel_band_roformer_aufr33_sdr_27.9959_config.yaml"
|
28 |
+
},
|
29 |
+
"Roformer Model: Mel-Roformer-Denoise-Aufr33-Aggr": {
|
30 |
+
"denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt": "denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768_config.yaml"
|
31 |
+
},
|
32 |
+
"Roformer Model: MelBand Roformer | Denoise-Debleed by Gabox": {
|
33 |
+
"mel_band_roformer_denoise_debleed_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
34 |
+
},
|
35 |
+
"Roformer Model: Mel-Roformer-Crowd-Aufr33-Viperx": {
|
36 |
+
"mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt": "mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144_config.yaml"
|
37 |
+
},
|
38 |
+
"Roformer Model: BS-Roformer-De-Reverb": {
|
39 |
+
"deverb_bs_roformer_8_384dim_10depth.ckpt": "deverb_bs_roformer_8_384dim_10depth_config.yaml"
|
40 |
+
},
|
41 |
+
"Roformer Model: MelBand Roformer | Vocals by Kimberley Jensen": {
|
42 |
+
"vocals_mel_band_roformer.ckpt": "vocals_mel_band_roformer.yaml"
|
43 |
+
},
|
44 |
+
"Roformer Model: MelBand Roformer Kim | FT by unwa": {
|
45 |
+
"mel_band_roformer_kim_ft_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
46 |
+
},
|
47 |
+
"Roformer Model: MelBand Roformer Kim | FT 2 by unwa": {
|
48 |
+
"mel_band_roformer_kim_ft2_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
49 |
+
},
|
50 |
+
"Roformer Model: MelBand Roformer Kim | FT 2 Bleedless by unwa": {
|
51 |
+
"mel_band_roformer_kim_ft2_bleedless_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
52 |
+
},
|
53 |
+
"Roformer Model: MelBand Roformer Kim | FT 3 by unwa": {
|
54 |
+
"mel_band_roformer_kim_ft3_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
55 |
+
},
|
56 |
+
"Roformer Model: MelBand Roformer Kim | Inst V1 Plus by Unwa": {
|
57 |
+
"melband_roformer_inst_v1_plus.ckpt": "config_melbandroformer_inst.yaml"
|
58 |
+
},
|
59 |
+
"Roformer Model: MelBand Roformer Kim | Inst V1 (E) by Unwa": {
|
60 |
+
"melband_roformer_inst_v1e.ckpt": "config_melbandroformer_inst.yaml"
|
61 |
+
},
|
62 |
+
"Roformer Model: MelBand Roformer Kim | Inst V1 (E) Plus by Unwa": {
|
63 |
+
"melband_roformer_inst_v1e_plus.ckpt": "config_melbandroformer_inst.yaml"
|
64 |
+
},
|
65 |
+
"Roformer Model: MelBand Roformer | Vocals by becruily": {
|
66 |
+
"mel_band_roformer_vocals_becruily.ckpt": "config_mel_band_roformer_vocals_becruily.yaml"
|
67 |
+
},
|
68 |
+
"Roformer Model: MelBand Roformer | Instrumental by becruily": {
|
69 |
+
"mel_band_roformer_instrumental_becruily.ckpt": "config_mel_band_roformer_instrumental_becruily.yaml"
|
70 |
+
},
|
71 |
+
"Roformer Model: MelBand Roformer | Vocals Fullness by Aname": {
|
72 |
+
"mel_band_roformer_vocal_fullness_aname.ckpt": "config_mel_band_roformer_vocal_fullness_aname.yaml"
|
73 |
+
},
|
74 |
+
"Roformer Model: BS Roformer | Vocals by Gabox": {
|
75 |
+
"bs_roformer_vocals_gabox.ckpt": "config_bs_roformer_vocals_gabox.yaml"
|
76 |
+
},
|
77 |
+
"Roformer Model: MelBand Roformer | Vocals by Gabox": {
|
78 |
+
"mel_band_roformer_vocals_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
79 |
+
},
|
80 |
+
"Roformer Model: MelBand Roformer | Vocals FV1 by Gabox": {
|
81 |
+
"mel_band_roformer_vocals_fv1_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
82 |
+
},
|
83 |
+
"Roformer Model: MelBand Roformer | Vocals FV2 by Gabox": {
|
84 |
+
"mel_band_roformer_vocals_fv2_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
85 |
+
},
|
86 |
+
"Roformer Model: MelBand Roformer | Vocals FV3 by Gabox": {
|
87 |
+
"mel_band_roformer_vocals_fv3_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
88 |
+
},
|
89 |
+
"Roformer Model: MelBand Roformer | Vocals FV4 by Gabox": {
|
90 |
+
"mel_band_roformer_vocals_fv4_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
91 |
+
},
|
92 |
+
"Roformer Model: MelBand Roformer | Instrumental by Gabox": {
|
93 |
+
"mel_band_roformer_instrumental_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
94 |
+
},
|
95 |
+
"Roformer Model: MelBand Roformer | Instrumental 2 by Gabox": {
|
96 |
+
"mel_band_roformer_instrumental_2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
97 |
+
},
|
98 |
+
"Roformer Model: MelBand Roformer | Instrumental 3 by Gabox": {
|
99 |
+
"mel_band_roformer_instrumental_3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
100 |
+
},
|
101 |
+
"Roformer Model: MelBand Roformer | Instrumental Bleedless V1 by Gabox": {
|
102 |
+
"mel_band_roformer_instrumental_bleedless_v1_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
103 |
+
},
|
104 |
+
"Roformer Model: MelBand Roformer | Instrumental Bleedless V2 by Gabox": {
|
105 |
+
"mel_band_roformer_instrumental_bleedless_v2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
106 |
+
},
|
107 |
+
"Roformer Model: MelBand Roformer | Instrumental Bleedless V3 by Gabox": {
|
108 |
+
"mel_band_roformer_instrumental_bleedless_v3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
109 |
+
},
|
110 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness V1 by Gabox": {
|
111 |
+
"mel_band_roformer_instrumental_fullness_v1_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
112 |
+
},
|
113 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness V2 by Gabox": {
|
114 |
+
"mel_band_roformer_instrumental_fullness_v2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
115 |
+
},
|
116 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness V3 by Gabox": {
|
117 |
+
"mel_band_roformer_instrumental_fullness_v3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
118 |
+
},
|
119 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness Noisy V4 by Gabox": {
|
120 |
+
"mel_band_roformer_instrumental_fullness_noise_v4_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
121 |
+
},
|
122 |
+
"Roformer Model: MelBand Roformer | INSTV5 by Gabox": {
|
123 |
+
"mel_band_roformer_instrumental_instv5_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
124 |
+
},
|
125 |
+
"Roformer Model: MelBand Roformer | INSTV5N by Gabox": {
|
126 |
+
"mel_band_roformer_instrumental_instv5n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
127 |
+
},
|
128 |
+
"Roformer Model: MelBand Roformer | INSTV6 by Gabox": {
|
129 |
+
"mel_band_roformer_instrumental_instv6_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
130 |
+
},
|
131 |
+
"Roformer Model: MelBand Roformer | INSTV6N by Gabox": {
|
132 |
+
"mel_band_roformer_instrumental_instv6n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
133 |
+
},
|
134 |
+
"Roformer Model: MelBand Roformer | INSTV7 by Gabox": {
|
135 |
+
"mel_band_roformer_instrumental_instv7_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
136 |
+
},
|
137 |
+
"Roformer Model: MelBand Roformer | INSTV7N by Gabox": {
|
138 |
+
"mel_band_roformer_instrumental_instv7n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
139 |
+
},
|
140 |
+
"Roformer Model: MelBand Roformer | INSTV8 by Gabox": {
|
141 |
+
"mel_band_roformer_instrumental_instv8_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
142 |
+
},
|
143 |
+
"Roformer Model: MelBand Roformer | INSTV8N by Gabox": {
|
144 |
+
"mel_band_roformer_instrumental_instv8n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
145 |
+
},
|
146 |
+
"Roformer Model: MelBand Roformer | FVX by Gabox": {
|
147 |
+
"mel_band_roformer_instrumental_fvx_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
148 |
+
},
|
149 |
+
"Roformer Model: MelBand Roformer | De-Reverb by anvuew": {
|
150 |
+
"dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
|
151 |
+
},
|
152 |
+
"Roformer Model: MelBand Roformer | De-Reverb Less Aggressive by anvuew": {
|
153 |
+
"dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
|
154 |
+
},
|
155 |
+
"Roformer Model: MelBand Roformer | De-Reverb Mono by anvuew": {
|
156 |
+
"dereverb_mel_band_roformer_mono_anvuew.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
|
157 |
+
},
|
158 |
+
"Roformer Model: MelBand Roformer | De-Reverb Big by Sucial": {
|
159 |
+
"dereverb_big_mbr_ep_362.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
|
160 |
+
},
|
161 |
+
"Roformer Model: MelBand Roformer | De-Reverb Super Big by Sucial": {
|
162 |
+
"dereverb_super_big_mbr_ep_346.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
|
163 |
+
},
|
164 |
+
"Roformer Model: MelBand Roformer | De-Reverb-Echo by Sucial": {
|
165 |
+
"dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt": "config_dereverb-echo_mel_band_roformer.yaml"
|
166 |
+
},
|
167 |
+
"Roformer Model: MelBand Roformer | De-Reverb-Echo V2 by Sucial": {
|
168 |
+
"dereverb-echo_mel_band_roformer_sdr_13.4843_v2.ckpt": "config_dereverb-echo_mel_band_roformer_sdr_13.4843_v2.yaml"
|
169 |
+
},
|
170 |
+
"Roformer Model: MelBand Roformer | De-Reverb-Echo Fused by Sucial": {
|
171 |
+
"dereverb_echo_mbr_fused.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
|
172 |
+
},
|
173 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT by SYH99999": {
|
174 |
+
"MelBandRoformerSYHFT.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
175 |
+
},
|
176 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT V2 by SYH99999": {
|
177 |
+
"MelBandRoformerSYHFTV2.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
178 |
+
},
|
179 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT V2.5 by SYH99999": {
|
180 |
+
"MelBandRoformerSYHFTV2.5.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
181 |
+
},
|
182 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT V3 by SYH99999": {
|
183 |
+
"MelBandRoformerSYHFTV3Epsilon.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
184 |
+
},
|
185 |
+
"Roformer Model: MelBand Roformer Kim | Big SYHFT V1 by SYH99999": {
|
186 |
+
"MelBandRoformerBigSYHFTV1.ckpt": "config_vocals_mel_band_roformer_big_v1_ft.yaml"
|
187 |
+
},
|
188 |
+
"Roformer Model: MelBand Roformer Kim | Big Beta 4 FT by unwa": {
|
189 |
+
"melband_roformer_big_beta4.ckpt": "config_melbandroformer_big_beta4.yaml"
|
190 |
+
},
|
191 |
+
"Roformer Model: MelBand Roformer Kim | Big Beta 5e FT by unwa": {
|
192 |
+
"melband_roformer_big_beta5e.ckpt": "config_melband_roformer_big_beta5e.yaml"
|
193 |
+
},
|
194 |
+
"Roformer Model: MelBand Roformer | Big Beta 6 by unwa": {
|
195 |
+
"melband_roformer_big_beta6.ckpt": "config_melbandroformer_big_beta6.yaml"
|
196 |
+
},
|
197 |
+
"Roformer Model: MelBand Roformer | Big Beta 6X by unwa": {
|
198 |
+
"melband_roformer_big_beta6x.ckpt": "config_melbandroformer_big_beta6x.yaml"
|
199 |
+
},
|
200 |
+
"Roformer Model: BS Roformer | Chorus Male-Female by Sucial": {
|
201 |
+
"model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt": "config_chorus_male_female_bs_roformer.yaml"
|
202 |
+
},
|
203 |
+
"Roformer Model: BS Roformer | Male-Female by aufr33": {
|
204 |
+
"bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt": "config_chorus_male_female_bs_roformer.yaml"
|
205 |
+
},
|
206 |
+
"Roformer Model: MelBand Roformer | Aspiration by Sucial": {
|
207 |
+
"aspiration_mel_band_roformer_sdr_18.9845.ckpt": "config_aspiration_mel_band_roformer.yaml"
|
208 |
+
},
|
209 |
+
"Roformer Model: MelBand Roformer | Aspiration Less Aggressive by Sucial": {
|
210 |
+
"aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt": "config_aspiration_mel_band_roformer.yaml"
|
211 |
+
},
|
212 |
+
"Roformer Model: MelBand Roformer | Bleed Suppressor V1 by unwa-97chris": {
|
213 |
+
"mel_band_roformer_bleed_suppressor_v1.ckpt": "config_mel_band_roformer_bleed_suppressor_v1.yaml"
|
214 |
+
}
|
215 |
+
}
|
216 |
+
}
|
audio_separator/separator/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .separator import Separator
|
audio_separator/separator/architectures/__init__.py
ADDED
File without changes
|
audio_separator/separator/architectures/demucs_separator.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
7 |
+
from audio_separator.separator.uvr_lib_v5.demucs.apply import apply_model, demucs_segments
|
8 |
+
from audio_separator.separator.uvr_lib_v5.demucs.hdemucs import HDemucs
|
9 |
+
from audio_separator.separator.uvr_lib_v5.demucs.pretrained import get_model as get_demucs_model
|
10 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
11 |
+
|
12 |
+
DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
|
13 |
+
|
14 |
+
DEMUCS_2_SOURCE_MAPPER = {CommonSeparator.INST_STEM: 0, CommonSeparator.VOCAL_STEM: 1}
|
15 |
+
DEMUCS_4_SOURCE_MAPPER = {CommonSeparator.BASS_STEM: 0, CommonSeparator.DRUM_STEM: 1, CommonSeparator.OTHER_STEM: 2, CommonSeparator.VOCAL_STEM: 3}
|
16 |
+
DEMUCS_6_SOURCE_MAPPER = {
|
17 |
+
CommonSeparator.BASS_STEM: 0,
|
18 |
+
CommonSeparator.DRUM_STEM: 1,
|
19 |
+
CommonSeparator.OTHER_STEM: 2,
|
20 |
+
CommonSeparator.VOCAL_STEM: 3,
|
21 |
+
CommonSeparator.GUITAR_STEM: 4,
|
22 |
+
CommonSeparator.PIANO_STEM: 5,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
class DemucsSeparator(CommonSeparator):
|
27 |
+
"""
|
28 |
+
DemucsSeparator is responsible for separating audio sources using Demucs models.
|
29 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, common_config, arch_config):
|
33 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
34 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
35 |
+
super().__init__(config=common_config)
|
36 |
+
|
37 |
+
# Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance
|
38 |
+
|
39 |
+
# Adjust segments to manage RAM or V-RAM usage:
|
40 |
+
# - Smaller sizes consume less resources.
|
41 |
+
# - Bigger sizes consume more resources, but may provide better results.
|
42 |
+
# - "Default" picks the optimal size.
|
43 |
+
# DEMUCS_SEGMENTS = (DEF_OPT, '1', '5', '10', '15', '20',
|
44 |
+
# '25', '30', '35', '40', '45', '50',
|
45 |
+
# '55', '60', '65', '70', '75', '80',
|
46 |
+
# '85', '90', '95', '100')
|
47 |
+
self.segment_size = arch_config.get("segment_size", "Default")
|
48 |
+
|
49 |
+
# Performs multiple predictions with random shifts of the input and averages them.
|
50 |
+
# The higher number of shifts, the longer the prediction will take.
|
51 |
+
# Not recommended unless you have a GPU.
|
52 |
+
# DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5,
|
53 |
+
# 6, 7, 8, 9, 10, 11,
|
54 |
+
# 12, 13, 14, 15, 16, 17,
|
55 |
+
# 18, 19, 20)
|
56 |
+
self.shifts = arch_config.get("shifts", 2)
|
57 |
+
|
58 |
+
# This option controls the amount of overlap between prediction windows.
|
59 |
+
# - Higher values can provide better results, but will lead to longer processing times.
|
60 |
+
# - You can choose between 0.001-0.999
|
61 |
+
# DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99)
|
62 |
+
self.overlap = arch_config.get("overlap", 0.25)
|
63 |
+
|
64 |
+
# Enables "Segments". Deselecting this option is only recommended for those with powerful PCs.
|
65 |
+
self.segments_enabled = arch_config.get("segments_enabled", True)
|
66 |
+
|
67 |
+
self.logger.debug(f"Demucs arch params: segment_size={self.segment_size}, segments_enabled={self.segments_enabled}")
|
68 |
+
self.logger.debug(f"Demucs arch params: shifts={self.shifts}, overlap={self.overlap}")
|
69 |
+
|
70 |
+
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
71 |
+
|
72 |
+
self.audio_file_path = None
|
73 |
+
self.audio_file_base = None
|
74 |
+
self.demucs_model_instance = None
|
75 |
+
|
76 |
+
# Add uvr_lib_v5 folder to system path so pytorch serialization can find the demucs module
|
77 |
+
current_dir = os.path.dirname(__file__)
|
78 |
+
uvr_lib_v5_path = os.path.join(current_dir, "..", "uvr_lib_v5")
|
79 |
+
sys.path.insert(0, uvr_lib_v5_path)
|
80 |
+
|
81 |
+
self.logger.info("Demucs Separator initialisation complete")
|
82 |
+
|
83 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
84 |
+
"""
|
85 |
+
Separates the audio file into its component stems using the Demucs model.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
audio_file_path (str): The path to the audio file to be processed.
|
89 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
list: A list of paths to the output files generated by the separation process.
|
93 |
+
"""
|
94 |
+
self.logger.debug("Starting separation process...")
|
95 |
+
source = None
|
96 |
+
stem_source = None
|
97 |
+
inst_source = {}
|
98 |
+
|
99 |
+
self.audio_file_path = audio_file_path
|
100 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
101 |
+
|
102 |
+
# Prepare the mix for processing
|
103 |
+
self.logger.debug("Preparing mix...")
|
104 |
+
mix = self.prepare_mix(self.audio_file_path)
|
105 |
+
|
106 |
+
self.logger.debug(f"Mix prepared for demixing. Shape: {mix.shape}")
|
107 |
+
|
108 |
+
self.logger.debug("Loading model for demixing...")
|
109 |
+
|
110 |
+
self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE)
|
111 |
+
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
|
112 |
+
self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance)
|
113 |
+
self.demucs_model_instance.to(self.torch_device)
|
114 |
+
self.demucs_model_instance.eval()
|
115 |
+
|
116 |
+
self.logger.debug("Model loaded and set to evaluation mode.")
|
117 |
+
|
118 |
+
source = self.demix_demucs(mix)
|
119 |
+
|
120 |
+
del self.demucs_model_instance
|
121 |
+
self.clear_gpu_cache()
|
122 |
+
self.logger.debug("Model and GPU cache cleared after demixing.")
|
123 |
+
|
124 |
+
output_files = []
|
125 |
+
self.logger.debug("Processing output files...")
|
126 |
+
|
127 |
+
if isinstance(inst_source, np.ndarray):
|
128 |
+
self.logger.debug("Processing instance source...")
|
129 |
+
source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]])
|
130 |
+
inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape
|
131 |
+
source = inst_source
|
132 |
+
|
133 |
+
if isinstance(source, np.ndarray):
|
134 |
+
source_length = len(source)
|
135 |
+
self.logger.debug(f"Processing source array, source length is {source_length}")
|
136 |
+
match source_length:
|
137 |
+
case 2:
|
138 |
+
self.logger.debug("Setting source map to 2-stem...")
|
139 |
+
self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
|
140 |
+
case 6:
|
141 |
+
self.logger.debug("Setting source map to 6-stem...")
|
142 |
+
self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER
|
143 |
+
case _:
|
144 |
+
self.logger.debug("Setting source map to 4-stem...")
|
145 |
+
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
146 |
+
|
147 |
+
self.logger.debug("Processing for all stems...")
|
148 |
+
for stem_name, stem_value in self.demucs_source_map.items():
|
149 |
+
if self.output_single_stem is not None:
|
150 |
+
if stem_name.lower() != self.output_single_stem.lower():
|
151 |
+
self.logger.debug(f"Skipping writing stem {stem_name} as output_single_stem is set to {self.output_single_stem}...")
|
152 |
+
continue
|
153 |
+
|
154 |
+
stem_path = self.get_stem_output_path(stem_name, custom_output_names)
|
155 |
+
stem_source = source[stem_value].T
|
156 |
+
|
157 |
+
self.final_process(stem_path, stem_source, stem_name)
|
158 |
+
output_files.append(stem_path)
|
159 |
+
|
160 |
+
return output_files
|
161 |
+
|
162 |
+
def demix_demucs(self, mix):
|
163 |
+
"""
|
164 |
+
Demixes the input mix using the demucs model.
|
165 |
+
"""
|
166 |
+
self.logger.debug("Starting demixing process in demix_demucs...")
|
167 |
+
|
168 |
+
processed = {}
|
169 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
170 |
+
ref = mix.mean(0)
|
171 |
+
mix = (mix - ref.mean()) / ref.std()
|
172 |
+
mix_infer = mix
|
173 |
+
|
174 |
+
with torch.no_grad():
|
175 |
+
self.logger.debug("Running model inference...")
|
176 |
+
sources = apply_model(
|
177 |
+
model=self.demucs_model_instance,
|
178 |
+
mix=mix_infer[None],
|
179 |
+
shifts=self.shifts,
|
180 |
+
split=self.segments_enabled,
|
181 |
+
overlap=self.overlap,
|
182 |
+
static_shifts=1 if self.shifts == 0 else self.shifts,
|
183 |
+
set_progress_bar=None,
|
184 |
+
device=self.torch_device,
|
185 |
+
progress=True,
|
186 |
+
)[0]
|
187 |
+
|
188 |
+
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
189 |
+
sources[[0, 1]] = sources[[1, 0]]
|
190 |
+
processed[mix] = sources[:, :, 0:None].copy()
|
191 |
+
sources = list(processed.values())
|
192 |
+
sources = [s[:, :, 0:None] for s in sources]
|
193 |
+
sources = np.concatenate(sources, axis=-1)
|
194 |
+
|
195 |
+
return sources
|
audio_separator/separator/architectures/mdx_separator.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for separating audio sources using MDX architecture models."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import torch
|
6 |
+
import onnx
|
7 |
+
import onnxruntime as ort
|
8 |
+
import numpy as np
|
9 |
+
import onnx2torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
12 |
+
from audio_separator.separator.uvr_lib_v5.stft import STFT
|
13 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
14 |
+
|
15 |
+
|
16 |
+
class MDXSeparator(CommonSeparator):
|
17 |
+
"""
|
18 |
+
MDXSeparator is responsible for separating audio sources using MDX models.
|
19 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, common_config, arch_config):
|
23 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
24 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
25 |
+
super().__init__(config=common_config)
|
26 |
+
|
27 |
+
# Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance
|
28 |
+
|
29 |
+
# Pick a segment size to balance speed, resource use, and quality:
|
30 |
+
# - Smaller sizes consume less resources.
|
31 |
+
# - Bigger sizes consume more resources, but may provide better results.
|
32 |
+
# - Default size is 256. Quality can change based on your pick.
|
33 |
+
self.segment_size = arch_config.get("segment_size")
|
34 |
+
|
35 |
+
# This option controls the amount of overlap between prediction windows.
|
36 |
+
# - Higher values can provide better results, but will lead to longer processing times.
|
37 |
+
# - For Non-MDX23C models: You can choose between 0.001-0.999
|
38 |
+
self.overlap = arch_config.get("overlap")
|
39 |
+
|
40 |
+
# Number of batches to be processed at a time.
|
41 |
+
# - Higher values mean more RAM usage but slightly faster processing times.
|
42 |
+
# - Lower values mean less RAM usage but slightly longer processing times.
|
43 |
+
# - Batch size value has no effect on output quality.
|
44 |
+
# BATCH_SIZE = ('1', ''2', '3', '4', '5', '6', '7', '8', '9', '10')
|
45 |
+
self.batch_size = arch_config.get("batch_size", 1)
|
46 |
+
|
47 |
+
# hop_length is equivalent to the more commonly used term "stride" in convolutional neural networks
|
48 |
+
# In machine learning, particularly in the context of convolutional neural networks (CNNs),
|
49 |
+
# the term "stride" refers to the number of pixels by which we move the filter across the input image.
|
50 |
+
# Strides are a crucial component in the convolution operation, a fundamental building block of CNNs used primarily in the field of computer vision.
|
51 |
+
# Stride is a parameter that dictates the movement of the kernel, or filter, across the input data, such as an image.
|
52 |
+
# When performing a convolution operation, the stride determines how many units the filter shifts at each step.
|
53 |
+
# The choice of stride affects the model in several ways:
|
54 |
+
# Output Size: A larger stride will result in a smaller output spatial dimension.
|
55 |
+
# Computational Efficiency: Increasing the stride can decrease the computational load.
|
56 |
+
# Field of View: A higher stride means that each step of the filter takes into account a wider area of the input image.
|
57 |
+
# This can be beneficial when the model needs to capture more global features rather than focusing on finer details.
|
58 |
+
self.hop_length = arch_config.get("hop_length")
|
59 |
+
|
60 |
+
# If enabled, model will be run twice to reduce noise in output audio.
|
61 |
+
self.enable_denoise = arch_config.get("enable_denoise")
|
62 |
+
|
63 |
+
self.logger.debug(f"MDX arch params: batch_size={self.batch_size}, segment_size={self.segment_size}")
|
64 |
+
self.logger.debug(f"MDX arch params: overlap={self.overlap}, hop_length={self.hop_length}, enable_denoise={self.enable_denoise}")
|
65 |
+
|
66 |
+
# Initializing model-specific parameters from model_data JSON
|
67 |
+
self.compensate = self.model_data["compensate"]
|
68 |
+
self.dim_f = self.model_data["mdx_dim_f_set"]
|
69 |
+
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
70 |
+
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
71 |
+
self.config_yaml = self.model_data.get("config_yaml", None)
|
72 |
+
|
73 |
+
self.logger.debug(f"MDX arch params: compensate={self.compensate}, dim_f={self.dim_f}, dim_t={self.dim_t}, n_fft={self.n_fft}")
|
74 |
+
self.logger.debug(f"MDX arch params: config_yaml={self.config_yaml}")
|
75 |
+
|
76 |
+
# In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
|
77 |
+
# Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.
|
78 |
+
|
79 |
+
# "chunks" is not actually used for anything in UVR...
|
80 |
+
# self.chunks = 0
|
81 |
+
|
82 |
+
# "adjust" is hard-coded to 1 in UVR, and only used as a multiplier in run_model, so it does nothing.
|
83 |
+
# self.adjust = 1
|
84 |
+
|
85 |
+
# "hop" is hard-coded to 1024 in UVR. We have a "hop_length" parameter instead
|
86 |
+
# self.hop = 1024
|
87 |
+
|
88 |
+
# "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead.
|
89 |
+
# self.margin = 44100
|
90 |
+
|
91 |
+
# "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
|
92 |
+
# We haven't implemented support for the checkpoint models here, so we're not using it.
|
93 |
+
# self.dim_c = 4
|
94 |
+
|
95 |
+
self.load_model()
|
96 |
+
|
97 |
+
self.n_bins = 0
|
98 |
+
self.trim = 0
|
99 |
+
self.chunk_size = 0
|
100 |
+
self.gen_size = 0
|
101 |
+
self.stft = None
|
102 |
+
|
103 |
+
self.primary_source = None
|
104 |
+
self.secondary_source = None
|
105 |
+
self.audio_file_path = None
|
106 |
+
self.audio_file_base = None
|
107 |
+
|
108 |
+
def load_model(self):
|
109 |
+
"""
|
110 |
+
Load the model into memory from file on disk, initialize it with config from the model data,
|
111 |
+
and prepare for inferencing using hardware accelerated Torch device.
|
112 |
+
"""
|
113 |
+
self.logger.debug("Loading ONNX model for inference...")
|
114 |
+
|
115 |
+
if self.segment_size == self.dim_t:
|
116 |
+
ort_session_options = ort.SessionOptions()
|
117 |
+
if self.log_level > 10:
|
118 |
+
ort_session_options.log_severity_level = 3
|
119 |
+
else:
|
120 |
+
ort_session_options.log_severity_level = 0
|
121 |
+
|
122 |
+
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
123 |
+
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
124 |
+
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
|
125 |
+
else:
|
126 |
+
if platform.system() == 'Windows':
|
127 |
+
onnx_model = onnx.load(self.model_path)
|
128 |
+
self.model_run = onnx2torch.convert(onnx_model)
|
129 |
+
else:
|
130 |
+
self.model_run = onnx2torch.convert(self.model_path)
|
131 |
+
|
132 |
+
self.model_run.to(self.torch_device).eval()
|
133 |
+
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
|
134 |
+
|
135 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
136 |
+
"""
|
137 |
+
Separates the audio file into primary and secondary sources based on the model's configuration.
|
138 |
+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
audio_file_path (str): The path to the audio file to be processed.
|
142 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list: A list of paths to the output files generated by the separation process.
|
146 |
+
"""
|
147 |
+
self.audio_file_path = audio_file_path
|
148 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
149 |
+
|
150 |
+
# Prepare the mix for processing
|
151 |
+
self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
|
152 |
+
mix = self.prepare_mix(self.audio_file_path)
|
153 |
+
|
154 |
+
self.logger.debug("Normalizing mix before demixing...")
|
155 |
+
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
|
156 |
+
|
157 |
+
# Start the demixing process
|
158 |
+
source = self.demix(mix)
|
159 |
+
self.logger.debug("Demixing completed.")
|
160 |
+
|
161 |
+
# In UVR, the source is cached here if it's a vocal split model, but we're not supporting that yet
|
162 |
+
|
163 |
+
# Initialize the list for output files
|
164 |
+
output_files = []
|
165 |
+
self.logger.debug("Processing output files...")
|
166 |
+
|
167 |
+
# Normalize and transpose the primary source if it's not already an array
|
168 |
+
if not isinstance(self.primary_source, np.ndarray):
|
169 |
+
self.logger.debug("Normalizing primary source...")
|
170 |
+
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
|
171 |
+
|
172 |
+
# Process the secondary source if not already an array
|
173 |
+
if not isinstance(self.secondary_source, np.ndarray):
|
174 |
+
self.logger.debug("Producing secondary source: demixing in match_mix mode")
|
175 |
+
raw_mix = self.demix(mix, is_match_mix=True)
|
176 |
+
|
177 |
+
if self.invert_using_spec:
|
178 |
+
self.logger.debug("Inverting secondary stem using spectogram as invert_using_spec is set to True")
|
179 |
+
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
180 |
+
else:
|
181 |
+
self.logger.debug("Inverting secondary stem by subtracting of transposed demixed stem from transposed original mix")
|
182 |
+
self.secondary_source = mix.T - source.T
|
183 |
+
|
184 |
+
# Save and process the secondary stem if needed
|
185 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
186 |
+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
|
187 |
+
|
188 |
+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
|
189 |
+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
190 |
+
output_files.append(self.secondary_stem_output_path)
|
191 |
+
|
192 |
+
# Save and process the primary stem if needed
|
193 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
194 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
195 |
+
|
196 |
+
if not isinstance(self.primary_source, np.ndarray):
|
197 |
+
self.primary_source = source.T
|
198 |
+
|
199 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
200 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
201 |
+
output_files.append(self.primary_stem_output_path)
|
202 |
+
|
203 |
+
# Not yet implemented from UVR features:
|
204 |
+
# self.process_vocal_split_chain(secondary_sources)
|
205 |
+
# self.logger.debug("Vocal split chain processed.")
|
206 |
+
|
207 |
+
return output_files
|
208 |
+
|
209 |
+
def initialize_model_settings(self):
|
210 |
+
"""
|
211 |
+
This function sets up the necessary parameters for the model, like the number of frequency bins (n_bins), the trimming size (trim),
|
212 |
+
the size of each audio chunk (chunk_size), and the window function for spectral transformations (window).
|
213 |
+
It ensures that the model is configured with the correct settings for processing the audio data.
|
214 |
+
"""
|
215 |
+
self.logger.debug("Initializing model settings...")
|
216 |
+
|
217 |
+
# n_bins is half the FFT size plus one (self.n_fft // 2 + 1).
|
218 |
+
self.n_bins = self.n_fft // 2 + 1
|
219 |
+
|
220 |
+
# trim is half the FFT size (self.n_fft // 2).
|
221 |
+
self.trim = self.n_fft // 2
|
222 |
+
|
223 |
+
# chunk_size is the hop_length size times the segment size minus one
|
224 |
+
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
225 |
+
|
226 |
+
# gen_size is the chunk size minus twice the trim size
|
227 |
+
self.gen_size = self.chunk_size - 2 * self.trim
|
228 |
+
|
229 |
+
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
230 |
+
|
231 |
+
self.logger.debug(f"Model input params: n_fft={self.n_fft} hop_length={self.hop_length} dim_f={self.dim_f}")
|
232 |
+
self.logger.debug(f"Model settings: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}")
|
233 |
+
|
234 |
+
def initialize_mix(self, mix, is_ckpt=False):
|
235 |
+
"""
|
236 |
+
After prepare_mix segments the audio, initialize_mix further processes each segment.
|
237 |
+
It ensures each audio segment is in the correct format for the model, applies necessary padding,
|
238 |
+
and converts the segments into tensors for processing with the model.
|
239 |
+
This step is essential for preparing the audio data in a format that the neural network can process.
|
240 |
+
"""
|
241 |
+
# Log the initialization of the mix and whether checkpoint mode is used
|
242 |
+
self.logger.debug(f"Initializing mix with is_ckpt={is_ckpt}. Initial mix shape: {mix.shape}")
|
243 |
+
|
244 |
+
# Ensure the mix is a 2-channel (stereo) audio signal
|
245 |
+
if mix.shape[0] != 2:
|
246 |
+
error_message = f"Expected a 2-channel audio signal, but got {mix.shape[0]} channels"
|
247 |
+
self.logger.error(error_message)
|
248 |
+
raise ValueError(error_message)
|
249 |
+
|
250 |
+
# If in checkpoint mode, process the mix differently
|
251 |
+
if is_ckpt:
|
252 |
+
self.logger.debug("Processing in checkpoint mode...")
|
253 |
+
# Calculate padding based on the generation size and trim
|
254 |
+
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
255 |
+
self.logger.debug(f"Padding calculated: {pad}")
|
256 |
+
# Add padding at the beginning and the end of the mix
|
257 |
+
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
258 |
+
# Determine the number of chunks based on the mixture's length
|
259 |
+
num_chunks = mixture.shape[-1] // self.gen_size
|
260 |
+
self.logger.debug(f"Mixture shape after padding: {mixture.shape}, Number of chunks: {num_chunks}")
|
261 |
+
# Split the mixture into chunks
|
262 |
+
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
263 |
+
else:
|
264 |
+
# If not in checkpoint mode, process normally
|
265 |
+
self.logger.debug("Processing in non-checkpoint mode...")
|
266 |
+
mix_waves = []
|
267 |
+
n_sample = mix.shape[1]
|
268 |
+
# Calculate necessary padding to make the total length divisible by the generation size
|
269 |
+
pad = self.gen_size - n_sample % self.gen_size
|
270 |
+
self.logger.debug(f"Number of samples: {n_sample}, Padding calculated: {pad}")
|
271 |
+
# Apply padding to the mix
|
272 |
+
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
273 |
+
self.logger.debug(f"Shape of mix after padding: {mix_p.shape}")
|
274 |
+
|
275 |
+
# Process the mix in chunks
|
276 |
+
i = 0
|
277 |
+
while i < n_sample + pad:
|
278 |
+
waves = np.array(mix_p[:, i : i + self.chunk_size])
|
279 |
+
mix_waves.append(waves)
|
280 |
+
self.logger.debug(f"Processed chunk {len(mix_waves)}: Start {i}, End {i + self.chunk_size}")
|
281 |
+
i += self.gen_size
|
282 |
+
|
283 |
+
# Convert the list of wave chunks into a tensor for processing on the specified device
|
284 |
+
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
285 |
+
self.logger.debug(f"Converted mix_waves to tensor. Tensor shape: {mix_waves_tensor.shape}")
|
286 |
+
|
287 |
+
return mix_waves_tensor, pad
|
288 |
+
|
289 |
+
def demix(self, mix, is_match_mix=False):
|
290 |
+
"""
|
291 |
+
Demixes the input mix into its constituent sources. If is_match_mix is True, the function adjusts the processing
|
292 |
+
to better match the mix, affecting chunk sizes and overlaps. The demixing process involves padding the mix,
|
293 |
+
processing it in chunks, applying windowing for overlaps, and accumulating the results to separate the sources.
|
294 |
+
"""
|
295 |
+
self.logger.debug(f"Starting demixing process with is_match_mix: {is_match_mix}...")
|
296 |
+
self.initialize_model_settings()
|
297 |
+
|
298 |
+
# Preserves the original mix for later use.
|
299 |
+
# In UVR, this is used for the pitch fix and VR denoise processes, which aren't yet implemented here.
|
300 |
+
org_mix = mix
|
301 |
+
self.logger.debug(f"Original mix stored. Shape: {org_mix.shape}")
|
302 |
+
|
303 |
+
# Initializes a list to store the separated waveforms.
|
304 |
+
tar_waves_ = []
|
305 |
+
|
306 |
+
# Handling different chunk sizes and overlaps based on the matching requirement.
|
307 |
+
if is_match_mix:
|
308 |
+
# Sets a smaller chunk size specifically for matching the mix.
|
309 |
+
chunk_size = self.hop_length * (self.segment_size - 1)
|
310 |
+
# Sets a small overlap for the chunks.
|
311 |
+
overlap = 0.02
|
312 |
+
self.logger.debug(f"Chunk size for matching mix: {chunk_size}, Overlap: {overlap}")
|
313 |
+
else:
|
314 |
+
# Uses the regular chunk size defined in model settings.
|
315 |
+
chunk_size = self.chunk_size
|
316 |
+
# Uses the overlap specified in the model settings.
|
317 |
+
overlap = self.overlap
|
318 |
+
self.logger.debug(f"Standard chunk size: {chunk_size}, Overlap: {overlap}")
|
319 |
+
|
320 |
+
# Calculates the generated size after subtracting the trim from both ends of the chunk.
|
321 |
+
gen_size = chunk_size - 2 * self.trim
|
322 |
+
self.logger.debug(f"Generated size calculated: {gen_size}")
|
323 |
+
|
324 |
+
# Calculates padding to make the mix length a multiple of the generated size.
|
325 |
+
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
|
326 |
+
# Prepares the mixture with padding at the beginning and the end.
|
327 |
+
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
328 |
+
self.logger.debug(f"Mixture prepared with padding. Mixture shape: {mixture.shape}")
|
329 |
+
|
330 |
+
# Calculates the step size for processing chunks based on the overlap.
|
331 |
+
step = int((1 - overlap) * chunk_size)
|
332 |
+
self.logger.debug(f"Step size for processing chunks: {step} as overlap is set to {overlap}.")
|
333 |
+
|
334 |
+
# Initializes arrays to store the results and to account for overlap.
|
335 |
+
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
336 |
+
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
337 |
+
|
338 |
+
# Initializes counters for processing chunks.
|
339 |
+
total = 0
|
340 |
+
total_chunks = (mixture.shape[-1] + step - 1) // step
|
341 |
+
self.logger.debug(f"Total chunks to process: {total_chunks}")
|
342 |
+
|
343 |
+
# Processes each chunk of the mixture.
|
344 |
+
for i in tqdm(range(0, mixture.shape[-1], step)):
|
345 |
+
total += 1
|
346 |
+
start = i
|
347 |
+
end = min(i + chunk_size, mixture.shape[-1])
|
348 |
+
self.logger.debug(f"Processing chunk {total}/{total_chunks}: Start {start}, End {end}")
|
349 |
+
|
350 |
+
# Handles windowing for overlapping chunks.
|
351 |
+
chunk_size_actual = end - start
|
352 |
+
window = None
|
353 |
+
if overlap != 0:
|
354 |
+
window = np.hanning(chunk_size_actual)
|
355 |
+
window = np.tile(window[None, None, :], (1, 2, 1))
|
356 |
+
self.logger.debug("Window applied to the chunk.")
|
357 |
+
|
358 |
+
# Zero-pad the chunk to prepare it for processing.
|
359 |
+
mix_part_ = mixture[:, start:end]
|
360 |
+
if end != i + chunk_size:
|
361 |
+
pad_size = (i + chunk_size) - end
|
362 |
+
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
363 |
+
|
364 |
+
# Converts the chunk to a tensor for processing.
|
365 |
+
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
|
366 |
+
# Splits the chunk into smaller batches if necessary.
|
367 |
+
mix_waves = mix_part.split(self.batch_size)
|
368 |
+
total_batches = len(mix_waves)
|
369 |
+
self.logger.debug(f"Mix part split into batches. Number of batches: {total_batches}")
|
370 |
+
|
371 |
+
with torch.no_grad():
|
372 |
+
# Processes each batch in the chunk.
|
373 |
+
batches_processed = 0
|
374 |
+
for mix_wave in mix_waves:
|
375 |
+
batches_processed += 1
|
376 |
+
self.logger.debug(f"Processing mix_wave batch {batches_processed}/{total_batches}")
|
377 |
+
|
378 |
+
# Runs the model to separate the sources.
|
379 |
+
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
380 |
+
|
381 |
+
# Applies windowing if needed and accumulates the results.
|
382 |
+
if window is not None:
|
383 |
+
tar_waves[..., :chunk_size_actual] *= window
|
384 |
+
divider[..., start:end] += window
|
385 |
+
else:
|
386 |
+
divider[..., start:end] += 1
|
387 |
+
|
388 |
+
result[..., start:end] += tar_waves[..., : end - start]
|
389 |
+
|
390 |
+
# Normalizes the results by the divider to account for overlap.
|
391 |
+
self.logger.debug("Normalizing result by dividing result by divider.")
|
392 |
+
tar_waves = result / divider
|
393 |
+
tar_waves_.append(tar_waves)
|
394 |
+
|
395 |
+
# Reshapes the results to match the original dimensions.
|
396 |
+
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
|
397 |
+
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]]
|
398 |
+
|
399 |
+
# Extracts the source from the results.
|
400 |
+
source = tar_waves[:, 0:None]
|
401 |
+
self.logger.debug(f"Concatenated tar_waves. Shape: {tar_waves.shape}")
|
402 |
+
|
403 |
+
# TODO: In UVR, pitch changing happens here. Consider implementing this as a feature.
|
404 |
+
|
405 |
+
# Compensates the source if not matching the mix.
|
406 |
+
if not is_match_mix:
|
407 |
+
source *= self.compensate
|
408 |
+
self.logger.debug("Match mix mode; compensate multiplier applied.")
|
409 |
+
|
410 |
+
# TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature.
|
411 |
+
|
412 |
+
self.logger.debug("Demixing process completed.")
|
413 |
+
return source
|
414 |
+
|
415 |
+
def run_model(self, mix, is_match_mix=False):
|
416 |
+
"""
|
417 |
+
Processes the input mix through the model to separate the sources.
|
418 |
+
Applies STFT, handles spectrum modifications, and runs the model for source separation.
|
419 |
+
"""
|
420 |
+
# Applying the STFT to the mix. The mix is moved to the specified device (e.g., GPU) before processing.
|
421 |
+
# self.logger.debug(f"Running STFT on the mix. Mix shape before STFT: {mix.shape}")
|
422 |
+
spek = self.stft(mix.to(self.torch_device))
|
423 |
+
self.logger.debug(f"STFT applied on mix. Spectrum shape: {spek.shape}")
|
424 |
+
|
425 |
+
# Zeroing out the first 3 bins of the spectrum. This is often done to reduce low-frequency noise.
|
426 |
+
spek[:, :, :3, :] *= 0
|
427 |
+
# self.logger.debug("First 3 bins of the spectrum zeroed out.")
|
428 |
+
|
429 |
+
# Handling the case where the mix needs to be matched (is_match_mix = True)
|
430 |
+
if is_match_mix:
|
431 |
+
# self.logger.debug("Match mix mode is enabled. Converting spectrum to NumPy array.")
|
432 |
+
spec_pred = spek.cpu().numpy()
|
433 |
+
self.logger.debug("is_match_mix: spectrum prediction obtained directly from STFT output.")
|
434 |
+
else:
|
435 |
+
# If denoising is enabled, the model is run on both the negative and positive spectrums.
|
436 |
+
if self.enable_denoise:
|
437 |
+
# Assuming spek is a tensor and self.model_run can process it directly
|
438 |
+
spec_pred_neg = self.model_run(-spek) # Ensure this line correctly negates spek and runs the model
|
439 |
+
spec_pred_pos = self.model_run(spek)
|
440 |
+
# Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations
|
441 |
+
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type]
|
442 |
+
self.logger.debug("Model run on both negative and positive spectrums for denoising.")
|
443 |
+
else:
|
444 |
+
spec_pred = self.model_run(spek)
|
445 |
+
self.logger.debug("Model run on the spectrum without denoising.")
|
446 |
+
|
447 |
+
# Applying the inverse STFT to convert the spectrum back to the time domain.
|
448 |
+
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
449 |
+
self.logger.debug(f"Inverse STFT applied. Returning result with shape: {result.shape}")
|
450 |
+
|
451 |
+
return result
|
audio_separator/separator/architectures/mdxc_separator.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from ml_collections import ConfigDict
|
8 |
+
from scipy import signal
|
9 |
+
|
10 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
11 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
12 |
+
from audio_separator.separator.uvr_lib_v5.tfc_tdf_v3 import TFC_TDF_net
|
13 |
+
from audio_separator.separator.uvr_lib_v5.roformer.mel_band_roformer import MelBandRoformer
|
14 |
+
from audio_separator.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer
|
15 |
+
|
16 |
+
|
17 |
+
class MDXCSeparator(CommonSeparator):
|
18 |
+
"""
|
19 |
+
MDXCSeparator is responsible for separating audio sources using MDXC models.
|
20 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, common_config, arch_config):
|
24 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
25 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
26 |
+
super().__init__(config=common_config)
|
27 |
+
|
28 |
+
# Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
|
29 |
+
# It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
|
30 |
+
# The instance variable self.model_data is passed through from Separator and set in CommonSeparator
|
31 |
+
self.logger.debug(f"Model data: {self.model_data}")
|
32 |
+
|
33 |
+
# Arch Config is the MDXC architecture specific user configuration options, which should all be configurable by the user
|
34 |
+
# either by their Separator class instantiation or by passing in a CLI parameter.
|
35 |
+
# While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
|
36 |
+
# this way as they have architecture-specific default values.
|
37 |
+
self.segment_size = arch_config.get("segment_size", 256)
|
38 |
+
|
39 |
+
# Whether or not to use the segment size from model config, or the default
|
40 |
+
# The segment size is set based on the value provided in a chosen model's associated config file (yaml).
|
41 |
+
self.override_model_segment_size = arch_config.get("override_model_segment_size", False)
|
42 |
+
|
43 |
+
self.overlap = arch_config.get("overlap", 8)
|
44 |
+
self.batch_size = arch_config.get("batch_size", 1)
|
45 |
+
|
46 |
+
# Amount of pitch shift to apply during processing (this does NOT affect the pitch of the output audio):
|
47 |
+
# • Whole numbers indicate semitones.
|
48 |
+
# • Using higher pitches may cut the upper bandwidth, even in high-quality models.
|
49 |
+
# • Upping the pitch can be better for tracks with deeper vocals.
|
50 |
+
# • Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.
|
51 |
+
self.pitch_shift = arch_config.get("pitch_shift", 0)
|
52 |
+
|
53 |
+
self.process_all_stems = arch_config.get("process_all_stems", True)
|
54 |
+
|
55 |
+
self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}")
|
56 |
+
self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
|
57 |
+
self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")
|
58 |
+
|
59 |
+
self.is_roformer = "is_roformer" in self.model_data
|
60 |
+
|
61 |
+
self.load_model()
|
62 |
+
|
63 |
+
self.primary_source = None
|
64 |
+
self.secondary_source = None
|
65 |
+
self.audio_file_path = None
|
66 |
+
self.audio_file_base = None
|
67 |
+
|
68 |
+
self.is_primary_stem_main_target = False
|
69 |
+
if self.model_data_cfgdict.training.target_instrument == "Vocals" or len(self.model_data_cfgdict.training.instruments) > 1:
|
70 |
+
self.is_primary_stem_main_target = True
|
71 |
+
|
72 |
+
self.logger.debug(f"is_primary_stem_main_target: {self.is_primary_stem_main_target}")
|
73 |
+
|
74 |
+
self.logger.info("MDXC Separator initialisation complete")
|
75 |
+
|
76 |
+
def load_model(self):
|
77 |
+
"""
|
78 |
+
Load the model into memory from file on disk, initialize it with config from the model data,
|
79 |
+
and prepare for inferencing using hardware accelerated Torch device.
|
80 |
+
"""
|
81 |
+
self.logger.debug("Loading checkpoint model for inference...")
|
82 |
+
|
83 |
+
self.model_data_cfgdict = ConfigDict(self.model_data)
|
84 |
+
|
85 |
+
try:
|
86 |
+
if self.is_roformer:
|
87 |
+
self.logger.debug("Loading Roformer model...")
|
88 |
+
|
89 |
+
# Determine the model type based on the configuration and instantiate it
|
90 |
+
if "num_bands" in self.model_data_cfgdict.model:
|
91 |
+
self.logger.debug("Loading MelBandRoformer model...")
|
92 |
+
model = MelBandRoformer(**self.model_data_cfgdict.model)
|
93 |
+
elif "freqs_per_bands" in self.model_data_cfgdict.model:
|
94 |
+
self.logger.debug("Loading BSRoformer model...")
|
95 |
+
model = BSRoformer(**self.model_data_cfgdict.model)
|
96 |
+
else:
|
97 |
+
raise ValueError("Unknown Roformer model type in the configuration.")
|
98 |
+
|
99 |
+
# Load model checkpoint
|
100 |
+
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
|
101 |
+
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
|
102 |
+
self.model_run.load_state_dict(checkpoint)
|
103 |
+
self.model_run.to(self.torch_device).eval()
|
104 |
+
|
105 |
+
else:
|
106 |
+
self.logger.debug("Loading TFC_TDF_net model...")
|
107 |
+
self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
|
108 |
+
self.logger.debug("Loading model onto cpu")
|
109 |
+
# For some reason loading the state onto a hardware accelerated devices causes issues,
|
110 |
+
# so we load it onto CPU first then move it to the device
|
111 |
+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
112 |
+
self.model_run.to(self.torch_device).eval()
|
113 |
+
|
114 |
+
except RuntimeError as e:
|
115 |
+
self.logger.error(f"Error: {e}")
|
116 |
+
self.logger.error("An error occurred while loading the model file. This often occurs when the model file is corrupt or incomplete.")
|
117 |
+
self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
|
118 |
+
sys.exit(1)
|
119 |
+
|
120 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
121 |
+
"""
|
122 |
+
Separates the audio file into primary and secondary sources based on the model's configuration.
|
123 |
+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
audio_file_path (str): The path to the audio file to be processed.
|
127 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
list: A list of paths to the output files generated by the separation process.
|
131 |
+
"""
|
132 |
+
self.primary_source = None
|
133 |
+
self.secondary_source = None
|
134 |
+
|
135 |
+
self.audio_file_path = audio_file_path
|
136 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
137 |
+
|
138 |
+
self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
|
139 |
+
mix = self.prepare_mix(self.audio_file_path)
|
140 |
+
|
141 |
+
self.logger.debug("Normalizing mix before demixing...")
|
142 |
+
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
|
143 |
+
|
144 |
+
source = self.demix(mix=mix)
|
145 |
+
self.logger.debug("Demixing completed.")
|
146 |
+
|
147 |
+
output_files = []
|
148 |
+
self.logger.debug("Processing output files...")
|
149 |
+
|
150 |
+
if isinstance(source, dict):
|
151 |
+
self.logger.debug("Source is a dict, processing each stem...")
|
152 |
+
|
153 |
+
stem_list = []
|
154 |
+
if self.model_data_cfgdict.training.target_instrument:
|
155 |
+
stem_list = [self.model_data_cfgdict.training.target_instrument]
|
156 |
+
else:
|
157 |
+
stem_list = self.model_data_cfgdict.training.instruments
|
158 |
+
|
159 |
+
self.logger.debug(f"Available stems: {stem_list}")
|
160 |
+
|
161 |
+
is_multi_stem_model = len(stem_list) > 2
|
162 |
+
should_process_all_stems = self.process_all_stems and is_multi_stem_model
|
163 |
+
|
164 |
+
if should_process_all_stems:
|
165 |
+
self.logger.debug("Processing all stems from multi-stem model...")
|
166 |
+
for stem_name in stem_list:
|
167 |
+
stem_output_path = self.get_stem_output_path(stem_name, custom_output_names)
|
168 |
+
stem_source = spec_utils.normalize(
|
169 |
+
wave=source[stem_name],
|
170 |
+
max_peak=self.normalization_threshold,
|
171 |
+
min_peak=self.amplification_threshold
|
172 |
+
).T
|
173 |
+
|
174 |
+
self.logger.info(f"Saving {stem_name} stem to {stem_output_path}...")
|
175 |
+
self.final_process(stem_output_path, stem_source, stem_name)
|
176 |
+
output_files.append(stem_output_path)
|
177 |
+
else:
|
178 |
+
# Standard processing for primary and secondary stems
|
179 |
+
if not isinstance(self.primary_source, np.ndarray):
|
180 |
+
self.logger.debug(f"Normalizing primary source for primary stem {self.primary_stem_name}...")
|
181 |
+
self.primary_source = spec_utils.normalize(
|
182 |
+
wave=source[self.primary_stem_name],
|
183 |
+
max_peak=self.normalization_threshold,
|
184 |
+
min_peak=self.amplification_threshold
|
185 |
+
).T
|
186 |
+
|
187 |
+
if not isinstance(self.secondary_source, np.ndarray):
|
188 |
+
self.logger.debug(f"Normalizing secondary source for secondary stem {self.secondary_stem_name}...")
|
189 |
+
self.secondary_source = spec_utils.normalize(
|
190 |
+
wave=source[self.secondary_stem_name],
|
191 |
+
max_peak=self.normalization_threshold,
|
192 |
+
min_peak=self.amplification_threshold
|
193 |
+
).T
|
194 |
+
|
195 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
196 |
+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
|
197 |
+
|
198 |
+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
|
199 |
+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
200 |
+
output_files.append(self.secondary_stem_output_path)
|
201 |
+
|
202 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
203 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
204 |
+
|
205 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
206 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
207 |
+
output_files.append(self.primary_stem_output_path)
|
208 |
+
|
209 |
+
else:
|
210 |
+
# Handle case when source is not a dictionary (single source model)
|
211 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
212 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
213 |
+
|
214 |
+
if not isinstance(self.primary_source, np.ndarray):
|
215 |
+
self.primary_source = source.T
|
216 |
+
|
217 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
218 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
219 |
+
output_files.append(self.primary_stem_output_path)
|
220 |
+
|
221 |
+
return output_files
|
222 |
+
|
223 |
+
def pitch_fix(self, source, sr_pitched, orig_mix):
|
224 |
+
"""
|
225 |
+
Change the pitch of the source audio by a number of semitones.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
source (np.ndarray): The source audio to be pitch-shifted.
|
229 |
+
sr_pitched (int): The sample rate of the pitch-shifted audio.
|
230 |
+
orig_mix (np.ndarray): The original mix, used to match the shape of the pitch-shifted audio.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
np.ndarray: The pitch-shifted source audio.
|
234 |
+
"""
|
235 |
+
source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=self.pitch_shift)[0]
|
236 |
+
source = spec_utils.match_array_shapes(source, orig_mix)
|
237 |
+
return source
|
238 |
+
|
239 |
+
def overlap_add(self, result, x, weights, start, length):
|
240 |
+
"""
|
241 |
+
Adds the overlapping part of the result to the result tensor.
|
242 |
+
"""
|
243 |
+
result[..., start : start + length] += x[..., :length] * weights[:length]
|
244 |
+
return result
|
245 |
+
|
246 |
+
def demix(self, mix: np.ndarray) -> dict:
|
247 |
+
"""
|
248 |
+
Demixes the input mix into primary and secondary sources using the model and model data.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
mix (np.ndarray): The mix to be demixed.
|
252 |
+
Returns:
|
253 |
+
dict: A dictionary containing the demixed sources.
|
254 |
+
"""
|
255 |
+
orig_mix = mix
|
256 |
+
|
257 |
+
if self.pitch_shift != 0:
|
258 |
+
self.logger.debug(f"Shifting pitch by -{self.pitch_shift} semitones...")
|
259 |
+
mix, sample_rate = spec_utils.change_pitch_semitones(mix, self.sample_rate, semitone_shift=-self.pitch_shift)
|
260 |
+
|
261 |
+
if self.is_roformer:
|
262 |
+
# Note: Currently, for Roformer models, `batch_size` is not utilized due to negligible performance improvements.
|
263 |
+
|
264 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
265 |
+
|
266 |
+
if self.override_model_segment_size:
|
267 |
+
mdx_segment_size = self.segment_size
|
268 |
+
self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
|
269 |
+
else:
|
270 |
+
mdx_segment_size = self.model_data_cfgdict.inference.dim_t
|
271 |
+
self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
|
272 |
+
|
273 |
+
# num_stems aka "S" in UVR
|
274 |
+
num_stems = 1 if self.model_data_cfgdict.training.target_instrument else len(self.model_data_cfgdict.training.instruments)
|
275 |
+
self.logger.debug(f"Number of stems: {num_stems}")
|
276 |
+
|
277 |
+
# chunk_size aka "C" in UVR
|
278 |
+
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
|
279 |
+
self.logger.debug(f"Chunk size: {chunk_size}")
|
280 |
+
|
281 |
+
step = int(self.overlap * self.model_data_cfgdict.audio.sample_rate)
|
282 |
+
self.logger.debug(f"Step: {step}")
|
283 |
+
|
284 |
+
# Create a weighting table and convert it to a PyTorch tensor
|
285 |
+
window = torch.tensor(signal.windows.hamming(chunk_size), dtype=torch.float32)
|
286 |
+
|
287 |
+
device = next(self.model_run.parameters()).device
|
288 |
+
|
289 |
+
|
290 |
+
with torch.no_grad():
|
291 |
+
req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
|
292 |
+
result = torch.zeros(req_shape, dtype=torch.float32)
|
293 |
+
counter = torch.zeros(req_shape, dtype=torch.float32)
|
294 |
+
|
295 |
+
for i in tqdm(range(0, mix.shape[1], step)):
|
296 |
+
part = mix[:, i : i + chunk_size]
|
297 |
+
length = part.shape[-1]
|
298 |
+
if i + chunk_size > mix.shape[1]:
|
299 |
+
part = mix[:, -chunk_size:]
|
300 |
+
length = chunk_size
|
301 |
+
part = part.to(device)
|
302 |
+
x = self.model_run(part.unsqueeze(0))[0]
|
303 |
+
x = x.cpu()
|
304 |
+
# Perform overlap_add on CPU
|
305 |
+
if i + chunk_size > mix.shape[1]:
|
306 |
+
# Fixed to correctly add to the end of the tensor
|
307 |
+
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
|
308 |
+
counter[..., result.shape[-1] - chunk_size :] += window[:length]
|
309 |
+
else:
|
310 |
+
result = self.overlap_add(result, x, window, i, length)
|
311 |
+
counter[..., i : i + length] += window[:length]
|
312 |
+
|
313 |
+
inferenced_outputs = result / counter.clamp(min=1e-10)
|
314 |
+
|
315 |
+
else:
|
316 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
317 |
+
|
318 |
+
try:
|
319 |
+
num_stems = self.model_run.num_target_instruments
|
320 |
+
except AttributeError:
|
321 |
+
num_stems = self.model_run.module.num_target_instruments
|
322 |
+
self.logger.debug(f"Number of stems: {num_stems}")
|
323 |
+
|
324 |
+
if self.override_model_segment_size:
|
325 |
+
mdx_segment_size = self.segment_size
|
326 |
+
self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
|
327 |
+
else:
|
328 |
+
mdx_segment_size = self.model_data_cfgdict.inference.dim_t
|
329 |
+
self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
|
330 |
+
|
331 |
+
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
|
332 |
+
self.logger.debug(f"Chunk size: {chunk_size}")
|
333 |
+
|
334 |
+
hop_size = chunk_size // self.overlap
|
335 |
+
self.logger.debug(f"Hop size: {hop_size}")
|
336 |
+
|
337 |
+
mix_shape = mix.shape[1]
|
338 |
+
pad_size = hop_size - (mix_shape - chunk_size) % hop_size
|
339 |
+
self.logger.debug(f"Pad size: {pad_size}")
|
340 |
+
|
341 |
+
mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
|
342 |
+
self.logger.debug(f"Mix shape: {mix.shape}")
|
343 |
+
|
344 |
+
chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
|
345 |
+
self.logger.debug(f"Chunks length: {len(chunks)} and shape: {chunks.shape}")
|
346 |
+
|
347 |
+
batches = [chunks[i : i + self.batch_size] for i in range(0, len(chunks), self.batch_size)]
|
348 |
+
self.logger.debug(f"Batch size: {self.batch_size}, number of batches: {len(batches)}")
|
349 |
+
|
350 |
+
# accumulated_outputs is used to accumulate the output from processing each batch of chunks through the model.
|
351 |
+
# It starts as a tensor of zeros and is updated in-place as the model processes each batch.
|
352 |
+
# The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
|
353 |
+
accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
|
354 |
+
|
355 |
+
with torch.no_grad():
|
356 |
+
count = 0
|
357 |
+
for batch in tqdm(batches):
|
358 |
+
# Since the model processes the audio data in batches, single_batch_result temporarily holds the model's output
|
359 |
+
# for each batch before it is accumulated into accumulated_outputs.
|
360 |
+
single_batch_result = self.model_run(batch.to(self.torch_device))
|
361 |
+
|
362 |
+
# Each individual output tensor from the current batch's processing result.
|
363 |
+
# Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
|
364 |
+
# individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
|
365 |
+
for individual_output in single_batch_result:
|
366 |
+
individual_output_cpu = individual_output.cpu()
|
367 |
+
# Accumulate outputs on CPU
|
368 |
+
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
|
369 |
+
count += 1
|
370 |
+
|
371 |
+
self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
|
372 |
+
inferenced_outputs = accumulated_outputs[..., chunk_size - hop_size : -(pad_size + chunk_size - hop_size)] / self.overlap
|
373 |
+
self.logger.debug("Deleting accumulated outputs to free up memory")
|
374 |
+
del accumulated_outputs
|
375 |
+
|
376 |
+
if num_stems > 1 or self.is_primary_stem_main_target:
|
377 |
+
self.logger.debug("Number of stems is greater than 1 or vocals are main target, detaching individual sources and correcting pitch if necessary...")
|
378 |
+
|
379 |
+
sources = {}
|
380 |
+
|
381 |
+
# Iterates over each instrument specified in the model's configuration and its corresponding separated audio source.
|
382 |
+
# self.model_data_cfgdict.training.instruments provides the list of stems.
|
383 |
+
# estimated_sources.cpu().detach().numpy() converts the separated sources tensor to a NumPy array for processing.
|
384 |
+
# Each iteration provides an instrument name ('key') and its separated audio ('value') for further processing.
|
385 |
+
for key, value in zip(self.model_data_cfgdict.training.instruments, inferenced_outputs.cpu().detach().numpy()):
|
386 |
+
self.logger.debug(f"Processing instrument: {key}")
|
387 |
+
if self.pitch_shift != 0:
|
388 |
+
self.logger.debug(f"Applying pitch correction for {key}")
|
389 |
+
sources[key] = self.pitch_fix(value, sample_rate, orig_mix)
|
390 |
+
else:
|
391 |
+
sources[key] = value
|
392 |
+
|
393 |
+
if self.is_primary_stem_main_target:
|
394 |
+
self.logger.debug(f"Primary stem: {self.primary_stem_name} is main target, detaching and matching array shapes if necessary...")
|
395 |
+
if sources[self.primary_stem_name].shape[1] != orig_mix.shape[1]:
|
396 |
+
sources[self.primary_stem_name] = spec_utils.match_array_shapes(sources[self.primary_stem_name], orig_mix)
|
397 |
+
sources[self.secondary_stem_name] = orig_mix - sources[self.primary_stem_name]
|
398 |
+
|
399 |
+
self.logger.debug("Deleting inferenced outputs to free up memory")
|
400 |
+
del inferenced_outputs
|
401 |
+
|
402 |
+
self.logger.debug("Returning separated sources")
|
403 |
+
return sources
|
404 |
+
else:
|
405 |
+
self.logger.debug("Processing single source...")
|
406 |
+
|
407 |
+
if self.is_roformer:
|
408 |
+
sources = {k: v.cpu().detach().numpy() for k, v in zip([self.model_data_cfgdict.training.target_instrument], inferenced_outputs)}
|
409 |
+
inferenced_output = sources[self.model_data_cfgdict.training.target_instrument]
|
410 |
+
else:
|
411 |
+
inferenced_output = inferenced_outputs.cpu().detach().numpy()
|
412 |
+
|
413 |
+
self.logger.debug("Demix process completed for single source.")
|
414 |
+
|
415 |
+
self.logger.debug("Deleting inferenced outputs to free up memory")
|
416 |
+
del inferenced_outputs
|
417 |
+
|
418 |
+
if self.pitch_shift != 0:
|
419 |
+
self.logger.debug("Applying pitch correction for single instrument")
|
420 |
+
return self.pitch_fix(inferenced_output, sample_rate, orig_mix)
|
421 |
+
else:
|
422 |
+
self.logger.debug("Returning inferenced output for single instrument")
|
423 |
+
return inferenced_output
|
audio_separator/separator/architectures/vr_separator.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for separating audio sources using VR architecture models."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Check if we really need the rerun_mp3 function, remove if not
|
12 |
+
import audioread
|
13 |
+
|
14 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
15 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
16 |
+
from audio_separator.separator.uvr_lib_v5.vr_network import nets
|
17 |
+
from audio_separator.separator.uvr_lib_v5.vr_network import nets_new
|
18 |
+
from audio_separator.separator.uvr_lib_v5.vr_network.model_param_init import ModelParameters
|
19 |
+
|
20 |
+
|
21 |
+
class VRSeparator(CommonSeparator):
|
22 |
+
"""
|
23 |
+
VRSeparator is responsible for separating audio sources using VR models.
|
24 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, common_config, arch_config: dict):
|
28 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
29 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
30 |
+
super().__init__(config=common_config)
|
31 |
+
|
32 |
+
# Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
|
33 |
+
# It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
|
34 |
+
# The instance variable self.model_data is passed through from Separator and set in CommonSeparator
|
35 |
+
self.logger.debug(f"Model data: {self.model_data}")
|
36 |
+
|
37 |
+
# Most of the VR models use the same number of output channels, but the VR 51 models have specific values set in model_data JSON
|
38 |
+
self.model_capacity = 32, 128
|
39 |
+
self.is_vr_51_model = False
|
40 |
+
|
41 |
+
if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys():
|
42 |
+
self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
|
43 |
+
self.is_vr_51_model = True
|
44 |
+
|
45 |
+
# Model params are additional technical parameter values from JSON files in separator/uvr_lib_v5/vr_network/modelparams/*.json,
|
46 |
+
# with filenames referenced by the model_data["vr_model_param"] value
|
47 |
+
package_root_filepath = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
48 |
+
vr_params_json_dir = os.path.join(package_root_filepath, "uvr_lib_v5", "vr_network", "modelparams")
|
49 |
+
vr_params_json_filename = f"{self.model_data['vr_model_param']}.json"
|
50 |
+
vr_params_json_filepath = os.path.join(vr_params_json_dir, vr_params_json_filename)
|
51 |
+
self.model_params = ModelParameters(vr_params_json_filepath)
|
52 |
+
|
53 |
+
self.logger.debug(f"Model params: {self.model_params.param}")
|
54 |
+
|
55 |
+
# Arch Config is the VR architecture specific user configuration options, which should all be configurable by the user
|
56 |
+
# either by their Separator class instantiation or by passing in a CLI parameter.
|
57 |
+
# While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
|
58 |
+
# this way as they have architecture-specific default values.
|
59 |
+
|
60 |
+
# This option performs Test-Time-Augmentation to improve the separation quality.
|
61 |
+
# Note: Having this selected will increase the time it takes to complete a conversion
|
62 |
+
self.enable_tta = arch_config.get("enable_tta", False)
|
63 |
+
|
64 |
+
# This option can potentially identify leftover instrumental artifacts within the vocal outputs; may improve the separation of some songs.
|
65 |
+
# Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.
|
66 |
+
self.enable_post_process = arch_config.get("enable_post_process", False)
|
67 |
+
|
68 |
+
# post_process_threshold values = ('0.1', '0.2', '0.3')
|
69 |
+
self.post_process_threshold = arch_config.get("post_process_threshold", 0.2)
|
70 |
+
|
71 |
+
# Number of batches to be processed at a time.
|
72 |
+
# - Higher values mean more RAM usage but slightly faster processing times.
|
73 |
+
# - Lower values mean less RAM usage but slightly longer processing times.
|
74 |
+
# - Batch size value has no effect on output quality.
|
75 |
+
|
76 |
+
# Andrew note: for some reason, lower batch sizes seem to cause broken output for VR arch; need to investigate why
|
77 |
+
self.batch_size = arch_config.get("batch_size", 1)
|
78 |
+
|
79 |
+
# Select window size to balance quality and speed:
|
80 |
+
# - 1024 - Quick but lesser quality.
|
81 |
+
# - 512 - Medium speed and quality.
|
82 |
+
# - 320 - Takes longer but may offer better quality.
|
83 |
+
self.window_size = arch_config.get("window_size", 512)
|
84 |
+
|
85 |
+
# The application will mirror the missing frequency range of the output.
|
86 |
+
self.high_end_process = arch_config.get("high_end_process", False)
|
87 |
+
self.input_high_end_h = None
|
88 |
+
self.input_high_end = None
|
89 |
+
|
90 |
+
# Adjust the intensity of primary stem extraction:
|
91 |
+
# - Ranges from -100 - 100.
|
92 |
+
# - Bigger values mean deeper extractions.
|
93 |
+
# - Typically, it's set to 5 for vocals & instrumentals.
|
94 |
+
# - Values beyond 5 might muddy the sound for non-vocal models.
|
95 |
+
self.aggression = float(int(arch_config.get("aggression", 5)) / 100)
|
96 |
+
|
97 |
+
self.aggressiveness = {"value": self.aggression, "split_bin": self.model_params.param["band"][1]["crop_stop"], "aggr_correction": self.model_params.param.get("aggr_correction")}
|
98 |
+
|
99 |
+
self.model_samplerate = self.model_params.param["sr"]
|
100 |
+
|
101 |
+
self.logger.debug(f"VR arch params: enable_tta={self.enable_tta}, enable_post_process={self.enable_post_process}, post_process_threshold={self.post_process_threshold}")
|
102 |
+
self.logger.debug(f"VR arch params: batch_size={self.batch_size}, window_size={self.window_size}")
|
103 |
+
self.logger.debug(f"VR arch params: high_end_process={self.high_end_process}, aggression={self.aggression}")
|
104 |
+
self.logger.debug(f"VR arch params: is_vr_51_model={self.is_vr_51_model}, model_samplerate={self.model_samplerate}, model_capacity={self.model_capacity}")
|
105 |
+
|
106 |
+
self.model_run = lambda *args, **kwargs: self.logger.error("Model run method is not initialised yet.")
|
107 |
+
|
108 |
+
# This should go away once we refactor to remove soundfile.write and replace with pydub like we did for the MDX rewrite
|
109 |
+
self.wav_subtype = "PCM_16"
|
110 |
+
|
111 |
+
self.logger.info("VR Separator initialisation complete")
|
112 |
+
|
113 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
114 |
+
"""
|
115 |
+
Separates the audio file into primary and secondary sources based on the model's configuration.
|
116 |
+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
audio_file_path (str): The path to the audio file to be processed.
|
120 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
list: A list of paths to the output files generated by the separation process.
|
124 |
+
"""
|
125 |
+
self.primary_source = None
|
126 |
+
self.secondary_source = None
|
127 |
+
|
128 |
+
self.audio_file_path = audio_file_path
|
129 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
130 |
+
|
131 |
+
self.logger.debug(f"Starting separation for input audio file {self.audio_file_path}...")
|
132 |
+
|
133 |
+
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
|
134 |
+
vr_5_1_models = [56817, 218409]
|
135 |
+
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
|
136 |
+
nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
|
137 |
+
self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}")
|
138 |
+
|
139 |
+
if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
|
140 |
+
self.logger.debug("Using CascadedNet for VR 5.1 model...")
|
141 |
+
self.model_run = nets_new.CascadedNet(self.model_params.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
|
142 |
+
self.is_vr_51_model = True
|
143 |
+
else:
|
144 |
+
self.logger.debug("Determining model capacity...")
|
145 |
+
self.model_run = nets.determine_model_capacity(self.model_params.param["bins"] * 2, nn_arch_size)
|
146 |
+
|
147 |
+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
148 |
+
self.model_run.to(self.torch_device)
|
149 |
+
self.logger.debug("Model loaded and moved to device.")
|
150 |
+
|
151 |
+
y_spec, v_spec = self.inference_vr(self.loading_mix(), self.torch_device, self.aggressiveness)
|
152 |
+
self.logger.debug("Inference completed.")
|
153 |
+
|
154 |
+
# Sanitize y_spec and v_spec to replace NaN and infinite values
|
155 |
+
y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0)
|
156 |
+
v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0)
|
157 |
+
|
158 |
+
self.logger.debug("Sanitization completed. Replaced NaN and infinite values in y_spec and v_spec.")
|
159 |
+
|
160 |
+
# After inference_vr call
|
161 |
+
self.logger.debug(f"Inference VR completed. y_spec shape: {y_spec.shape}, v_spec shape: {v_spec.shape}")
|
162 |
+
self.logger.debug(f"y_spec stats - min: {np.min(y_spec)}, max: {np.max(y_spec)}, isnan: {np.isnan(y_spec).any()}, isinf: {np.isinf(y_spec).any()}")
|
163 |
+
self.logger.debug(f"v_spec stats - min: {np.min(v_spec)}, max: {np.max(v_spec)}, isnan: {np.isnan(v_spec).any()}, isinf: {np.isinf(v_spec).any()}")
|
164 |
+
|
165 |
+
# Not yet implemented from UVR features:
|
166 |
+
#
|
167 |
+
# if not self.is_vocal_split_model:
|
168 |
+
# self.cache_source((y_spec, v_spec))
|
169 |
+
|
170 |
+
# if self.is_secondary_model_activated and self.secondary_model:
|
171 |
+
# self.logger.debug("Processing secondary model...")
|
172 |
+
# self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(
|
173 |
+
# self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem
|
174 |
+
# )
|
175 |
+
|
176 |
+
# Initialize the list for output files
|
177 |
+
output_files = []
|
178 |
+
self.logger.debug("Processing output files...")
|
179 |
+
|
180 |
+
# Note: logic similar to the following should probably be added to the other architectures
|
181 |
+
# Check if output_single_stem is set to a value that would result in no output files
|
182 |
+
if self.output_single_stem and (self.output_single_stem.lower() != self.primary_stem_name.lower() and self.output_single_stem.lower() != self.secondary_stem_name.lower()):
|
183 |
+
# If so, reset output_single_stem to None to save both stems
|
184 |
+
self.output_single_stem = None
|
185 |
+
self.logger.warning(f"The output_single_stem setting '{self.output_single_stem}' does not match any of the output files: '{self.primary_stem_name}' and '{self.secondary_stem_name}'. For this model '{self.model_name}', the output_single_stem setting will be ignored and all output files will be saved.")
|
186 |
+
|
187 |
+
# Save and process the primary stem if needed
|
188 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
189 |
+
self.logger.debug(f"Processing primary stem: {self.primary_stem_name}")
|
190 |
+
if not isinstance(self.primary_source, np.ndarray):
|
191 |
+
self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {y_spec.shape}")
|
192 |
+
|
193 |
+
self.primary_source = self.spec_to_wav(y_spec).T
|
194 |
+
self.logger.debug("Converting primary source spectrogram to waveform.")
|
195 |
+
if not self.model_samplerate == 44100:
|
196 |
+
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
|
197 |
+
self.logger.debug("Resampling primary source to 44100Hz.")
|
198 |
+
|
199 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
200 |
+
|
201 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
202 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
203 |
+
output_files.append(self.primary_stem_output_path)
|
204 |
+
|
205 |
+
# Save and process the secondary stem if needed
|
206 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
207 |
+
self.logger.debug(f"Processing secondary stem: {self.secondary_stem_name}")
|
208 |
+
if not isinstance(self.secondary_source, np.ndarray):
|
209 |
+
self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {v_spec.shape}")
|
210 |
+
|
211 |
+
self.secondary_source = self.spec_to_wav(v_spec).T
|
212 |
+
self.logger.debug("Converting secondary source spectrogram to waveform.")
|
213 |
+
if not self.model_samplerate == 44100:
|
214 |
+
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
|
215 |
+
self.logger.debug("Resampling secondary source to 44100Hz.")
|
216 |
+
|
217 |
+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
|
218 |
+
|
219 |
+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
|
220 |
+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
221 |
+
output_files.append(self.secondary_stem_output_path)
|
222 |
+
|
223 |
+
# Not yet implemented from UVR features:
|
224 |
+
# self.process_vocal_split_chain(secondary_sources)
|
225 |
+
# self.logger.debug("Vocal split chain processed.")
|
226 |
+
|
227 |
+
return output_files
|
228 |
+
|
229 |
+
def loading_mix(self):
|
230 |
+
X_wave, X_spec_s = {}, {}
|
231 |
+
|
232 |
+
bands_n = len(self.model_params.param["band"])
|
233 |
+
|
234 |
+
audio_file = spec_utils.write_array_to_mem(self.audio_file_path, subtype=self.wav_subtype)
|
235 |
+
is_mp3 = audio_file.endswith(".mp3") if isinstance(audio_file, str) else False
|
236 |
+
|
237 |
+
self.logger.debug(f"loading_mix iteraring through {bands_n} bands")
|
238 |
+
for d in tqdm(range(bands_n, 0, -1)):
|
239 |
+
bp = self.model_params.param["band"][d]
|
240 |
+
|
241 |
+
wav_resolution = bp["res_type"]
|
242 |
+
|
243 |
+
if self.torch_device_mps is not None:
|
244 |
+
wav_resolution = "polyphase"
|
245 |
+
|
246 |
+
if d == bands_n: # high-end band
|
247 |
+
X_wave[d], _ = librosa.load(audio_file, sr=bp["sr"], mono=False, dtype=np.float32, res_type=wav_resolution)
|
248 |
+
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model)
|
249 |
+
|
250 |
+
if not np.any(X_wave[d]) and is_mp3:
|
251 |
+
X_wave[d] = rerun_mp3(audio_file, bp["sr"])
|
252 |
+
|
253 |
+
if X_wave[d].ndim == 1:
|
254 |
+
X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
|
255 |
+
else: # lower bands
|
256 |
+
X_wave[d] = librosa.resample(X_wave[d + 1], orig_sr=self.model_params.param["band"][d + 1]["sr"], target_sr=bp["sr"], res_type=wav_resolution)
|
257 |
+
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model)
|
258 |
+
|
259 |
+
if d == bands_n and self.high_end_process:
|
260 |
+
self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (self.model_params.param["pre_filter_stop"] - self.model_params.param["pre_filter_start"])
|
261 |
+
self.input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, :]
|
262 |
+
|
263 |
+
X_spec = spec_utils.combine_spectrograms(X_spec_s, self.model_params, is_v51_model=self.is_vr_51_model)
|
264 |
+
|
265 |
+
del X_wave, X_spec_s, audio_file
|
266 |
+
|
267 |
+
return X_spec
|
268 |
+
|
269 |
+
def inference_vr(self, X_spec, device, aggressiveness):
|
270 |
+
def _execute(X_mag_pad, roi_size):
|
271 |
+
X_dataset = []
|
272 |
+
patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
|
273 |
+
|
274 |
+
self.logger.debug(f"inference_vr appending to X_dataset for each of {patches} patches")
|
275 |
+
for i in tqdm(range(patches)):
|
276 |
+
start = i * roi_size
|
277 |
+
X_mag_window = X_mag_pad[:, :, start : start + self.window_size]
|
278 |
+
X_dataset.append(X_mag_window)
|
279 |
+
|
280 |
+
total_iterations = patches // self.batch_size if not self.enable_tta else (patches // self.batch_size) * 2
|
281 |
+
self.logger.debug(f"inference_vr iterating through {total_iterations} batches, batch_size = {self.batch_size}")
|
282 |
+
|
283 |
+
X_dataset = np.asarray(X_dataset)
|
284 |
+
self.model_run.eval()
|
285 |
+
with torch.no_grad():
|
286 |
+
mask = []
|
287 |
+
|
288 |
+
for i in tqdm(range(0, patches, self.batch_size)):
|
289 |
+
|
290 |
+
X_batch = X_dataset[i : i + self.batch_size]
|
291 |
+
X_batch = torch.from_numpy(X_batch).to(device)
|
292 |
+
pred = self.model_run.predict_mask(X_batch)
|
293 |
+
if not pred.size()[3] > 0:
|
294 |
+
raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]")
|
295 |
+
pred = pred.detach().cpu().numpy()
|
296 |
+
pred = np.concatenate(pred, axis=2)
|
297 |
+
mask.append(pred)
|
298 |
+
if len(mask) == 0:
|
299 |
+
raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]")
|
300 |
+
|
301 |
+
mask = np.concatenate(mask, axis=2)
|
302 |
+
return mask
|
303 |
+
|
304 |
+
def postprocess(mask, X_mag, X_phase):
|
305 |
+
is_non_accom_stem = False
|
306 |
+
for stem in CommonSeparator.NON_ACCOM_STEMS:
|
307 |
+
if stem == self.primary_stem_name:
|
308 |
+
is_non_accom_stem = True
|
309 |
+
|
310 |
+
mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
|
311 |
+
|
312 |
+
if self.enable_post_process:
|
313 |
+
mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold)
|
314 |
+
|
315 |
+
y_spec = mask * X_mag * np.exp(1.0j * X_phase)
|
316 |
+
v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
|
317 |
+
|
318 |
+
return y_spec, v_spec
|
319 |
+
|
320 |
+
X_mag, X_phase = spec_utils.preprocess(X_spec)
|
321 |
+
n_frame = X_mag.shape[2]
|
322 |
+
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
|
323 |
+
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
|
324 |
+
X_mag_pad /= X_mag_pad.max()
|
325 |
+
mask = _execute(X_mag_pad, roi_size)
|
326 |
+
|
327 |
+
if self.enable_tta:
|
328 |
+
pad_l += roi_size // 2
|
329 |
+
pad_r += roi_size // 2
|
330 |
+
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
|
331 |
+
X_mag_pad /= X_mag_pad.max()
|
332 |
+
mask_tta = _execute(X_mag_pad, roi_size)
|
333 |
+
mask_tta = mask_tta[:, :, roi_size // 2 :]
|
334 |
+
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
|
335 |
+
else:
|
336 |
+
mask = mask[:, :, :n_frame]
|
337 |
+
|
338 |
+
y_spec, v_spec = postprocess(mask, X_mag, X_phase)
|
339 |
+
|
340 |
+
return y_spec, v_spec
|
341 |
+
|
342 |
+
def spec_to_wav(self, spec):
|
343 |
+
if self.high_end_process and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h:
|
344 |
+
input_high_end_ = spec_utils.mirroring("mirroring", spec, self.input_high_end, self.model_params)
|
345 |
+
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model)
|
346 |
+
else:
|
347 |
+
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, is_v51_model=self.is_vr_51_model)
|
348 |
+
|
349 |
+
return wav
|
350 |
+
|
351 |
+
|
352 |
+
# Check if we really need the rerun_mp3 function, refactor or remove if not
|
353 |
+
def rerun_mp3(audio_file, sample_rate=44100):
|
354 |
+
with audioread.audio_open(audio_file) as f:
|
355 |
+
track_length = int(f.duration)
|
356 |
+
|
357 |
+
return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0]
|
audio_separator/separator/common_separator.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" This file contains the CommonSeparator class, common to all architecture-specific Separator classes. """
|
2 |
+
|
3 |
+
from logging import Logger
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import gc
|
7 |
+
import numpy as np
|
8 |
+
import librosa
|
9 |
+
import torch
|
10 |
+
from pydub import AudioSegment
|
11 |
+
import soundfile as sf
|
12 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
13 |
+
|
14 |
+
|
15 |
+
class CommonSeparator:
|
16 |
+
"""
|
17 |
+
This class contains the common methods and attributes common to all architecture-specific Separator classes.
|
18 |
+
"""
|
19 |
+
|
20 |
+
ALL_STEMS = "All Stems"
|
21 |
+
VOCAL_STEM = "Vocals"
|
22 |
+
INST_STEM = "Instrumental"
|
23 |
+
OTHER_STEM = "Other"
|
24 |
+
BASS_STEM = "Bass"
|
25 |
+
DRUM_STEM = "Drums"
|
26 |
+
GUITAR_STEM = "Guitar"
|
27 |
+
PIANO_STEM = "Piano"
|
28 |
+
SYNTH_STEM = "Synthesizer"
|
29 |
+
STRINGS_STEM = "Strings"
|
30 |
+
WOODWINDS_STEM = "Woodwinds"
|
31 |
+
BRASS_STEM = "Brass"
|
32 |
+
WIND_INST_STEM = "Wind Inst"
|
33 |
+
NO_OTHER_STEM = "No Other"
|
34 |
+
NO_BASS_STEM = "No Bass"
|
35 |
+
NO_DRUM_STEM = "No Drums"
|
36 |
+
NO_GUITAR_STEM = "No Guitar"
|
37 |
+
NO_PIANO_STEM = "No Piano"
|
38 |
+
NO_SYNTH_STEM = "No Synthesizer"
|
39 |
+
NO_STRINGS_STEM = "No Strings"
|
40 |
+
NO_WOODWINDS_STEM = "No Woodwinds"
|
41 |
+
NO_WIND_INST_STEM = "No Wind Inst"
|
42 |
+
NO_BRASS_STEM = "No Brass"
|
43 |
+
PRIMARY_STEM = "Primary Stem"
|
44 |
+
SECONDARY_STEM = "Secondary Stem"
|
45 |
+
LEAD_VOCAL_STEM = "lead_only"
|
46 |
+
BV_VOCAL_STEM = "backing_only"
|
47 |
+
LEAD_VOCAL_STEM_I = "with_lead_vocals"
|
48 |
+
BV_VOCAL_STEM_I = "with_backing_vocals"
|
49 |
+
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
|
50 |
+
BV_VOCAL_STEM_LABEL = "Backing Vocals"
|
51 |
+
NO_STEM = "No "
|
52 |
+
|
53 |
+
STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
|
54 |
+
|
55 |
+
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
|
56 |
+
|
57 |
+
def __init__(self, config):
|
58 |
+
|
59 |
+
self.logger: Logger = config.get("logger")
|
60 |
+
self.log_level: int = config.get("log_level")
|
61 |
+
|
62 |
+
# Inferencing device / acceleration config
|
63 |
+
self.torch_device = config.get("torch_device")
|
64 |
+
self.torch_device_cpu = config.get("torch_device_cpu")
|
65 |
+
self.torch_device_mps = config.get("torch_device_mps")
|
66 |
+
self.onnx_execution_provider = config.get("onnx_execution_provider")
|
67 |
+
|
68 |
+
# Model data
|
69 |
+
self.model_name = config.get("model_name")
|
70 |
+
self.model_path = config.get("model_path")
|
71 |
+
self.model_data = config.get("model_data")
|
72 |
+
|
73 |
+
# Output directory and format
|
74 |
+
self.output_dir = config.get("output_dir")
|
75 |
+
self.output_format = config.get("output_format")
|
76 |
+
self.output_bitrate = config.get("output_bitrate")
|
77 |
+
|
78 |
+
# Functional options which are applicable to all architectures and the user may tweak to affect the output
|
79 |
+
self.normalization_threshold = config.get("normalization_threshold")
|
80 |
+
self.amplification_threshold = config.get("amplification_threshold")
|
81 |
+
self.enable_denoise = config.get("enable_denoise")
|
82 |
+
self.output_single_stem = config.get("output_single_stem")
|
83 |
+
self.invert_using_spec = config.get("invert_using_spec")
|
84 |
+
self.sample_rate = config.get("sample_rate")
|
85 |
+
self.use_soundfile = config.get("use_soundfile")
|
86 |
+
|
87 |
+
# Model specific properties
|
88 |
+
|
89 |
+
# Check if model_data has a "training" key with "instruments" list
|
90 |
+
self.primary_stem_name = None
|
91 |
+
self.secondary_stem_name = None
|
92 |
+
|
93 |
+
if "training" in self.model_data and "instruments" in self.model_data["training"]:
|
94 |
+
instruments = self.model_data["training"]["instruments"]
|
95 |
+
if instruments:
|
96 |
+
self.primary_stem_name = instruments[0]
|
97 |
+
self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
|
98 |
+
|
99 |
+
if self.primary_stem_name is None:
|
100 |
+
self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
|
101 |
+
self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
|
102 |
+
|
103 |
+
self.is_karaoke = self.model_data.get("is_karaoke", False)
|
104 |
+
self.is_bv_model = self.model_data.get("is_bv_model", False)
|
105 |
+
self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
|
106 |
+
|
107 |
+
self.logger.debug(f"Common params: model_name={self.model_name}, model_path={self.model_path}")
|
108 |
+
self.logger.debug(f"Common params: output_dir={self.output_dir}, output_format={self.output_format}")
|
109 |
+
self.logger.debug(f"Common params: normalization_threshold={self.normalization_threshold}, amplification_threshold={self.amplification_threshold}")
|
110 |
+
self.logger.debug(f"Common params: enable_denoise={self.enable_denoise}, output_single_stem={self.output_single_stem}")
|
111 |
+
self.logger.debug(f"Common params: invert_using_spec={self.invert_using_spec}, sample_rate={self.sample_rate}")
|
112 |
+
|
113 |
+
self.logger.debug(f"Common params: primary_stem_name={self.primary_stem_name}, secondary_stem_name={self.secondary_stem_name}")
|
114 |
+
self.logger.debug(f"Common params: is_karaoke={self.is_karaoke}, is_bv_model={self.is_bv_model}, bv_model_rebalance={self.bv_model_rebalance}")
|
115 |
+
|
116 |
+
# File-specific variables which need to be cleared between processing different audio inputs
|
117 |
+
self.audio_file_path = None
|
118 |
+
self.audio_file_base = None
|
119 |
+
|
120 |
+
self.primary_source = None
|
121 |
+
self.secondary_source = None
|
122 |
+
|
123 |
+
self.primary_stem_output_path = None
|
124 |
+
self.secondary_stem_output_path = None
|
125 |
+
|
126 |
+
self.cached_sources_map = {}
|
127 |
+
|
128 |
+
def secondary_stem(self, primary_stem: str):
|
129 |
+
"""Determines secondary stem name based on the primary stem name."""
|
130 |
+
primary_stem = primary_stem if primary_stem else self.NO_STEM
|
131 |
+
|
132 |
+
if primary_stem in self.STEM_PAIR_MAPPER:
|
133 |
+
secondary_stem = self.STEM_PAIR_MAPPER[primary_stem]
|
134 |
+
else:
|
135 |
+
secondary_stem = primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
|
136 |
+
|
137 |
+
return secondary_stem
|
138 |
+
|
139 |
+
def separate(self, audio_file_path):
|
140 |
+
"""
|
141 |
+
Placeholder method for separating audio sources. Should be overridden by subclasses.
|
142 |
+
"""
|
143 |
+
raise NotImplementedError("This method should be overridden by subclasses.")
|
144 |
+
|
145 |
+
def final_process(self, stem_path, source, stem_name):
|
146 |
+
"""
|
147 |
+
Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
|
148 |
+
"""
|
149 |
+
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
|
150 |
+
self.write_audio(stem_path, source)
|
151 |
+
|
152 |
+
return {stem_name: source}
|
153 |
+
|
154 |
+
def cached_sources_clear(self):
|
155 |
+
"""
|
156 |
+
Clears the cache dictionaries for VR, MDX, and Demucs models.
|
157 |
+
|
158 |
+
This function is essential for ensuring that the cache does not hold outdated or irrelevant data
|
159 |
+
between different processing sessions or when a new batch of audio files is processed.
|
160 |
+
It helps in managing memory efficiently and prevents potential errors due to stale data.
|
161 |
+
"""
|
162 |
+
self.cached_sources_map = {}
|
163 |
+
|
164 |
+
def cached_source_callback(self, model_architecture, model_name=None):
|
165 |
+
"""
|
166 |
+
Retrieves the model and sources from the cache based on the processing method and model name.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
model_architecture: The architecture type (VR, MDX, or Demucs) being used for processing.
|
170 |
+
model_name: The specific model name within the architecture type, if applicable.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
A tuple containing the model and its sources if found in the cache; otherwise, None.
|
174 |
+
|
175 |
+
This function is crucial for optimizing performance by avoiding redundant processing.
|
176 |
+
If the requested model and its sources are already in the cache, they can be reused directly,
|
177 |
+
saving time and computational resources.
|
178 |
+
"""
|
179 |
+
model, sources = None, None
|
180 |
+
|
181 |
+
mapper = self.cached_sources_map[model_architecture]
|
182 |
+
|
183 |
+
for key, value in mapper.items():
|
184 |
+
if model_name in key:
|
185 |
+
model = key
|
186 |
+
sources = value
|
187 |
+
|
188 |
+
return model, sources
|
189 |
+
|
190 |
+
def cached_model_source_holder(self, model_architecture, sources, model_name=None):
|
191 |
+
"""
|
192 |
+
Update the dictionary for the given model_architecture with the new model name and its sources.
|
193 |
+
Use the model_architecture as a key to access the corresponding cache source mapper dictionary.
|
194 |
+
"""
|
195 |
+
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
|
196 |
+
|
197 |
+
def prepare_mix(self, mix):
|
198 |
+
"""
|
199 |
+
Prepares the mix for processing. This includes loading the audio from a file if necessary,
|
200 |
+
ensuring the mix is in the correct format, and converting mono to stereo if needed.
|
201 |
+
"""
|
202 |
+
# Store the original path or the mix itself for later checks
|
203 |
+
audio_path = mix
|
204 |
+
|
205 |
+
# Check if the input is a file path (string) and needs to be loaded
|
206 |
+
if not isinstance(mix, np.ndarray):
|
207 |
+
self.logger.debug(f"Loading audio from file: {mix}")
|
208 |
+
mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
|
209 |
+
self.logger.debug(f"Audio loaded. Sample rate: {sr}, Audio shape: {mix.shape}")
|
210 |
+
else:
|
211 |
+
# Transpose the mix if it's already an ndarray (expected shape: [channels, samples])
|
212 |
+
self.logger.debug("Transposing the provided mix array.")
|
213 |
+
mix = mix.T
|
214 |
+
self.logger.debug(f"Transposed mix shape: {mix.shape}")
|
215 |
+
|
216 |
+
# If the original input was a filepath, check if the loaded mix is empty
|
217 |
+
if isinstance(audio_path, str):
|
218 |
+
if not np.any(mix):
|
219 |
+
error_msg = f"Audio file {audio_path} is empty or not valid"
|
220 |
+
self.logger.error(error_msg)
|
221 |
+
raise ValueError(error_msg)
|
222 |
+
else:
|
223 |
+
self.logger.debug("Audio file is valid and contains data.")
|
224 |
+
|
225 |
+
# Ensure the mix is in stereo format
|
226 |
+
if mix.ndim == 1:
|
227 |
+
self.logger.debug("Mix is mono. Converting to stereo.")
|
228 |
+
mix = np.asfortranarray([mix, mix])
|
229 |
+
self.logger.debug("Converted to stereo mix.")
|
230 |
+
|
231 |
+
# Final log indicating successful preparation of the mix
|
232 |
+
self.logger.debug("Mix preparation completed.")
|
233 |
+
return mix
|
234 |
+
|
235 |
+
def write_audio(self, stem_path: str, stem_source):
|
236 |
+
"""
|
237 |
+
Writes the separated audio source to a file using pydub or soundfile
|
238 |
+
Pydub supports a much wider range of audio formats and produces better encoded lossy files for some formats.
|
239 |
+
Soundfile is used for very large files (longer than 1 hour), as pydub has memory issues with large files:
|
240 |
+
https://github.com/jiaaro/pydub/issues/135
|
241 |
+
"""
|
242 |
+
# Get the duration of the input audio file
|
243 |
+
duration_seconds = librosa.get_duration(filename=self.audio_file_path)
|
244 |
+
duration_hours = duration_seconds / 3600
|
245 |
+
self.logger.info(f"Audio duration is {duration_hours:.2f} hours ({duration_seconds:.2f} seconds).")
|
246 |
+
|
247 |
+
if self.use_soundfile:
|
248 |
+
self.logger.warning(f"Using soundfile for writing.")
|
249 |
+
self.write_audio_soundfile(stem_path, stem_source)
|
250 |
+
else:
|
251 |
+
self.logger.info(f"Using pydub for writing.")
|
252 |
+
self.write_audio_pydub(stem_path, stem_source)
|
253 |
+
|
254 |
+
def write_audio_pydub(self, stem_path: str, stem_source):
|
255 |
+
"""
|
256 |
+
Writes the separated audio source to a file using pydub (ffmpeg)
|
257 |
+
"""
|
258 |
+
self.logger.debug(f"Entering write_audio_pydub with stem_path: {stem_path}")
|
259 |
+
|
260 |
+
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
|
261 |
+
|
262 |
+
# Check if the numpy array is empty or contains very low values
|
263 |
+
if np.max(np.abs(stem_source)) < 1e-6:
|
264 |
+
self.logger.warning("Warning: stem_source array is near-silent or empty.")
|
265 |
+
return
|
266 |
+
|
267 |
+
# If output_dir is specified, create it and join it with stem_path
|
268 |
+
if self.output_dir:
|
269 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
270 |
+
stem_path = os.path.join(self.output_dir, stem_path)
|
271 |
+
|
272 |
+
self.logger.debug(f"Audio data shape before processing: {stem_source.shape}")
|
273 |
+
self.logger.debug(f"Data type before conversion: {stem_source.dtype}")
|
274 |
+
|
275 |
+
# Ensure the audio data is in the correct format (e.g., int16)
|
276 |
+
if stem_source.dtype != np.int16:
|
277 |
+
stem_source = (stem_source * 32767).astype(np.int16)
|
278 |
+
self.logger.debug("Converted stem_source to int16.")
|
279 |
+
|
280 |
+
# Correctly interleave stereo channels
|
281 |
+
stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
282 |
+
stem_source_interleaved[0::2] = stem_source[:, 0] # Left channel
|
283 |
+
stem_source_interleaved[1::2] = stem_source[:, 1] # Right channel
|
284 |
+
|
285 |
+
self.logger.debug(f"Interleaved audio data shape: {stem_source_interleaved.shape}")
|
286 |
+
|
287 |
+
# Create a pydub AudioSegment
|
288 |
+
try:
|
289 |
+
audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
|
290 |
+
self.logger.debug("Created AudioSegment successfully.")
|
291 |
+
except (IOError, ValueError) as e:
|
292 |
+
self.logger.error(f"Specific error creating AudioSegment: {e}")
|
293 |
+
return
|
294 |
+
|
295 |
+
# Determine file format based on the file extension
|
296 |
+
file_format = stem_path.lower().split(".")[-1]
|
297 |
+
|
298 |
+
# For m4a files, specify mp4 as the container format as the extension doesn't match the format name
|
299 |
+
if file_format == "m4a":
|
300 |
+
file_format = "mp4"
|
301 |
+
elif file_format == "mka":
|
302 |
+
file_format = "matroska"
|
303 |
+
|
304 |
+
# Set the bitrate to 320k for mp3 files if output_bitrate is not specified
|
305 |
+
bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate
|
306 |
+
|
307 |
+
# Export using the determined format
|
308 |
+
try:
|
309 |
+
audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
|
310 |
+
self.logger.debug(f"Exported audio file successfully to {stem_path}")
|
311 |
+
except (IOError, ValueError) as e:
|
312 |
+
self.logger.error(f"Error exporting audio file: {e}")
|
313 |
+
|
314 |
+
def write_audio_soundfile(self, stem_path: str, stem_source):
|
315 |
+
"""
|
316 |
+
Writes the separated audio source to a file using soundfile library.
|
317 |
+
"""
|
318 |
+
self.logger.debug(f"Entering write_audio_soundfile with stem_path: {stem_path}")
|
319 |
+
|
320 |
+
# Correctly interleave stereo channels if needed
|
321 |
+
if stem_source.shape[1] == 2:
|
322 |
+
# If the audio is already interleaved, ensure it's in the correct order
|
323 |
+
# Check if the array is Fortran contiguous (column-major)
|
324 |
+
if stem_source.flags["F_CONTIGUOUS"]:
|
325 |
+
# Convert to C contiguous (row-major)
|
326 |
+
stem_source = np.ascontiguousarray(stem_source)
|
327 |
+
# Otherwise, perform interleaving
|
328 |
+
else:
|
329 |
+
stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
330 |
+
# Left channel
|
331 |
+
stereo_interleaved[0::2] = stem_source[:, 0]
|
332 |
+
# Right channel
|
333 |
+
stereo_interleaved[1::2] = stem_source[:, 1]
|
334 |
+
stem_source = stereo_interleaved
|
335 |
+
|
336 |
+
self.logger.debug(f"Interleaved audio data shape: {stem_source.shape}")
|
337 |
+
|
338 |
+
"""
|
339 |
+
Write audio using soundfile (for formats other than M4A).
|
340 |
+
"""
|
341 |
+
# Save audio using soundfile
|
342 |
+
try:
|
343 |
+
# Specify the subtype to define the sample width
|
344 |
+
sf.write(stem_path, stem_source, self.sample_rate)
|
345 |
+
self.logger.debug(f"Exported audio file successfully to {stem_path}")
|
346 |
+
except Exception as e:
|
347 |
+
self.logger.error(f"Error exporting audio file: {e}")
|
348 |
+
|
349 |
+
def clear_gpu_cache(self):
|
350 |
+
"""
|
351 |
+
This method clears the GPU cache to free up memory.
|
352 |
+
"""
|
353 |
+
self.logger.debug("Running garbage collection...")
|
354 |
+
gc.collect()
|
355 |
+
if self.torch_device == torch.device("mps"):
|
356 |
+
self.logger.debug("Clearing MPS cache...")
|
357 |
+
torch.mps.empty_cache()
|
358 |
+
if self.torch_device == torch.device("cuda"):
|
359 |
+
self.logger.debug("Clearing CUDA cache...")
|
360 |
+
torch.cuda.empty_cache()
|
361 |
+
|
362 |
+
def clear_file_specific_paths(self):
|
363 |
+
"""
|
364 |
+
Clears the file-specific variables which need to be cleared between processing different audio inputs.
|
365 |
+
"""
|
366 |
+
self.logger.info("Clearing input audio file paths, sources and stems...")
|
367 |
+
|
368 |
+
self.audio_file_path = None
|
369 |
+
self.audio_file_base = None
|
370 |
+
|
371 |
+
self.primary_source = None
|
372 |
+
self.secondary_source = None
|
373 |
+
|
374 |
+
self.primary_stem_output_path = None
|
375 |
+
self.secondary_stem_output_path = None
|
376 |
+
|
377 |
+
def sanitize_filename(self, filename):
|
378 |
+
"""
|
379 |
+
Cleans the filename by replacing invalid characters with underscores.
|
380 |
+
"""
|
381 |
+
sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
382 |
+
sanitized = re.sub(r'_+', '_', sanitized)
|
383 |
+
sanitized = sanitized.strip('_. ')
|
384 |
+
return sanitized
|
385 |
+
|
386 |
+
def get_stem_output_path(self, stem_name, custom_output_names):
|
387 |
+
"""
|
388 |
+
Gets the output path for a stem based on the stem name and custom output names.
|
389 |
+
"""
|
390 |
+
# Convert custom_output_names keys to lowercase for case-insensitive comparison
|
391 |
+
if custom_output_names:
|
392 |
+
custom_output_names_lower = {k.lower(): v for k, v in custom_output_names.items()}
|
393 |
+
stem_name_lower = stem_name.lower()
|
394 |
+
if stem_name_lower in custom_output_names_lower:
|
395 |
+
sanitized_custom_name = self.sanitize_filename(custom_output_names_lower[stem_name_lower])
|
396 |
+
return os.path.join(f"{sanitized_custom_name}.{self.output_format.lower()}")
|
397 |
+
|
398 |
+
sanitized_audio_base = self.sanitize_filename(self.audio_file_base)
|
399 |
+
sanitized_stem_name = self.sanitize_filename(stem_name)
|
400 |
+
sanitized_model_name = self.sanitize_filename(self.model_name)
|
401 |
+
|
402 |
+
filename = f"{sanitized_audio_base}_({sanitized_stem_name})_{sanitized_model_name}.{self.output_format.lower()}"
|
403 |
+
return os.path.join(filename)
|
audio_separator/separator/separator.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" This file contains the Separator class, to facilitate the separation of stems from audio. """
|
2 |
+
|
3 |
+
from importlib import metadata, resources
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import platform
|
7 |
+
import subprocess
|
8 |
+
import time
|
9 |
+
import logging
|
10 |
+
import warnings
|
11 |
+
import importlib
|
12 |
+
import io
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
import hashlib
|
16 |
+
import json
|
17 |
+
import yaml
|
18 |
+
import requests
|
19 |
+
import torch
|
20 |
+
import torch.amp.autocast_mode as autocast_mode
|
21 |
+
import onnxruntime as ort
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
|
25 |
+
class Separator:
|
26 |
+
"""
|
27 |
+
The Separator class is designed to facilitate the separation of audio sources from a given audio file.
|
28 |
+
It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides
|
29 |
+
functionalities to configure separation parameters, load models, and perform audio source separation.
|
30 |
+
It also handles logging, normalization, and output formatting of the separated audio stems.
|
31 |
+
|
32 |
+
The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
|
33 |
+
this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
|
34 |
+
initiating the separation process and passing outputs back to the caller.
|
35 |
+
|
36 |
+
Common Attributes:
|
37 |
+
log_level (int): The logging level.
|
38 |
+
log_formatter (logging.Formatter): The logging formatter.
|
39 |
+
model_file_dir (str): The directory where model files are stored.
|
40 |
+
output_dir (str): The directory where output files will be saved.
|
41 |
+
output_format (str): The format of the output audio file.
|
42 |
+
output_bitrate (str): The bitrate of the output audio file.
|
43 |
+
amplification_threshold (float): The threshold for audio amplification.
|
44 |
+
normalization_threshold (float): The threshold for audio normalization.
|
45 |
+
output_single_stem (str): Option to output a single stem.
|
46 |
+
invert_using_spec (bool): Flag to invert using spectrogram.
|
47 |
+
sample_rate (int): The sample rate of the audio.
|
48 |
+
use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
|
49 |
+
use_autocast (bool): Flag to use PyTorch autocast for faster inference.
|
50 |
+
|
51 |
+
MDX Architecture Specific Attributes:
|
52 |
+
hop_length (int): The hop length for STFT.
|
53 |
+
segment_size (int): The segment size for processing.
|
54 |
+
overlap (float): The overlap between segments.
|
55 |
+
batch_size (int): The batch size for processing.
|
56 |
+
enable_denoise (bool): Flag to enable or disable denoising.
|
57 |
+
|
58 |
+
VR Architecture Specific Attributes & Defaults:
|
59 |
+
batch_size: 16
|
60 |
+
window_size: 512
|
61 |
+
aggression: 5
|
62 |
+
enable_tta: False
|
63 |
+
enable_post_process: False
|
64 |
+
post_process_threshold: 0.2
|
65 |
+
high_end_process: False
|
66 |
+
|
67 |
+
Demucs Architecture Specific Attributes & Defaults:
|
68 |
+
segment_size: "Default"
|
69 |
+
shifts: 2
|
70 |
+
overlap: 0.25
|
71 |
+
segments_enabled: True
|
72 |
+
|
73 |
+
MDXC Architecture Specific Attributes & Defaults:
|
74 |
+
segment_size: 256
|
75 |
+
override_model_segment_size: False
|
76 |
+
batch_size: 1
|
77 |
+
overlap: 8
|
78 |
+
pitch_shift: 0
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
log_level=logging.INFO,
|
84 |
+
log_formatter=None,
|
85 |
+
model_file_dir="/tmp/audio-separator-models/",
|
86 |
+
output_dir=None,
|
87 |
+
output_format="WAV",
|
88 |
+
output_bitrate=None,
|
89 |
+
normalization_threshold=0.9,
|
90 |
+
amplification_threshold=0.0,
|
91 |
+
output_single_stem=None,
|
92 |
+
invert_using_spec=False,
|
93 |
+
sample_rate=44100,
|
94 |
+
use_soundfile=False,
|
95 |
+
use_autocast=False,
|
96 |
+
use_directml=False,
|
97 |
+
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
|
98 |
+
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
|
99 |
+
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
|
100 |
+
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
|
101 |
+
info_only=False,
|
102 |
+
):
|
103 |
+
"""Initialize the separator."""
|
104 |
+
self.logger = logging.getLogger(__name__)
|
105 |
+
self.logger.setLevel(log_level)
|
106 |
+
self.log_level = log_level
|
107 |
+
self.log_formatter = log_formatter
|
108 |
+
|
109 |
+
self.log_handler = logging.StreamHandler()
|
110 |
+
|
111 |
+
if self.log_formatter is None:
|
112 |
+
self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
113 |
+
|
114 |
+
self.log_handler.setFormatter(self.log_formatter)
|
115 |
+
|
116 |
+
if not self.logger.hasHandlers():
|
117 |
+
self.logger.addHandler(self.log_handler)
|
118 |
+
|
119 |
+
# Filter out noisy warnings from PyTorch for users who don't care about them
|
120 |
+
if log_level > logging.DEBUG:
|
121 |
+
warnings.filterwarnings("ignore")
|
122 |
+
|
123 |
+
# Skip initialization logs if info_only is True
|
124 |
+
if not info_only:
|
125 |
+
package_version = self.get_package_distribution("audio-separator").version
|
126 |
+
self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
|
127 |
+
|
128 |
+
if output_dir is None:
|
129 |
+
output_dir = os.getcwd()
|
130 |
+
if not info_only:
|
131 |
+
self.logger.info("Output directory not specified. Using current working directory.")
|
132 |
+
|
133 |
+
self.output_dir = output_dir
|
134 |
+
|
135 |
+
# Check for environment variable to override model_file_dir
|
136 |
+
env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
|
137 |
+
if env_model_dir:
|
138 |
+
self.model_file_dir = env_model_dir
|
139 |
+
self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
|
140 |
+
if not os.path.exists(self.model_file_dir):
|
141 |
+
raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
|
142 |
+
else:
|
143 |
+
self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
|
144 |
+
self.model_file_dir = model_file_dir
|
145 |
+
|
146 |
+
# Create the model directory if it does not exist
|
147 |
+
os.makedirs(self.model_file_dir, exist_ok=True)
|
148 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
149 |
+
|
150 |
+
self.output_format = output_format
|
151 |
+
self.output_bitrate = output_bitrate
|
152 |
+
|
153 |
+
if self.output_format is None:
|
154 |
+
self.output_format = "WAV"
|
155 |
+
|
156 |
+
self.normalization_threshold = normalization_threshold
|
157 |
+
if normalization_threshold <= 0 or normalization_threshold > 1:
|
158 |
+
raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
|
159 |
+
|
160 |
+
self.amplification_threshold = amplification_threshold
|
161 |
+
if amplification_threshold < 0 or amplification_threshold > 1:
|
162 |
+
raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
|
163 |
+
|
164 |
+
self.output_single_stem = output_single_stem
|
165 |
+
if output_single_stem is not None:
|
166 |
+
self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
|
167 |
+
|
168 |
+
self.invert_using_spec = invert_using_spec
|
169 |
+
if self.invert_using_spec:
|
170 |
+
self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
|
171 |
+
|
172 |
+
try:
|
173 |
+
self.sample_rate = int(sample_rate)
|
174 |
+
if self.sample_rate <= 0:
|
175 |
+
raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
|
176 |
+
if self.sample_rate > 12800000:
|
177 |
+
raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
|
178 |
+
except ValueError:
|
179 |
+
raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
|
180 |
+
|
181 |
+
self.use_soundfile = use_soundfile
|
182 |
+
self.use_autocast = use_autocast
|
183 |
+
self.use_directml = use_directml
|
184 |
+
|
185 |
+
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
|
186 |
+
# even though they are specific to a single model architecture
|
187 |
+
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
|
188 |
+
|
189 |
+
self.torch_device = None
|
190 |
+
self.torch_device_cpu = None
|
191 |
+
self.torch_device_mps = None
|
192 |
+
|
193 |
+
self.onnx_execution_provider = None
|
194 |
+
self.model_instance = None
|
195 |
+
|
196 |
+
self.model_is_uvr_vip = False
|
197 |
+
self.model_friendly_name = None
|
198 |
+
|
199 |
+
if not info_only:
|
200 |
+
self.setup_accelerated_inferencing_device()
|
201 |
+
|
202 |
+
def setup_accelerated_inferencing_device(self):
|
203 |
+
"""
|
204 |
+
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
205 |
+
"""
|
206 |
+
system_info = self.get_system_info()
|
207 |
+
self.check_ffmpeg_installed()
|
208 |
+
self.log_onnxruntime_packages()
|
209 |
+
self.setup_torch_device(system_info)
|
210 |
+
|
211 |
+
def get_system_info(self):
|
212 |
+
"""
|
213 |
+
This method logs the system information, including the operating system, CPU archutecture and Python version
|
214 |
+
"""
|
215 |
+
os_name = platform.system()
|
216 |
+
os_version = platform.version()
|
217 |
+
self.logger.info(f"Operating System: {os_name} {os_version}")
|
218 |
+
|
219 |
+
system_info = platform.uname()
|
220 |
+
self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
|
221 |
+
|
222 |
+
python_version = platform.python_version()
|
223 |
+
self.logger.info(f"Python Version: {python_version}")
|
224 |
+
|
225 |
+
pytorch_version = torch.__version__
|
226 |
+
self.logger.info(f"PyTorch Version: {pytorch_version}")
|
227 |
+
return system_info
|
228 |
+
|
229 |
+
def check_ffmpeg_installed(self):
|
230 |
+
"""
|
231 |
+
This method checks if ffmpeg is installed and logs its version.
|
232 |
+
"""
|
233 |
+
try:
|
234 |
+
ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
|
235 |
+
first_line = ffmpeg_version_output.splitlines()[0]
|
236 |
+
self.logger.info(f"FFmpeg installed: {first_line}")
|
237 |
+
except FileNotFoundError:
|
238 |
+
self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
|
239 |
+
# Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
|
240 |
+
# but if we're just running unit tests in CI, no reason to throw
|
241 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
242 |
+
raise
|
243 |
+
|
244 |
+
def log_onnxruntime_packages(self):
|
245 |
+
"""
|
246 |
+
This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
|
247 |
+
"""
|
248 |
+
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
249 |
+
onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
|
250 |
+
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
251 |
+
onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
|
252 |
+
|
253 |
+
if onnxruntime_gpu_package is not None:
|
254 |
+
self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
|
255 |
+
if onnxruntime_silicon_package is not None:
|
256 |
+
self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
|
257 |
+
if onnxruntime_cpu_package is not None:
|
258 |
+
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
|
259 |
+
if onnxruntime_dml_package is not None:
|
260 |
+
self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
|
261 |
+
|
262 |
+
def setup_torch_device(self, system_info):
|
263 |
+
"""
|
264 |
+
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
265 |
+
"""
|
266 |
+
hardware_acceleration_enabled = False
|
267 |
+
ort_providers = ort.get_available_providers()
|
268 |
+
has_torch_dml_installed = self.get_package_distribution("torch_directml")
|
269 |
+
|
270 |
+
self.torch_device_cpu = torch.device("cpu")
|
271 |
+
|
272 |
+
if torch.cuda.is_available():
|
273 |
+
self.configure_cuda(ort_providers)
|
274 |
+
hardware_acceleration_enabled = True
|
275 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
276 |
+
self.configure_mps(ort_providers)
|
277 |
+
hardware_acceleration_enabled = True
|
278 |
+
elif self.use_directml and has_torch_dml_installed:
|
279 |
+
import torch_directml
|
280 |
+
if torch_directml.is_available():
|
281 |
+
self.configure_dml(ort_providers)
|
282 |
+
hardware_acceleration_enabled = True
|
283 |
+
|
284 |
+
if not hardware_acceleration_enabled:
|
285 |
+
self.logger.info("No hardware acceleration could be configured, running in CPU mode")
|
286 |
+
self.torch_device = self.torch_device_cpu
|
287 |
+
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
288 |
+
|
289 |
+
def configure_cuda(self, ort_providers):
|
290 |
+
"""
|
291 |
+
This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
|
292 |
+
"""
|
293 |
+
self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
|
294 |
+
self.torch_device = torch.device("cuda")
|
295 |
+
if "CUDAExecutionProvider" in ort_providers:
|
296 |
+
self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
|
297 |
+
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
298 |
+
else:
|
299 |
+
self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
300 |
+
|
301 |
+
def configure_mps(self, ort_providers):
|
302 |
+
"""
|
303 |
+
This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
|
304 |
+
"""
|
305 |
+
self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
|
306 |
+
self.torch_device_mps = torch.device("mps")
|
307 |
+
|
308 |
+
self.torch_device = self.torch_device_mps
|
309 |
+
|
310 |
+
if "CoreMLExecutionProvider" in ort_providers:
|
311 |
+
self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
|
312 |
+
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
313 |
+
else:
|
314 |
+
self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
315 |
+
|
316 |
+
def configure_dml(self, ort_providers):
|
317 |
+
"""
|
318 |
+
This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
|
319 |
+
"""
|
320 |
+
import torch_directml
|
321 |
+
self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
|
322 |
+
self.torch_device_dml = torch_directml.device()
|
323 |
+
self.torch_device = self.torch_device_dml
|
324 |
+
|
325 |
+
if "DmlExecutionProvider" in ort_providers:
|
326 |
+
self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
|
327 |
+
self.onnx_execution_provider = ["DmlExecutionProvider"]
|
328 |
+
else:
|
329 |
+
self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
330 |
+
|
331 |
+
def get_package_distribution(self, package_name):
|
332 |
+
"""
|
333 |
+
This method returns the package distribution for a given package name if installed, or None otherwise.
|
334 |
+
"""
|
335 |
+
try:
|
336 |
+
return metadata.distribution(package_name)
|
337 |
+
except metadata.PackageNotFoundError:
|
338 |
+
self.logger.debug(f"Python package: {package_name} not installed")
|
339 |
+
return None
|
340 |
+
|
341 |
+
def get_model_hash(self, model_path):
|
342 |
+
"""
|
343 |
+
This method returns the MD5 hash of a given model file.
|
344 |
+
"""
|
345 |
+
self.logger.debug(f"Calculating hash of model file {model_path}")
|
346 |
+
# Use the specific byte count from the original logic
|
347 |
+
BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
|
348 |
+
|
349 |
+
try:
|
350 |
+
file_size = os.path.getsize(model_path)
|
351 |
+
|
352 |
+
with open(model_path, "rb") as f:
|
353 |
+
if file_size < BYTES_TO_HASH:
|
354 |
+
# Hash the entire file if smaller than the target byte count
|
355 |
+
self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.")
|
356 |
+
hash_value = hashlib.md5(f.read()).hexdigest()
|
357 |
+
else:
|
358 |
+
# Seek to the specific position before the end (from the beginning) and hash
|
359 |
+
seek_pos = file_size - BYTES_TO_HASH
|
360 |
+
self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
|
361 |
+
f.seek(seek_pos, io.SEEK_SET)
|
362 |
+
hash_value = hashlib.md5(f.read()).hexdigest()
|
363 |
+
|
364 |
+
# Log the calculated hash
|
365 |
+
self.logger.info(f"Hash of model file {model_path} is {hash_value}")
|
366 |
+
return hash_value
|
367 |
+
|
368 |
+
except FileNotFoundError:
|
369 |
+
self.logger.error(f"Model file not found at {model_path}")
|
370 |
+
raise # Re-raise the specific error
|
371 |
+
except Exception as e:
|
372 |
+
# Catch other potential errors (e.g., permissions, other IOErrors)
|
373 |
+
self.logger.error(f"Error calculating hash for {model_path}: {e}")
|
374 |
+
raise # Re-raise other errors
|
375 |
+
|
376 |
+
def download_file_if_not_exists(self, url, output_path):
|
377 |
+
"""
|
378 |
+
This method downloads a file from a given URL to a given output path, if the file does not already exist.
|
379 |
+
"""
|
380 |
+
|
381 |
+
if os.path.isfile(output_path):
|
382 |
+
self.logger.debug(f"File already exists at {output_path}, skipping download")
|
383 |
+
return
|
384 |
+
|
385 |
+
self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s")
|
386 |
+
response = requests.get(url, stream=True, timeout=300)
|
387 |
+
|
388 |
+
if response.status_code == 200:
|
389 |
+
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
390 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
391 |
+
|
392 |
+
with open(output_path, "wb") as f:
|
393 |
+
for chunk in response.iter_content(chunk_size=8192):
|
394 |
+
progress_bar.update(len(chunk))
|
395 |
+
f.write(chunk)
|
396 |
+
progress_bar.close()
|
397 |
+
else:
|
398 |
+
raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}")
|
399 |
+
|
400 |
+
def list_supported_model_files(self):
|
401 |
+
"""
|
402 |
+
This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these.
|
403 |
+
Also includes model performance scores where available.
|
404 |
+
|
405 |
+
Example response object:
|
406 |
+
|
407 |
+
{
|
408 |
+
"MDX": {
|
409 |
+
"MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": {
|
410 |
+
"filename": "UVR-MDX-NET-Inst_full_292.onnx",
|
411 |
+
"scores": {
|
412 |
+
"vocals": {
|
413 |
+
"SDR": 10.6497,
|
414 |
+
"SIR": 20.3786,
|
415 |
+
"SAR": 10.692,
|
416 |
+
"ISR": 14.848
|
417 |
+
},
|
418 |
+
"instrumental": {
|
419 |
+
"SDR": 15.2149,
|
420 |
+
"SIR": 25.6075,
|
421 |
+
"SAR": 17.1363,
|
422 |
+
"ISR": 17.7893
|
423 |
+
}
|
424 |
+
},
|
425 |
+
"download_files": [
|
426 |
+
"UVR-MDX-NET-Inst_full_292.onnx"
|
427 |
+
]
|
428 |
+
}
|
429 |
+
},
|
430 |
+
"Demucs": {
|
431 |
+
"Demucs v4: htdemucs_ft": {
|
432 |
+
"filename": "htdemucs_ft.yaml",
|
433 |
+
"scores": {
|
434 |
+
"vocals": {
|
435 |
+
"SDR": 11.2685,
|
436 |
+
"SIR": 21.257,
|
437 |
+
"SAR": 11.0359,
|
438 |
+
"ISR": 19.3753
|
439 |
+
},
|
440 |
+
"drums": {
|
441 |
+
"SDR": 13.235,
|
442 |
+
"SIR": 23.3053,
|
443 |
+
"SAR": 13.0313,
|
444 |
+
"ISR": 17.2889
|
445 |
+
},
|
446 |
+
"bass": {
|
447 |
+
"SDR": 9.72743,
|
448 |
+
"SIR": 19.5435,
|
449 |
+
"SAR": 9.20801,
|
450 |
+
"ISR": 13.5037
|
451 |
+
}
|
452 |
+
},
|
453 |
+
"download_files": [
|
454 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
|
455 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
|
456 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
|
457 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
|
458 |
+
"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
|
459 |
+
]
|
460 |
+
}
|
461 |
+
},
|
462 |
+
"MDXC": {
|
463 |
+
"MDX23C Model: MDX23C-InstVoc HQ": {
|
464 |
+
"filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
465 |
+
"scores": {
|
466 |
+
"vocals": {
|
467 |
+
"SDR": 11.9504,
|
468 |
+
"SIR": 23.1166,
|
469 |
+
"SAR": 12.093,
|
470 |
+
"ISR": 15.4782
|
471 |
+
},
|
472 |
+
"instrumental": {
|
473 |
+
"SDR": 16.3035,
|
474 |
+
"SIR": 26.6161,
|
475 |
+
"SAR": 18.5167,
|
476 |
+
"ISR": 18.3939
|
477 |
+
}
|
478 |
+
},
|
479 |
+
"download_files": [
|
480 |
+
"MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
481 |
+
"model_2_stem_full_band_8k.yaml"
|
482 |
+
]
|
483 |
+
}
|
484 |
+
}
|
485 |
+
}
|
486 |
+
"""
|
487 |
+
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
488 |
+
|
489 |
+
self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
|
490 |
+
|
491 |
+
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
492 |
+
self.logger.debug(f"UVR model download list loaded")
|
493 |
+
|
494 |
+
# Load the model scores with error handling
|
495 |
+
model_scores = {}
|
496 |
+
try:
|
497 |
+
with resources.open_text("audio_separator", "models-scores.json") as f:
|
498 |
+
model_scores = json.load(f)
|
499 |
+
self.logger.debug(f"Model scores loaded")
|
500 |
+
except json.JSONDecodeError as e:
|
501 |
+
self.logger.warning(f"Failed to load model scores: {str(e)}")
|
502 |
+
self.logger.warning("Continuing without model scores")
|
503 |
+
|
504 |
+
# Only show Demucs v4 models as we've only implemented support for v4
|
505 |
+
filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
|
506 |
+
|
507 |
+
# Modified Demucs handling to use YAML files as identifiers and include download files
|
508 |
+
demucs_models = {}
|
509 |
+
for name, files in filtered_demucs_v4.items():
|
510 |
+
# Find the YAML file in the model files
|
511 |
+
yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
|
512 |
+
if yaml_file:
|
513 |
+
model_score_data = model_scores.get(yaml_file, {})
|
514 |
+
demucs_models[name] = {
|
515 |
+
"filename": yaml_file,
|
516 |
+
"scores": model_score_data.get("median_scores", {}),
|
517 |
+
"stems": model_score_data.get("stems", []),
|
518 |
+
"target_stem": model_score_data.get("target_stem"),
|
519 |
+
"download_files": list(files.values()), # List of all download URLs/filenames
|
520 |
+
}
|
521 |
+
|
522 |
+
# Load the JSON file using importlib.resources
|
523 |
+
with resources.open_text("audio_separator", "models.json") as f:
|
524 |
+
audio_separator_models_list = json.load(f)
|
525 |
+
self.logger.debug(f"Audio-Separator model list loaded")
|
526 |
+
|
527 |
+
# Return object with list of model names
|
528 |
+
model_files_grouped_by_type = {
|
529 |
+
"VR": {
|
530 |
+
name: {
|
531 |
+
"filename": filename,
|
532 |
+
"scores": model_scores.get(filename, {}).get("median_scores", {}),
|
533 |
+
"stems": model_scores.get(filename, {}).get("stems", []),
|
534 |
+
"target_stem": model_scores.get(filename, {}).get("target_stem"),
|
535 |
+
"download_files": [filename],
|
536 |
+
} # Just the filename for VR models
|
537 |
+
for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
|
538 |
+
},
|
539 |
+
"MDX": {
|
540 |
+
name: {
|
541 |
+
"filename": filename,
|
542 |
+
"scores": model_scores.get(filename, {}).get("median_scores", {}),
|
543 |
+
"stems": model_scores.get(filename, {}).get("stems", []),
|
544 |
+
"target_stem": model_scores.get(filename, {}).get("target_stem"),
|
545 |
+
"download_files": [filename],
|
546 |
+
} # Just the filename for MDX models
|
547 |
+
for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
|
548 |
+
},
|
549 |
+
"Demucs": demucs_models,
|
550 |
+
"MDXC": {
|
551 |
+
name: {
|
552 |
+
"filename": next(iter(files.keys())),
|
553 |
+
"scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}),
|
554 |
+
"stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []),
|
555 |
+
"target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"),
|
556 |
+
"download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames
|
557 |
+
}
|
558 |
+
for name, files in {
|
559 |
+
**model_downloads_list["mdx23c_download_list"],
|
560 |
+
**model_downloads_list["mdx23c_download_vip_list"],
|
561 |
+
**model_downloads_list["roformer_download_list"],
|
562 |
+
**audio_separator_models_list["mdx23c_download_list"],
|
563 |
+
**audio_separator_models_list["roformer_download_list"],
|
564 |
+
}.items()
|
565 |
+
},
|
566 |
+
}
|
567 |
+
|
568 |
+
return model_files_grouped_by_type
|
569 |
+
|
570 |
+
def print_uvr_vip_message(self):
|
571 |
+
"""
|
572 |
+
This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
|
573 |
+
"""
|
574 |
+
if self.model_is_uvr_vip:
|
575 |
+
self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
|
576 |
+
self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
|
577 |
+
|
578 |
+
def download_model_files(self, model_filename):
|
579 |
+
"""
|
580 |
+
This method downloads the model files for a given model filename, if they are not already present.
|
581 |
+
Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename)
|
582 |
+
"""
|
583 |
+
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
584 |
+
|
585 |
+
supported_model_files_grouped = self.list_supported_model_files()
|
586 |
+
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
587 |
+
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
|
588 |
+
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
589 |
+
|
590 |
+
yaml_config_filename = None
|
591 |
+
|
592 |
+
self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
|
593 |
+
|
594 |
+
# Iterate through model types (MDX, Demucs, MDXC)
|
595 |
+
for model_type, models in supported_model_files_grouped.items():
|
596 |
+
# Iterate through each model in this type
|
597 |
+
for model_friendly_name, model_info in models.items():
|
598 |
+
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
599 |
+
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
600 |
+
|
601 |
+
# Check if this model matches our target filename
|
602 |
+
if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
|
603 |
+
self.logger.debug(f"Found matching model: {model_friendly_name}")
|
604 |
+
self.model_friendly_name = model_friendly_name
|
605 |
+
self.print_uvr_vip_message()
|
606 |
+
|
607 |
+
# Download each required file for this model
|
608 |
+
for file_to_download in model_info["download_files"]:
|
609 |
+
# For URLs, extract just the filename portion
|
610 |
+
if file_to_download.startswith("http"):
|
611 |
+
filename = file_to_download.split("/")[-1]
|
612 |
+
download_path = os.path.join(self.model_file_dir, filename)
|
613 |
+
self.download_file_if_not_exists(file_to_download, download_path)
|
614 |
+
continue
|
615 |
+
|
616 |
+
download_path = os.path.join(self.model_file_dir, file_to_download)
|
617 |
+
|
618 |
+
# For MDXC models, handle YAML config files specially
|
619 |
+
if model_type == "MDXC" and file_to_download.endswith(".yaml"):
|
620 |
+
yaml_config_filename = file_to_download
|
621 |
+
try:
|
622 |
+
yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
|
623 |
+
self.download_file_if_not_exists(yaml_url, download_path)
|
624 |
+
except RuntimeError:
|
625 |
+
self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
|
626 |
+
yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
627 |
+
self.download_file_if_not_exists(yaml_url, download_path)
|
628 |
+
continue
|
629 |
+
|
630 |
+
# For regular model files, try UVR repo first, then audio-separator repo
|
631 |
+
try:
|
632 |
+
download_url = f"{model_repo_url_prefix}/{file_to_download}"
|
633 |
+
self.download_file_if_not_exists(download_url, download_path)
|
634 |
+
except RuntimeError:
|
635 |
+
self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...")
|
636 |
+
download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
637 |
+
self.download_file_if_not_exists(download_url, download_path)
|
638 |
+
|
639 |
+
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
640 |
+
|
641 |
+
raise ValueError(f"Model file {model_filename} not found in supported model files")
|
642 |
+
|
643 |
+
def load_model_data_from_yaml(self, yaml_config_filename):
|
644 |
+
"""
|
645 |
+
This method loads model-specific parameters from the YAML file for that model.
|
646 |
+
The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training.
|
647 |
+
"""
|
648 |
+
# Verify if the YAML filename includes a full path or just the filename
|
649 |
+
if not os.path.exists(yaml_config_filename):
|
650 |
+
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
651 |
+
else:
|
652 |
+
model_data_yaml_filepath = yaml_config_filename
|
653 |
+
|
654 |
+
self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")
|
655 |
+
|
656 |
+
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
657 |
+
self.logger.debug(f"Model data loaded from YAML file: {model_data}")
|
658 |
+
|
659 |
+
if "roformer" in model_data_yaml_filepath:
|
660 |
+
model_data["is_roformer"] = True
|
661 |
+
|
662 |
+
return model_data
|
663 |
+
|
664 |
+
def load_model_data_using_hash(self, model_path):
|
665 |
+
"""
|
666 |
+
This method loads model-specific parameters from UVR model data files.
|
667 |
+
These parameters are critical to inferencing using a given model, as they need to match whatever was used during training.
|
668 |
+
The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files.
|
669 |
+
"""
|
670 |
+
# Model data and configuration sources from UVR
|
671 |
+
model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
|
672 |
+
|
673 |
+
vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
|
674 |
+
mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
|
675 |
+
|
676 |
+
# Calculate hash for the downloaded model
|
677 |
+
self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
|
678 |
+
model_hash = self.get_model_hash(model_path)
|
679 |
+
self.logger.debug(f"Model {model_path} has hash {model_hash}")
|
680 |
+
|
681 |
+
# Setting up the path for model data and checking its existence
|
682 |
+
vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json")
|
683 |
+
self.logger.debug(f"VR model data path set to {vr_model_data_path}")
|
684 |
+
self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path)
|
685 |
+
|
686 |
+
mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
|
687 |
+
self.logger.debug(f"MDX model data path set to {mdx_model_data_path}")
|
688 |
+
self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
|
689 |
+
|
690 |
+
# Loading model data from UVR
|
691 |
+
self.logger.debug("Loading MDX and VR model parameters from UVR model data files...")
|
692 |
+
vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8"))
|
693 |
+
mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
|
694 |
+
|
695 |
+
# Load additional model data from audio-separator
|
696 |
+
self.logger.debug("Loading additional model parameters from audio-separator model data file...")
|
697 |
+
with resources.open_text("audio_separator", "model-data.json") as f:
|
698 |
+
audio_separator_model_data = json.load(f)
|
699 |
+
|
700 |
+
# Merge the model data objects, with audio-separator data taking precedence
|
701 |
+
vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
|
702 |
+
mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
|
703 |
+
|
704 |
+
if model_hash in mdx_model_data_object:
|
705 |
+
model_data = mdx_model_data_object[model_hash]
|
706 |
+
elif model_hash in vr_model_data_object:
|
707 |
+
model_data = vr_model_data_object[model_hash]
|
708 |
+
else:
|
709 |
+
raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
|
710 |
+
|
711 |
+
self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
|
712 |
+
|
713 |
+
return model_data
|
714 |
+
|
715 |
+
def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
|
716 |
+
"""
|
717 |
+
This method instantiates the architecture-specific separation class,
|
718 |
+
loading the separation model into memory, downloading it first if necessary.
|
719 |
+
"""
|
720 |
+
self.logger.info(f"Loading model {model_filename}...")
|
721 |
+
|
722 |
+
load_model_start_time = time.perf_counter()
|
723 |
+
|
724 |
+
# Setting up the model path
|
725 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
726 |
+
model_name = model_filename.split(".")[0]
|
727 |
+
self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
|
728 |
+
|
729 |
+
if model_path.lower().endswith(".yaml"):
|
730 |
+
yaml_config_filename = model_path
|
731 |
+
|
732 |
+
if yaml_config_filename is not None:
|
733 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
734 |
+
else:
|
735 |
+
model_data = self.load_model_data_using_hash(model_path)
|
736 |
+
|
737 |
+
common_params = {
|
738 |
+
"logger": self.logger,
|
739 |
+
"log_level": self.log_level,
|
740 |
+
"torch_device": self.torch_device,
|
741 |
+
"torch_device_cpu": self.torch_device_cpu,
|
742 |
+
"torch_device_mps": self.torch_device_mps,
|
743 |
+
"onnx_execution_provider": self.onnx_execution_provider,
|
744 |
+
"model_name": model_name,
|
745 |
+
"model_path": model_path,
|
746 |
+
"model_data": model_data,
|
747 |
+
"output_format": self.output_format,
|
748 |
+
"output_bitrate": self.output_bitrate,
|
749 |
+
"output_dir": self.output_dir,
|
750 |
+
"normalization_threshold": self.normalization_threshold,
|
751 |
+
"amplification_threshold": self.amplification_threshold,
|
752 |
+
"output_single_stem": self.output_single_stem,
|
753 |
+
"invert_using_spec": self.invert_using_spec,
|
754 |
+
"sample_rate": self.sample_rate,
|
755 |
+
"use_soundfile": self.use_soundfile,
|
756 |
+
}
|
757 |
+
|
758 |
+
# Instantiate the appropriate separator class depending on the model type
|
759 |
+
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
|
760 |
+
|
761 |
+
if model_type not in self.arch_specific_params or model_type not in separator_classes:
|
762 |
+
raise ValueError(f"Model type not supported (yet): {model_type}")
|
763 |
+
|
764 |
+
if model_type == "Demucs" and sys.version_info < (3, 10):
|
765 |
+
raise Exception("Demucs models require Python version 3.10 or newer.")
|
766 |
+
|
767 |
+
self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
|
768 |
+
|
769 |
+
module_name, class_name = separator_classes[model_type].split(".")
|
770 |
+
module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
|
771 |
+
separator_class = getattr(module, class_name)
|
772 |
+
|
773 |
+
self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
|
774 |
+
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
775 |
+
|
776 |
+
# Log the completion of the model load process
|
777 |
+
self.logger.debug("Loading model completed.")
|
778 |
+
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
|
779 |
+
|
780 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
781 |
+
"""
|
782 |
+
Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model.
|
783 |
+
|
784 |
+
This method takes the path to an audio file or a directory containing audio files, processes them through
|
785 |
+
the loaded separation model, and returns the paths to the output files containing the separated audio stems.
|
786 |
+
It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
|
787 |
+
|
788 |
+
Parameters:
|
789 |
+
- audio_file_path (str or list): The path to the audio file or directory, or a list of paths.
|
790 |
+
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
791 |
+
|
792 |
+
Returns:
|
793 |
+
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
794 |
+
"""
|
795 |
+
# Check if the model and device are properly initialized
|
796 |
+
if not (self.torch_device and self.model_instance):
|
797 |
+
raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
|
798 |
+
|
799 |
+
# If audio_file_path is a string, convert it to a list for uniform processing
|
800 |
+
if isinstance(audio_file_path, str):
|
801 |
+
audio_file_path = [audio_file_path]
|
802 |
+
|
803 |
+
# Initialize a list to store paths of all output files
|
804 |
+
output_files = []
|
805 |
+
|
806 |
+
# Process each path in the list
|
807 |
+
for path in audio_file_path:
|
808 |
+
if os.path.isdir(path):
|
809 |
+
# If the path is a directory, recursively search for all audio files
|
810 |
+
for root, dirs, files in os.walk(path):
|
811 |
+
for file in files:
|
812 |
+
# Check the file extension to ensure it's an audio file
|
813 |
+
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
|
814 |
+
full_path = os.path.join(root, file)
|
815 |
+
self.logger.info(f"Processing file: {full_path}")
|
816 |
+
try:
|
817 |
+
# Perform separation for each file
|
818 |
+
files_output = self._separate_file(full_path, custom_output_names)
|
819 |
+
output_files.extend(files_output)
|
820 |
+
except Exception as e:
|
821 |
+
self.logger.error(f"Failed to process file {full_path}: {e}")
|
822 |
+
else:
|
823 |
+
# If the path is a file, process it directly
|
824 |
+
self.logger.info(f"Processing file: {path}")
|
825 |
+
try:
|
826 |
+
files_output = self._separate_file(path, custom_output_names)
|
827 |
+
output_files.extend(files_output)
|
828 |
+
except Exception as e:
|
829 |
+
self.logger.error(f"Failed to process file {path}: {e}")
|
830 |
+
|
831 |
+
return output_files
|
832 |
+
|
833 |
+
def _separate_file(self, audio_file_path, custom_output_names=None):
|
834 |
+
"""
|
835 |
+
Internal method to handle separation for a single audio file.
|
836 |
+
This method performs the actual separation process for a single audio file. It logs the start and end of the process,
|
837 |
+
handles autocast if enabled, and ensures GPU cache is cleared after processing.
|
838 |
+
Parameters:
|
839 |
+
- audio_file_path (str): The path to the audio file.
|
840 |
+
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
841 |
+
Returns:
|
842 |
+
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
843 |
+
"""
|
844 |
+
# Log the start of the separation process
|
845 |
+
self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
|
846 |
+
separate_start_time = time.perf_counter()
|
847 |
+
|
848 |
+
# Log normalization and amplification thresholds
|
849 |
+
self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
|
850 |
+
self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
|
851 |
+
|
852 |
+
# Run separation method for the loaded model with autocast enabled if supported by the device
|
853 |
+
output_files = None
|
854 |
+
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
|
855 |
+
self.logger.debug("Autocast available.")
|
856 |
+
with autocast_mode.autocast(self.torch_device.type):
|
857 |
+
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
858 |
+
else:
|
859 |
+
self.logger.debug("Autocast unavailable.")
|
860 |
+
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
861 |
+
|
862 |
+
# Clear GPU cache to free up memory
|
863 |
+
self.model_instance.clear_gpu_cache()
|
864 |
+
|
865 |
+
# Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
|
866 |
+
self.model_instance.clear_file_specific_paths()
|
867 |
+
|
868 |
+
# Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
|
869 |
+
self.print_uvr_vip_message()
|
870 |
+
|
871 |
+
# Log the completion of the separation process
|
872 |
+
self.logger.debug("Separation process completed.")
|
873 |
+
self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')
|
874 |
+
|
875 |
+
return output_files
|
876 |
+
|
877 |
+
def download_model_and_data(self, model_filename):
|
878 |
+
"""
|
879 |
+
Downloads the model file without loading it into memory.
|
880 |
+
"""
|
881 |
+
self.logger.info(f"Downloading model {model_filename}...")
|
882 |
+
|
883 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
884 |
+
|
885 |
+
if model_path.lower().endswith(".yaml"):
|
886 |
+
yaml_config_filename = model_path
|
887 |
+
|
888 |
+
if yaml_config_filename is not None:
|
889 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
890 |
+
else:
|
891 |
+
model_data = self.load_model_data_using_hash(model_path)
|
892 |
+
|
893 |
+
model_data_dict_size = len(model_data)
|
894 |
+
|
895 |
+
self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
|
896 |
+
|
897 |
+
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
898 |
+
"""
|
899 |
+
Returns a simplified, user-friendly list of models with their key metrics.
|
900 |
+
Optionally sorts the list based on the specified criteria.
|
901 |
+
|
902 |
+
:param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
|
903 |
+
"""
|
904 |
+
model_files = self.list_supported_model_files()
|
905 |
+
simplified_list = {}
|
906 |
+
|
907 |
+
for model_type, models in model_files.items():
|
908 |
+
for name, data in models.items():
|
909 |
+
filename = data["filename"]
|
910 |
+
scores = data.get("scores") or {}
|
911 |
+
stems = data.get("stems") or []
|
912 |
+
target_stem = data.get("target_stem")
|
913 |
+
|
914 |
+
# Format stems with their SDR scores where available
|
915 |
+
stems_with_scores = []
|
916 |
+
stem_sdr_dict = {}
|
917 |
+
|
918 |
+
# Process each stem from the model's stem list
|
919 |
+
for stem in stems:
|
920 |
+
stem_scores = scores.get(stem, {})
|
921 |
+
# Add asterisk if this is the target stem
|
922 |
+
stem_display = f"{stem}*" if stem == target_stem else stem
|
923 |
+
|
924 |
+
if isinstance(stem_scores, dict) and "SDR" in stem_scores:
|
925 |
+
sdr = round(stem_scores["SDR"], 1)
|
926 |
+
stems_with_scores.append(f"{stem_display} ({sdr})")
|
927 |
+
stem_sdr_dict[stem.lower()] = sdr
|
928 |
+
else:
|
929 |
+
# Include stem without SDR score
|
930 |
+
stems_with_scores.append(stem_display)
|
931 |
+
stem_sdr_dict[stem.lower()] = None
|
932 |
+
|
933 |
+
# If no stems listed, mark as Unknown
|
934 |
+
if not stems_with_scores:
|
935 |
+
stems_with_scores = ["Unknown"]
|
936 |
+
stem_sdr_dict["unknown"] = None
|
937 |
+
|
938 |
+
simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict}
|
939 |
+
|
940 |
+
# Sort and filter the list if a sort_by parameter is provided
|
941 |
+
if filter_sort_by:
|
942 |
+
if filter_sort_by == "name":
|
943 |
+
return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
|
944 |
+
elif filter_sort_by == "filename":
|
945 |
+
return dict(sorted(simplified_list.items()))
|
946 |
+
else:
|
947 |
+
# Convert sort_by to lowercase for case-insensitive comparison
|
948 |
+
sort_by_lower = filter_sort_by.lower()
|
949 |
+
# Filter out models that don't have the specified stem
|
950 |
+
filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
|
951 |
+
|
952 |
+
# Sort by SDR score if available, putting None values last
|
953 |
+
def sort_key(item):
|
954 |
+
sdr = item[1]["SDR"][sort_by_lower]
|
955 |
+
return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
|
956 |
+
|
957 |
+
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
|
958 |
+
|
959 |
+
return simplified_list
|
audio_separator/separator/uvr_lib_v5/__init__.py
ADDED
File without changes
|
audio_separator/separator/uvr_lib_v5/demucs/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
audio_separator/separator/uvr_lib_v5/demucs/__main__.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
from fractions import Fraction
|
13 |
+
|
14 |
+
import torch as th
|
15 |
+
from torch import distributed, nn
|
16 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
17 |
+
|
18 |
+
from .augment import FlipChannels, FlipSign, Remix, Shift
|
19 |
+
from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
|
20 |
+
from .model import Demucs
|
21 |
+
from .parser import get_name, get_parser
|
22 |
+
from .raw import Rawset
|
23 |
+
from .tasnet import ConvTasNet
|
24 |
+
from .test import evaluate
|
25 |
+
from .train import train_model, validate_model
|
26 |
+
from .utils import human_seconds, load_model, save_model, sizeof_fmt
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class SavedState:
|
31 |
+
metrics: list = field(default_factory=list)
|
32 |
+
last_state: dict = None
|
33 |
+
best_state: dict = None
|
34 |
+
optimizer: dict = None
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
parser = get_parser()
|
39 |
+
args = parser.parse_args()
|
40 |
+
name = get_name(parser, args)
|
41 |
+
print(f"Experiment {name}")
|
42 |
+
|
43 |
+
if args.musdb is None and args.rank == 0:
|
44 |
+
print("You must provide the path to the MusDB dataset with the --musdb flag. " "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", file=sys.stderr)
|
45 |
+
sys.exit(1)
|
46 |
+
|
47 |
+
eval_folder = args.evals / name
|
48 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
49 |
+
args.logs.mkdir(exist_ok=True)
|
50 |
+
metrics_path = args.logs / f"{name}.json"
|
51 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
52 |
+
args.checkpoints.mkdir(exist_ok=True, parents=True)
|
53 |
+
args.models.mkdir(exist_ok=True, parents=True)
|
54 |
+
|
55 |
+
if args.device is None:
|
56 |
+
device = "cpu"
|
57 |
+
if th.cuda.is_available():
|
58 |
+
device = "cuda"
|
59 |
+
else:
|
60 |
+
device = args.device
|
61 |
+
|
62 |
+
th.manual_seed(args.seed)
|
63 |
+
# Prevents too many threads to be started when running `museval` as it can be quite
|
64 |
+
# inefficient on NUMA architectures.
|
65 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
66 |
+
|
67 |
+
if args.world_size > 1:
|
68 |
+
if device != "cuda" and args.rank == 0:
|
69 |
+
print("Error: distributed training is only available with cuda device", file=sys.stderr)
|
70 |
+
sys.exit(1)
|
71 |
+
th.cuda.set_device(args.rank % th.cuda.device_count())
|
72 |
+
distributed.init_process_group(backend="nccl", init_method="tcp://" + args.master, rank=args.rank, world_size=args.world_size)
|
73 |
+
|
74 |
+
checkpoint = args.checkpoints / f"{name}.th"
|
75 |
+
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
|
76 |
+
if args.restart and checkpoint.exists():
|
77 |
+
checkpoint.unlink()
|
78 |
+
|
79 |
+
if args.test:
|
80 |
+
args.epochs = 1
|
81 |
+
args.repeat = 0
|
82 |
+
model = load_model(args.models / args.test)
|
83 |
+
elif args.tasnet:
|
84 |
+
model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
|
85 |
+
else:
|
86 |
+
model = Demucs(
|
87 |
+
audio_channels=args.audio_channels,
|
88 |
+
channels=args.channels,
|
89 |
+
context=args.context,
|
90 |
+
depth=args.depth,
|
91 |
+
glu=args.glu,
|
92 |
+
growth=args.growth,
|
93 |
+
kernel_size=args.kernel_size,
|
94 |
+
lstm_layers=args.lstm_layers,
|
95 |
+
rescale=args.rescale,
|
96 |
+
rewrite=args.rewrite,
|
97 |
+
sources=4,
|
98 |
+
stride=args.conv_stride,
|
99 |
+
upsample=args.upsample,
|
100 |
+
samplerate=args.samplerate,
|
101 |
+
)
|
102 |
+
model.to(device)
|
103 |
+
if args.show:
|
104 |
+
print(model)
|
105 |
+
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
|
106 |
+
print(f"Model size {size}")
|
107 |
+
return
|
108 |
+
|
109 |
+
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
|
110 |
+
|
111 |
+
try:
|
112 |
+
saved = th.load(checkpoint, map_location="cpu")
|
113 |
+
except IOError:
|
114 |
+
saved = SavedState()
|
115 |
+
else:
|
116 |
+
model.load_state_dict(saved.last_state)
|
117 |
+
optimizer.load_state_dict(saved.optimizer)
|
118 |
+
|
119 |
+
if args.save_model:
|
120 |
+
if args.rank == 0:
|
121 |
+
model.to("cpu")
|
122 |
+
model.load_state_dict(saved.best_state)
|
123 |
+
save_model(model, args.models / f"{name}.th")
|
124 |
+
return
|
125 |
+
|
126 |
+
if args.rank == 0:
|
127 |
+
done = args.logs / f"{name}.done"
|
128 |
+
if done.exists():
|
129 |
+
done.unlink()
|
130 |
+
|
131 |
+
if args.augment:
|
132 |
+
augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), Remix(group_size=args.remix_group_size)).to(device)
|
133 |
+
else:
|
134 |
+
augment = Shift(args.data_stride)
|
135 |
+
|
136 |
+
if args.mse:
|
137 |
+
criterion = nn.MSELoss()
|
138 |
+
else:
|
139 |
+
criterion = nn.L1Loss()
|
140 |
+
|
141 |
+
# Setting number of samples so that all convolution windows are full.
|
142 |
+
# Prevents hard to debug mistake with the prediction being shifted compared
|
143 |
+
# to the input mixture.
|
144 |
+
samples = model.valid_length(args.samples)
|
145 |
+
print(f"Number of training samples adjusted to {samples}")
|
146 |
+
|
147 |
+
if args.raw:
|
148 |
+
train_set = Rawset(args.raw / "train", samples=samples + args.data_stride, channels=args.audio_channels, streams=[0, 1, 2, 3, 4], stride=args.data_stride)
|
149 |
+
|
150 |
+
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
|
151 |
+
else:
|
152 |
+
if not args.metadata.is_file() and args.rank == 0:
|
153 |
+
build_musdb_metadata(args.metadata, args.musdb, args.workers)
|
154 |
+
if args.world_size > 1:
|
155 |
+
distributed.barrier()
|
156 |
+
metadata = json.load(open(args.metadata))
|
157 |
+
duration = Fraction(samples + args.data_stride, args.samplerate)
|
158 |
+
stride = Fraction(args.data_stride, args.samplerate)
|
159 |
+
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), metadata, duration=duration, stride=stride, samplerate=args.samplerate, channels=args.audio_channels)
|
160 |
+
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), metadata, samplerate=args.samplerate, channels=args.audio_channels)
|
161 |
+
|
162 |
+
best_loss = float("inf")
|
163 |
+
for epoch, metrics in enumerate(saved.metrics):
|
164 |
+
print(f"Epoch {epoch:03d}: " f"train={metrics['train']:.8f} " f"valid={metrics['valid']:.8f} " f"best={metrics['best']:.4f} " f"duration={human_seconds(metrics['duration'])}")
|
165 |
+
best_loss = metrics["best"]
|
166 |
+
|
167 |
+
if args.world_size > 1:
|
168 |
+
dmodel = DistributedDataParallel(model, device_ids=[th.cuda.current_device()], output_device=th.cuda.current_device())
|
169 |
+
else:
|
170 |
+
dmodel = model
|
171 |
+
|
172 |
+
for epoch in range(len(saved.metrics), args.epochs):
|
173 |
+
begin = time.time()
|
174 |
+
model.train()
|
175 |
+
train_loss = train_model(
|
176 |
+
epoch, train_set, dmodel, criterion, optimizer, augment, batch_size=args.batch_size, device=device, repeat=args.repeat, seed=args.seed, workers=args.workers, world_size=args.world_size
|
177 |
+
)
|
178 |
+
model.eval()
|
179 |
+
valid_loss = validate_model(epoch, valid_set, model, criterion, device=device, rank=args.rank, split=args.split_valid, world_size=args.world_size)
|
180 |
+
|
181 |
+
duration = time.time() - begin
|
182 |
+
if valid_loss < best_loss:
|
183 |
+
best_loss = valid_loss
|
184 |
+
saved.best_state = {key: value.to("cpu").clone() for key, value in model.state_dict().items()}
|
185 |
+
saved.metrics.append({"train": train_loss, "valid": valid_loss, "best": best_loss, "duration": duration})
|
186 |
+
if args.rank == 0:
|
187 |
+
json.dump(saved.metrics, open(metrics_path, "w"))
|
188 |
+
|
189 |
+
saved.last_state = model.state_dict()
|
190 |
+
saved.optimizer = optimizer.state_dict()
|
191 |
+
if args.rank == 0 and not args.test:
|
192 |
+
th.save(saved, checkpoint_tmp)
|
193 |
+
checkpoint_tmp.rename(checkpoint)
|
194 |
+
|
195 |
+
print(f"Epoch {epoch:03d}: " f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " f"duration={human_seconds(duration)}")
|
196 |
+
|
197 |
+
del dmodel
|
198 |
+
model.load_state_dict(saved.best_state)
|
199 |
+
if args.eval_cpu:
|
200 |
+
device = "cpu"
|
201 |
+
model.to(device)
|
202 |
+
model.eval()
|
203 |
+
evaluate(model, args.musdb, eval_folder, rank=args.rank, world_size=args.world_size, device=device, save=args.save, split=args.split_valid, shifts=args.shifts, workers=args.eval_workers)
|
204 |
+
model.to("cpu")
|
205 |
+
save_model(model, args.models / f"{name}.th")
|
206 |
+
if args.rank == 0:
|
207 |
+
print("done")
|
208 |
+
done.write_text("done")
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
main()
|
audio_separator/separator/uvr_lib_v5/demucs/apply.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Code to apply a model to a mix. It will handle chunking with overlaps and
|
8 |
+
inteprolation between chunks, as well as the "shift trick".
|
9 |
+
"""
|
10 |
+
from concurrent.futures import ThreadPoolExecutor
|
11 |
+
import random
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch as th
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
import tqdm
|
18 |
+
|
19 |
+
from .demucs import Demucs
|
20 |
+
from .hdemucs import HDemucs
|
21 |
+
from .utils import center_trim, DummyPoolExecutor
|
22 |
+
|
23 |
+
Model = tp.Union[Demucs, HDemucs]
|
24 |
+
|
25 |
+
progress_bar_num = 0
|
26 |
+
|
27 |
+
|
28 |
+
class BagOfModels(nn.Module):
|
29 |
+
def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None):
|
30 |
+
"""
|
31 |
+
Represents a bag of models with specific weights.
|
32 |
+
You should call `apply_model` rather than calling directly the forward here for
|
33 |
+
optimal performance.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
models (list[nn.Module]): list of Demucs/HDemucs models.
|
37 |
+
weights (list[list[float]]): list of weights. If None, assumed to
|
38 |
+
be all ones, otherwise it should be a list of N list (N number of models),
|
39 |
+
each containing S floats (S number of sources).
|
40 |
+
segment (None or float): overrides the `segment` attribute of each model
|
41 |
+
(this is performed inplace, be careful if you reuse the models passed).
|
42 |
+
"""
|
43 |
+
|
44 |
+
super().__init__()
|
45 |
+
assert len(models) > 0
|
46 |
+
first = models[0]
|
47 |
+
for other in models:
|
48 |
+
assert other.sources == first.sources
|
49 |
+
assert other.samplerate == first.samplerate
|
50 |
+
assert other.audio_channels == first.audio_channels
|
51 |
+
if segment is not None:
|
52 |
+
other.segment = segment
|
53 |
+
|
54 |
+
self.audio_channels = first.audio_channels
|
55 |
+
self.samplerate = first.samplerate
|
56 |
+
self.sources = first.sources
|
57 |
+
self.models = nn.ModuleList(models)
|
58 |
+
|
59 |
+
if weights is None:
|
60 |
+
weights = [[1.0 for _ in first.sources] for _ in models]
|
61 |
+
else:
|
62 |
+
assert len(weights) == len(models)
|
63 |
+
for weight in weights:
|
64 |
+
assert len(weight) == len(first.sources)
|
65 |
+
self.weights = weights
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
raise NotImplementedError("Call `apply_model` on this.")
|
69 |
+
|
70 |
+
|
71 |
+
class TensorChunk:
|
72 |
+
def __init__(self, tensor, offset=0, length=None):
|
73 |
+
total_length = tensor.shape[-1]
|
74 |
+
assert offset >= 0
|
75 |
+
assert offset < total_length
|
76 |
+
|
77 |
+
if length is None:
|
78 |
+
length = total_length - offset
|
79 |
+
else:
|
80 |
+
length = min(total_length - offset, length)
|
81 |
+
|
82 |
+
if isinstance(tensor, TensorChunk):
|
83 |
+
self.tensor = tensor.tensor
|
84 |
+
self.offset = offset + tensor.offset
|
85 |
+
else:
|
86 |
+
self.tensor = tensor
|
87 |
+
self.offset = offset
|
88 |
+
self.length = length
|
89 |
+
self.device = tensor.device
|
90 |
+
|
91 |
+
@property
|
92 |
+
def shape(self):
|
93 |
+
shape = list(self.tensor.shape)
|
94 |
+
shape[-1] = self.length
|
95 |
+
return shape
|
96 |
+
|
97 |
+
def padded(self, target_length):
|
98 |
+
delta = target_length - self.length
|
99 |
+
total_length = self.tensor.shape[-1]
|
100 |
+
assert delta >= 0
|
101 |
+
|
102 |
+
start = self.offset - delta // 2
|
103 |
+
end = start + target_length
|
104 |
+
|
105 |
+
correct_start = max(0, start)
|
106 |
+
correct_end = min(total_length, end)
|
107 |
+
|
108 |
+
pad_left = correct_start - start
|
109 |
+
pad_right = end - correct_end
|
110 |
+
|
111 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
112 |
+
assert out.shape[-1] == target_length
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def tensor_chunk(tensor_or_chunk):
|
117 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
118 |
+
return tensor_or_chunk
|
119 |
+
else:
|
120 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
121 |
+
return TensorChunk(tensor_or_chunk)
|
122 |
+
|
123 |
+
|
124 |
+
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
|
125 |
+
"""
|
126 |
+
Apply model to a given mixture.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
130 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
131 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
132 |
+
and improves SDR by up to 0.2 points.
|
133 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
134 |
+
and predictions will be performed individually on each and concatenated.
|
135 |
+
Useful for model with large memory footprint like Tasnet.
|
136 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
137 |
+
device (torch.device, str, or None): if provided, device on which to
|
138 |
+
execute the computation, otherwise `mix.device` is assumed.
|
139 |
+
When `device` is different from `mix.device`, only local computations will
|
140 |
+
be on `device`, while the entire tracks will be stored on `mix.device`.
|
141 |
+
"""
|
142 |
+
|
143 |
+
global fut_length
|
144 |
+
global bag_num
|
145 |
+
global prog_bar
|
146 |
+
|
147 |
+
if device is None:
|
148 |
+
device = mix.device
|
149 |
+
else:
|
150 |
+
device = th.device(device)
|
151 |
+
if pool is None:
|
152 |
+
if num_workers > 0 and device.type == "cpu":
|
153 |
+
pool = ThreadPoolExecutor(num_workers)
|
154 |
+
else:
|
155 |
+
pool = DummyPoolExecutor()
|
156 |
+
|
157 |
+
kwargs = {
|
158 |
+
"shifts": shifts,
|
159 |
+
"split": split,
|
160 |
+
"overlap": overlap,
|
161 |
+
"transition_power": transition_power,
|
162 |
+
"progress": progress,
|
163 |
+
"device": device,
|
164 |
+
"pool": pool,
|
165 |
+
"set_progress_bar": set_progress_bar,
|
166 |
+
"static_shifts": static_shifts,
|
167 |
+
}
|
168 |
+
|
169 |
+
if isinstance(model, BagOfModels):
|
170 |
+
# Special treatment for bag of model.
|
171 |
+
# We explicitely apply multiple times `apply_model` so that the random shifts
|
172 |
+
# are different for each model.
|
173 |
+
|
174 |
+
estimates = 0
|
175 |
+
totals = [0] * len(model.sources)
|
176 |
+
bag_num = len(model.models)
|
177 |
+
fut_length = 0
|
178 |
+
prog_bar = 0
|
179 |
+
current_model = 0 # (bag_num + 1)
|
180 |
+
for sub_model, weight in zip(model.models, model.weights):
|
181 |
+
original_model_device = next(iter(sub_model.parameters())).device
|
182 |
+
sub_model.to(device)
|
183 |
+
fut_length += fut_length
|
184 |
+
current_model += 1
|
185 |
+
out = apply_model(sub_model, mix, **kwargs)
|
186 |
+
sub_model.to(original_model_device)
|
187 |
+
for k, inst_weight in enumerate(weight):
|
188 |
+
out[:, k, :, :] *= inst_weight
|
189 |
+
totals[k] += inst_weight
|
190 |
+
estimates += out
|
191 |
+
del out
|
192 |
+
|
193 |
+
for k in range(estimates.shape[1]):
|
194 |
+
estimates[:, k, :, :] /= totals[k]
|
195 |
+
return estimates
|
196 |
+
|
197 |
+
model.to(device)
|
198 |
+
model.eval()
|
199 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
200 |
+
batch, channels, length = mix.shape
|
201 |
+
|
202 |
+
if shifts:
|
203 |
+
kwargs["shifts"] = 0
|
204 |
+
max_shift = int(0.5 * model.samplerate)
|
205 |
+
mix = tensor_chunk(mix)
|
206 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
207 |
+
out = 0
|
208 |
+
for _ in range(shifts):
|
209 |
+
offset = random.randint(0, max_shift)
|
210 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
211 |
+
shifted_out = apply_model(model, shifted, **kwargs)
|
212 |
+
out += shifted_out[..., max_shift - offset :]
|
213 |
+
out /= shifts
|
214 |
+
return out
|
215 |
+
elif split:
|
216 |
+
kwargs["split"] = False
|
217 |
+
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
218 |
+
sum_weight = th.zeros(length, device=mix.device)
|
219 |
+
segment = int(model.samplerate * model.segment)
|
220 |
+
stride = int((1 - overlap) * segment)
|
221 |
+
offsets = range(0, length, stride)
|
222 |
+
scale = float(format(stride / model.samplerate, ".2f"))
|
223 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
224 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
225 |
+
# Large values of transition power will lead to sharper transitions.
|
226 |
+
weight = th.cat([th.arange(1, segment // 2 + 1, device=device), th.arange(segment - segment // 2, 0, -1, device=device)])
|
227 |
+
assert len(weight) == segment
|
228 |
+
# If the overlap < 50%, this will translate to linear transition when
|
229 |
+
# transition_power is 1.
|
230 |
+
weight = (weight / weight.max()) ** transition_power
|
231 |
+
futures = []
|
232 |
+
for offset in offsets:
|
233 |
+
chunk = TensorChunk(mix, offset, segment)
|
234 |
+
future = pool.submit(apply_model, model, chunk, **kwargs)
|
235 |
+
futures.append((future, offset))
|
236 |
+
offset += segment
|
237 |
+
if progress:
|
238 |
+
futures = tqdm.tqdm(futures)
|
239 |
+
for future, offset in futures:
|
240 |
+
if set_progress_bar:
|
241 |
+
fut_length = len(futures) * bag_num * static_shifts
|
242 |
+
prog_bar += 1
|
243 |
+
set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
|
244 |
+
chunk_out = future.result()
|
245 |
+
chunk_length = chunk_out.shape[-1]
|
246 |
+
out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
|
247 |
+
sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
|
248 |
+
assert sum_weight.min() > 0
|
249 |
+
out /= sum_weight
|
250 |
+
return out
|
251 |
+
else:
|
252 |
+
if hasattr(model, "valid_length"):
|
253 |
+
valid_length = model.valid_length(length)
|
254 |
+
else:
|
255 |
+
valid_length = length
|
256 |
+
mix = tensor_chunk(mix)
|
257 |
+
padded_mix = mix.padded(valid_length).to(device)
|
258 |
+
with th.no_grad():
|
259 |
+
out = model(padded_mix)
|
260 |
+
return center_trim(out, length)
|
261 |
+
|
262 |
+
|
263 |
+
def demucs_segments(demucs_segment, demucs_model):
|
264 |
+
|
265 |
+
if demucs_segment == "Default":
|
266 |
+
segment = None
|
267 |
+
if isinstance(demucs_model, BagOfModels):
|
268 |
+
if segment is not None:
|
269 |
+
for sub in demucs_model.models:
|
270 |
+
sub.segment = segment
|
271 |
+
else:
|
272 |
+
if segment is not None:
|
273 |
+
sub.segment = segment
|
274 |
+
else:
|
275 |
+
try:
|
276 |
+
segment = int(demucs_segment)
|
277 |
+
if isinstance(demucs_model, BagOfModels):
|
278 |
+
if segment is not None:
|
279 |
+
for sub in demucs_model.models:
|
280 |
+
sub.segment = segment
|
281 |
+
else:
|
282 |
+
if segment is not None:
|
283 |
+
sub.segment = segment
|
284 |
+
except:
|
285 |
+
segment = None
|
286 |
+
if isinstance(demucs_model, BagOfModels):
|
287 |
+
if segment is not None:
|
288 |
+
for sub in demucs_model.models:
|
289 |
+
sub.segment = segment
|
290 |
+
else:
|
291 |
+
if segment is not None:
|
292 |
+
sub.segment = segment
|
293 |
+
|
294 |
+
return demucs_model
|
audio_separator/separator/uvr_lib_v5/demucs/demucs.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import julius
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from .states import capture_init
|
16 |
+
from .utils import center_trim, unfold
|
17 |
+
|
18 |
+
|
19 |
+
class BLSTM(nn.Module):
|
20 |
+
"""
|
21 |
+
BiLSTM with same hidden units as input dim.
|
22 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
23 |
+
chunks and the LSTM applied separately on each chunk.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
27 |
+
super().__init__()
|
28 |
+
assert max_steps is None or max_steps % 4 == 0
|
29 |
+
self.max_steps = max_steps
|
30 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
31 |
+
self.linear = nn.Linear(2 * dim, dim)
|
32 |
+
self.skip = skip
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
B, C, T = x.shape
|
36 |
+
y = x
|
37 |
+
framed = False
|
38 |
+
if self.max_steps is not None and T > self.max_steps:
|
39 |
+
width = self.max_steps
|
40 |
+
stride = width // 2
|
41 |
+
frames = unfold(x, width, stride)
|
42 |
+
nframes = frames.shape[2]
|
43 |
+
framed = True
|
44 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
45 |
+
|
46 |
+
x = x.permute(2, 0, 1)
|
47 |
+
|
48 |
+
x = self.lstm(x)[0]
|
49 |
+
x = self.linear(x)
|
50 |
+
x = x.permute(1, 2, 0)
|
51 |
+
if framed:
|
52 |
+
out = []
|
53 |
+
frames = x.reshape(B, -1, C, width)
|
54 |
+
limit = stride // 2
|
55 |
+
for k in range(nframes):
|
56 |
+
if k == 0:
|
57 |
+
out.append(frames[:, k, :, :-limit])
|
58 |
+
elif k == nframes - 1:
|
59 |
+
out.append(frames[:, k, :, limit:])
|
60 |
+
else:
|
61 |
+
out.append(frames[:, k, :, limit:-limit])
|
62 |
+
out = torch.cat(out, -1)
|
63 |
+
out = out[..., :T]
|
64 |
+
x = out
|
65 |
+
if self.skip:
|
66 |
+
x = x + y
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
def rescale_conv(conv, reference):
|
71 |
+
"""Rescale initial weight scale. It is unclear why it helps but it certainly does."""
|
72 |
+
std = conv.weight.std().detach()
|
73 |
+
scale = (std / reference) ** 0.5
|
74 |
+
conv.weight.data /= scale
|
75 |
+
if conv.bias is not None:
|
76 |
+
conv.bias.data /= scale
|
77 |
+
|
78 |
+
|
79 |
+
def rescale_module(module, reference):
|
80 |
+
for sub in module.modules():
|
81 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
82 |
+
rescale_conv(sub, reference)
|
83 |
+
|
84 |
+
|
85 |
+
class LayerScale(nn.Module):
|
86 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
87 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, channels: int, init: float = 0):
|
91 |
+
super().__init__()
|
92 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
93 |
+
self.scale.data[:] = init
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
return self.scale[:, None] * x
|
97 |
+
|
98 |
+
|
99 |
+
class DConv(nn.Module):
|
100 |
+
"""
|
101 |
+
New residual branches in each encoder layer.
|
102 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
103 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
104 |
+
e.g. of dim `channels // compress`.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
channels: input/output channels for residual branch.
|
111 |
+
compress: amount of channel compression inside the branch.
|
112 |
+
depth: number of layers in the residual branch. Each layer has its own
|
113 |
+
projection, and potentially LSTM and attention.
|
114 |
+
init: initial scale for LayerNorm.
|
115 |
+
norm: use GroupNorm.
|
116 |
+
attn: use LocalAttention.
|
117 |
+
heads: number of heads for the LocalAttention.
|
118 |
+
ndecay: number of decay controls in the LocalAttention.
|
119 |
+
lstm: use LSTM.
|
120 |
+
gelu: Use GELU activation.
|
121 |
+
kernel: kernel size for the (dilated) convolutions.
|
122 |
+
dilate: if true, use dilation, increasing with the depth.
|
123 |
+
"""
|
124 |
+
|
125 |
+
super().__init__()
|
126 |
+
assert kernel % 2 == 1
|
127 |
+
self.channels = channels
|
128 |
+
self.compress = compress
|
129 |
+
self.depth = abs(depth)
|
130 |
+
dilate = depth > 0
|
131 |
+
|
132 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
133 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
134 |
+
if norm:
|
135 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
136 |
+
|
137 |
+
hidden = int(channels / compress)
|
138 |
+
|
139 |
+
act: tp.Type[nn.Module]
|
140 |
+
if gelu:
|
141 |
+
act = nn.GELU
|
142 |
+
else:
|
143 |
+
act = nn.ReLU
|
144 |
+
|
145 |
+
self.layers = nn.ModuleList([])
|
146 |
+
for d in range(self.depth):
|
147 |
+
dilation = 2**d if dilate else 1
|
148 |
+
padding = dilation * (kernel // 2)
|
149 |
+
mods = [
|
150 |
+
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
151 |
+
norm_fn(hidden),
|
152 |
+
act(),
|
153 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
154 |
+
norm_fn(2 * channels),
|
155 |
+
nn.GLU(1),
|
156 |
+
LayerScale(channels, init),
|
157 |
+
]
|
158 |
+
if attn:
|
159 |
+
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
160 |
+
if lstm:
|
161 |
+
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
162 |
+
layer = nn.Sequential(*mods)
|
163 |
+
self.layers.append(layer)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
for layer in self.layers:
|
167 |
+
x = x + layer(x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class LocalState(nn.Module):
|
172 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
173 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
174 |
+
|
175 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
179 |
+
super().__init__()
|
180 |
+
assert channels % heads == 0, (channels, heads)
|
181 |
+
self.heads = heads
|
182 |
+
self.nfreqs = nfreqs
|
183 |
+
self.ndecay = ndecay
|
184 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
185 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
186 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
187 |
+
if nfreqs:
|
188 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
189 |
+
if ndecay:
|
190 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
191 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
192 |
+
self.query_decay.weight.data *= 0.01
|
193 |
+
assert self.query_decay.bias is not None # stupid type checker
|
194 |
+
self.query_decay.bias.data[:] = -2
|
195 |
+
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
B, C, T = x.shape
|
199 |
+
heads = self.heads
|
200 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
201 |
+
# left index are keys, right index are queries
|
202 |
+
delta = indexes[:, None] - indexes[None, :]
|
203 |
+
|
204 |
+
queries = self.query(x).view(B, heads, -1, T)
|
205 |
+
keys = self.key(x).view(B, heads, -1, T)
|
206 |
+
# t are keys, s are queries
|
207 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
208 |
+
dots /= keys.shape[2] ** 0.5
|
209 |
+
if self.nfreqs:
|
210 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
211 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
212 |
+
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
|
213 |
+
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
214 |
+
if self.ndecay:
|
215 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
216 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
217 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
218 |
+
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
219 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
220 |
+
|
221 |
+
# Kill self reference.
|
222 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
223 |
+
weights = torch.softmax(dots, dim=2)
|
224 |
+
|
225 |
+
content = self.content(x).view(B, heads, -1, T)
|
226 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
227 |
+
if self.nfreqs:
|
228 |
+
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
229 |
+
result = torch.cat([result, time_sig], 2)
|
230 |
+
result = result.reshape(B, -1, T)
|
231 |
+
return x + self.proj(result)
|
232 |
+
|
233 |
+
|
234 |
+
class Demucs(nn.Module):
|
235 |
+
@capture_init
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
sources,
|
239 |
+
# Channels
|
240 |
+
audio_channels=2,
|
241 |
+
channels=64,
|
242 |
+
growth=2.0,
|
243 |
+
# Main structure
|
244 |
+
depth=6,
|
245 |
+
rewrite=True,
|
246 |
+
lstm_layers=0,
|
247 |
+
# Convolutions
|
248 |
+
kernel_size=8,
|
249 |
+
stride=4,
|
250 |
+
context=1,
|
251 |
+
# Activations
|
252 |
+
gelu=True,
|
253 |
+
glu=True,
|
254 |
+
# Normalization
|
255 |
+
norm_starts=4,
|
256 |
+
norm_groups=4,
|
257 |
+
# DConv residual branch
|
258 |
+
dconv_mode=1,
|
259 |
+
dconv_depth=2,
|
260 |
+
dconv_comp=4,
|
261 |
+
dconv_attn=4,
|
262 |
+
dconv_lstm=4,
|
263 |
+
dconv_init=1e-4,
|
264 |
+
# Pre/post processing
|
265 |
+
normalize=True,
|
266 |
+
resample=True,
|
267 |
+
# Weight init
|
268 |
+
rescale=0.1,
|
269 |
+
# Metadata
|
270 |
+
samplerate=44100,
|
271 |
+
segment=4 * 10,
|
272 |
+
):
|
273 |
+
"""
|
274 |
+
Args:
|
275 |
+
sources (list[str]): list of source names
|
276 |
+
audio_channels (int): stereo or mono
|
277 |
+
channels (int): first convolution channels
|
278 |
+
depth (int): number of encoder/decoder layers
|
279 |
+
growth (float): multiply (resp divide) number of channels by that
|
280 |
+
for each layer of the encoder (resp decoder)
|
281 |
+
depth (int): number of layers in the encoder and in the decoder.
|
282 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
283 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
284 |
+
by default, as this is now replaced by the smaller and faster small LSTMs
|
285 |
+
in the DConv branches.
|
286 |
+
kernel_size (int): kernel size for convolutions
|
287 |
+
stride (int): stride for convolutions
|
288 |
+
context (int): kernel size of the convolution in the
|
289 |
+
decoder before the transposed convolution. If > 1,
|
290 |
+
will provide some context from neighboring time steps.
|
291 |
+
gelu: use GELU activation function.
|
292 |
+
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
293 |
+
norm_starts: layer at which group norm starts being used.
|
294 |
+
decoder layers are numbered in reverse order.
|
295 |
+
norm_groups: number of groups for group norm.
|
296 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
297 |
+
dconv_depth: depth of residual DConv branch.
|
298 |
+
dconv_comp: compression of DConv branch.
|
299 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
300 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
301 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
302 |
+
normalize (bool): normalizes the input audio on the fly, and scales back
|
303 |
+
the output by the same amount.
|
304 |
+
resample (bool): upsample x2 the input and downsample /2 the output.
|
305 |
+
rescale (int): rescale initial weights of convolutions
|
306 |
+
to get their standard deviation closer to `rescale`.
|
307 |
+
samplerate (int): stored as meta information for easing
|
308 |
+
future evaluations of the model.
|
309 |
+
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
310 |
+
This is used by `demucs.apply.apply_model`.
|
311 |
+
"""
|
312 |
+
|
313 |
+
super().__init__()
|
314 |
+
self.audio_channels = audio_channels
|
315 |
+
self.sources = sources
|
316 |
+
self.kernel_size = kernel_size
|
317 |
+
self.context = context
|
318 |
+
self.stride = stride
|
319 |
+
self.depth = depth
|
320 |
+
self.resample = resample
|
321 |
+
self.channels = channels
|
322 |
+
self.normalize = normalize
|
323 |
+
self.samplerate = samplerate
|
324 |
+
self.segment = segment
|
325 |
+
self.encoder = nn.ModuleList()
|
326 |
+
self.decoder = nn.ModuleList()
|
327 |
+
self.skip_scales = nn.ModuleList()
|
328 |
+
|
329 |
+
if glu:
|
330 |
+
activation = nn.GLU(dim=1)
|
331 |
+
ch_scale = 2
|
332 |
+
else:
|
333 |
+
activation = nn.ReLU()
|
334 |
+
ch_scale = 1
|
335 |
+
if gelu:
|
336 |
+
act2 = nn.GELU
|
337 |
+
else:
|
338 |
+
act2 = nn.ReLU
|
339 |
+
|
340 |
+
in_channels = audio_channels
|
341 |
+
padding = 0
|
342 |
+
for index in range(depth):
|
343 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
344 |
+
if index >= norm_starts:
|
345 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
346 |
+
|
347 |
+
encode = []
|
348 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
|
349 |
+
attn = index >= dconv_attn
|
350 |
+
lstm = index >= dconv_lstm
|
351 |
+
if dconv_mode & 1:
|
352 |
+
encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
353 |
+
if rewrite:
|
354 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
|
355 |
+
self.encoder.append(nn.Sequential(*encode))
|
356 |
+
|
357 |
+
decode = []
|
358 |
+
if index > 0:
|
359 |
+
out_channels = in_channels
|
360 |
+
else:
|
361 |
+
out_channels = len(self.sources) * audio_channels
|
362 |
+
if rewrite:
|
363 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
|
364 |
+
if dconv_mode & 2:
|
365 |
+
decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
366 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
|
367 |
+
if index > 0:
|
368 |
+
decode += [norm_fn(out_channels), act2()]
|
369 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
370 |
+
in_channels = channels
|
371 |
+
channels = int(growth * channels)
|
372 |
+
|
373 |
+
channels = in_channels
|
374 |
+
if lstm_layers:
|
375 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
376 |
+
else:
|
377 |
+
self.lstm = None
|
378 |
+
|
379 |
+
if rescale:
|
380 |
+
rescale_module(self, reference=rescale)
|
381 |
+
|
382 |
+
def valid_length(self, length):
|
383 |
+
"""
|
384 |
+
Return the nearest valid length to use with the model so that
|
385 |
+
there is no time steps left over in a convolution, e.g. for all
|
386 |
+
layers, size of the input - kernel_size % stride = 0.
|
387 |
+
|
388 |
+
Note that input are automatically padded if necessary to ensure that the output
|
389 |
+
has the same length as the input.
|
390 |
+
"""
|
391 |
+
if self.resample:
|
392 |
+
length *= 2
|
393 |
+
|
394 |
+
for _ in range(self.depth):
|
395 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
396 |
+
length = max(1, length)
|
397 |
+
|
398 |
+
for idx in range(self.depth):
|
399 |
+
length = (length - 1) * self.stride + self.kernel_size
|
400 |
+
|
401 |
+
if self.resample:
|
402 |
+
length = math.ceil(length / 2)
|
403 |
+
return int(length)
|
404 |
+
|
405 |
+
def forward(self, mix):
|
406 |
+
x = mix
|
407 |
+
length = x.shape[-1]
|
408 |
+
|
409 |
+
if self.normalize:
|
410 |
+
mono = mix.mean(dim=1, keepdim=True)
|
411 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
412 |
+
std = mono.std(dim=-1, keepdim=True)
|
413 |
+
x = (x - mean) / (1e-5 + std)
|
414 |
+
else:
|
415 |
+
mean = 0
|
416 |
+
std = 1
|
417 |
+
|
418 |
+
delta = self.valid_length(length) - length
|
419 |
+
x = F.pad(x, (delta // 2, delta - delta // 2))
|
420 |
+
|
421 |
+
if self.resample:
|
422 |
+
x = julius.resample_frac(x, 1, 2)
|
423 |
+
|
424 |
+
saved = []
|
425 |
+
for encode in self.encoder:
|
426 |
+
x = encode(x)
|
427 |
+
saved.append(x)
|
428 |
+
|
429 |
+
if self.lstm:
|
430 |
+
x = self.lstm(x)
|
431 |
+
|
432 |
+
for decode in self.decoder:
|
433 |
+
skip = saved.pop(-1)
|
434 |
+
skip = center_trim(skip, x)
|
435 |
+
x = decode(x + skip)
|
436 |
+
|
437 |
+
if self.resample:
|
438 |
+
x = julius.resample_frac(x, 2, 1)
|
439 |
+
x = x * std + mean
|
440 |
+
x = center_trim(x, length)
|
441 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
442 |
+
return x
|
443 |
+
|
444 |
+
def load_state_dict(self, state, strict=True):
|
445 |
+
# fix a mismatch with previous generation Demucs models.
|
446 |
+
for idx in range(self.depth):
|
447 |
+
for a in ["encoder", "decoder"]:
|
448 |
+
for b in ["bias", "weight"]:
|
449 |
+
new = f"{a}.{idx}.3.{b}"
|
450 |
+
old = f"{a}.{idx}.2.{b}"
|
451 |
+
if old in state and new not in state:
|
452 |
+
state[new] = state.pop(old)
|
453 |
+
super().load_state_dict(state, strict=strict)
|
audio_separator/separator/uvr_lib_v5/demucs/filtering.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
|
8 |
+
def atan2(y, x):
|
9 |
+
r"""Element-wise arctangent function of y/x.
|
10 |
+
Returns a new tensor with signed angles in radians.
|
11 |
+
It is an alternative implementation of torch.atan2
|
12 |
+
|
13 |
+
Args:
|
14 |
+
y (Tensor): First input tensor
|
15 |
+
x (Tensor): Second input tensor [shape=y.shape]
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Tensor: [shape=y.shape].
|
19 |
+
"""
|
20 |
+
pi = 2 * torch.asin(torch.tensor(1.0))
|
21 |
+
x += ((x == 0) & (y == 0)) * 1.0
|
22 |
+
out = torch.atan(y / x)
|
23 |
+
out += ((y >= 0) & (x < 0)) * pi
|
24 |
+
out -= ((y < 0) & (x < 0)) * pi
|
25 |
+
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
26 |
+
out += ((y > 0) & (x == 0)) * (pi / 2)
|
27 |
+
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
28 |
+
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
# Define basic complex operations on torch.Tensor objects whose last dimension
|
33 |
+
# consists in the concatenation of the real and imaginary parts.
|
34 |
+
|
35 |
+
|
36 |
+
def _norm(x: torch.Tensor) -> torch.Tensor:
|
37 |
+
r"""Computes the norm value of a torch Tensor, assuming that it
|
38 |
+
comes as real and imaginary part in its last dimension.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
x (Tensor): Input Tensor of shape [shape=(..., 2)]
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tensor: shape as x excluding the last dimension.
|
45 |
+
"""
|
46 |
+
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
47 |
+
|
48 |
+
|
49 |
+
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
50 |
+
"""Element-wise multiplication of two complex Tensors described
|
51 |
+
through their real and imaginary parts.
|
52 |
+
The result is added to the `out` tensor"""
|
53 |
+
|
54 |
+
# check `out` and allocate it if needed
|
55 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
56 |
+
if out is None or out.shape != target_shape:
|
57 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
58 |
+
if out is a:
|
59 |
+
real_a = a[..., 0]
|
60 |
+
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
61 |
+
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
62 |
+
else:
|
63 |
+
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
64 |
+
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
69 |
+
"""Element-wise multiplication of two complex Tensors described
|
70 |
+
through their real and imaginary parts
|
71 |
+
can work in place in case out is a only"""
|
72 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
73 |
+
if out is None or out.shape != target_shape:
|
74 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
75 |
+
if out is a:
|
76 |
+
real_a = a[..., 0]
|
77 |
+
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
78 |
+
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
79 |
+
else:
|
80 |
+
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
81 |
+
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
86 |
+
"""Element-wise multiplicative inverse of a Tensor with complex
|
87 |
+
entries described through their real and imaginary parts.
|
88 |
+
can work in place in case out is z"""
|
89 |
+
ez = _norm(z)
|
90 |
+
if out is None or out.shape != z.shape:
|
91 |
+
out = torch.zeros_like(z)
|
92 |
+
out[..., 0] = z[..., 0] / ez
|
93 |
+
out[..., 1] = -z[..., 1] / ez
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
98 |
+
"""Element-wise complex conjugate of a Tensor with complex entries
|
99 |
+
described through their real and imaginary parts.
|
100 |
+
can work in place in case out is z"""
|
101 |
+
if out is None or out.shape != z.shape:
|
102 |
+
out = torch.zeros_like(z)
|
103 |
+
out[..., 0] = z[..., 0]
|
104 |
+
out[..., 1] = -z[..., 1]
|
105 |
+
return out
|
106 |
+
|
107 |
+
|
108 |
+
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
109 |
+
"""
|
110 |
+
Invert 1x1 or 2x2 matrices
|
111 |
+
|
112 |
+
Will generate errors if the matrices are singular: user must handle this
|
113 |
+
through his own regularization schemes.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
|
117 |
+
matrices to invert: must be square along dimensions -3 and -2
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
invM (Tensor): [shape=M.shape]
|
121 |
+
inverses of M
|
122 |
+
"""
|
123 |
+
nb_channels = M.shape[-2]
|
124 |
+
|
125 |
+
if out is None or out.shape != M.shape:
|
126 |
+
out = torch.empty_like(M)
|
127 |
+
|
128 |
+
if nb_channels == 1:
|
129 |
+
# scalar case
|
130 |
+
out = _inv(M, out)
|
131 |
+
elif nb_channels == 2:
|
132 |
+
# two channels case: analytical expression
|
133 |
+
|
134 |
+
# first compute the determinent
|
135 |
+
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
136 |
+
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
137 |
+
# invert it
|
138 |
+
invDet = _inv(det)
|
139 |
+
|
140 |
+
# then fill out the matrix with the inverse
|
141 |
+
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
142 |
+
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
143 |
+
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
144 |
+
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
145 |
+
else:
|
146 |
+
raise Exception("Only 2 channels are supported for the torch version.")
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
# Now define the signal-processing low-level functions used by the Separator
|
151 |
+
|
152 |
+
|
153 |
+
def expectation_maximization(y: torch.Tensor, x: torch.Tensor, iterations: int = 2, eps: float = 1e-10, batch_size: int = 200):
|
154 |
+
r"""Expectation maximization algorithm, for refining source separation
|
155 |
+
estimates.
|
156 |
+
|
157 |
+
This algorithm allows to make source separation results better by
|
158 |
+
enforcing multichannel consistency for the estimates. This usually means
|
159 |
+
a better perceptual quality in terms of spatial artifacts.
|
160 |
+
|
161 |
+
The implementation follows the details presented in [1]_, taking
|
162 |
+
inspiration from the original EM algorithm proposed in [2]_ and its
|
163 |
+
weighted refinement proposed in [3]_, [4]_.
|
164 |
+
It works by iteratively:
|
165 |
+
|
166 |
+
* Re-estimate source parameters (power spectral densities and spatial
|
167 |
+
covariance matrices) through :func:`get_local_gaussian_model`.
|
168 |
+
|
169 |
+
* Separate again the mixture with the new parameters by first computing
|
170 |
+
the new modelled mixture covariance matrices with :func:`get_mix_model`,
|
171 |
+
prepare the Wiener filters through :func:`wiener_gain` and apply them
|
172 |
+
with :func:`apply_filter``.
|
173 |
+
|
174 |
+
References
|
175 |
+
----------
|
176 |
+
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
177 |
+
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
178 |
+
on deep neural networks through data augmentation and network
|
179 |
+
blending." 2017 IEEE International Conference on Acoustics, Speech
|
180 |
+
and Signal Processing (ICASSP). IEEE, 2017.
|
181 |
+
|
182 |
+
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
|
183 |
+
reverberant audio source separation using a full-rank spatial
|
184 |
+
covariance model." IEEE Transactions on Audio, Speech, and Language
|
185 |
+
Processing 18.7 (2010): 1830-1840.
|
186 |
+
|
187 |
+
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
188 |
+
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
189 |
+
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
190 |
+
|
191 |
+
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
192 |
+
separation with deep neural networks." 2016 24th European Signal
|
193 |
+
Processing Conference (EUSIPCO). IEEE, 2016.
|
194 |
+
|
195 |
+
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
196 |
+
source separation." IEEE Transactions on Signal Processing
|
197 |
+
62.16 (2014): 4298-4310.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
201 |
+
initial estimates for the sources
|
202 |
+
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
|
203 |
+
complex STFT of the mixture signal
|
204 |
+
iterations (int): [scalar]
|
205 |
+
number of iterations for the EM algorithm.
|
206 |
+
eps (float or None): [scalar]
|
207 |
+
The epsilon value to use for regularization and filters.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
211 |
+
estimated sources after iterations
|
212 |
+
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
|
213 |
+
estimated power spectral densities
|
214 |
+
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
|
215 |
+
estimated spatial covariance matrices
|
216 |
+
|
217 |
+
Notes:
|
218 |
+
* You need an initial estimate for the sources to apply this
|
219 |
+
algorithm. This is precisely what the :func:`wiener` function does.
|
220 |
+
* This algorithm *is not* an implementation of the "exact" EM
|
221 |
+
proposed in [1]_. In particular, it does compute the posterior
|
222 |
+
covariance matrices the same (exact) way. Instead, it uses the
|
223 |
+
simplified approximate scheme initially proposed in [5]_ and further
|
224 |
+
refined in [3]_, [4]_, that boils down to just take the empirical
|
225 |
+
covariance of the recent source estimates, followed by a weighted
|
226 |
+
average for the update of the spatial covariance matrix. It has been
|
227 |
+
empirically demonstrated that this simplified algorithm is more
|
228 |
+
robust for music separation.
|
229 |
+
|
230 |
+
Warning:
|
231 |
+
It is *very* important to make sure `x.dtype` is `torch.float64`
|
232 |
+
if you want double precision, because this function will **not**
|
233 |
+
do such conversion for you from `torch.complex32`, in case you want the
|
234 |
+
smaller RAM usage on purpose.
|
235 |
+
|
236 |
+
It is usually always better in terms of quality to have double
|
237 |
+
precision, by e.g. calling :func:`expectation_maximization`
|
238 |
+
with ``x.to(torch.float64)``.
|
239 |
+
"""
|
240 |
+
# dimensions
|
241 |
+
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
242 |
+
nb_sources = y.shape[-1]
|
243 |
+
|
244 |
+
regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
|
245 |
+
regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
|
246 |
+
|
247 |
+
# allocate the spatial covariance matrices
|
248 |
+
R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
|
249 |
+
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
250 |
+
|
251 |
+
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
252 |
+
for it in range(iterations):
|
253 |
+
# constructing the mixture covariance matrix. Doing it with a loop
|
254 |
+
# to avoid storing anytime in RAM the whole 6D tensor
|
255 |
+
|
256 |
+
# update the PSD as the average spectrogram over channels
|
257 |
+
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
258 |
+
|
259 |
+
# update spatial covariance matrices (weighted update)
|
260 |
+
for j in range(nb_sources):
|
261 |
+
R[j] = torch.tensor(0.0, device=x.device)
|
262 |
+
weight = torch.tensor(eps, device=x.device)
|
263 |
+
pos: int = 0
|
264 |
+
batch_size = batch_size if batch_size else nb_frames
|
265 |
+
while pos < nb_frames:
|
266 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
267 |
+
pos = int(t[-1]) + 1
|
268 |
+
|
269 |
+
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
270 |
+
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
271 |
+
R[j] = R[j] / weight[..., None, None, None]
|
272 |
+
weight = torch.zeros_like(weight)
|
273 |
+
|
274 |
+
# cloning y if we track gradient, because we're going to update it
|
275 |
+
if y.requires_grad:
|
276 |
+
y = y.clone()
|
277 |
+
|
278 |
+
pos = 0
|
279 |
+
while pos < nb_frames:
|
280 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
281 |
+
pos = int(t[-1]) + 1
|
282 |
+
|
283 |
+
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
284 |
+
|
285 |
+
# compute mix covariance matrix
|
286 |
+
Cxx = regularization
|
287 |
+
for j in range(nb_sources):
|
288 |
+
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
289 |
+
|
290 |
+
# invert it
|
291 |
+
inv_Cxx = _invert(Cxx)
|
292 |
+
|
293 |
+
# separate the sources
|
294 |
+
for j in range(nb_sources):
|
295 |
+
|
296 |
+
# create a wiener gain for this source
|
297 |
+
gain = torch.zeros_like(inv_Cxx)
|
298 |
+
|
299 |
+
# computes multichannel Wiener gain as v_j R_j inv_Cxx
|
300 |
+
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
|
301 |
+
for index in indices:
|
302 |
+
gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
|
303 |
+
gain = gain * v[t, ..., None, None, None, j]
|
304 |
+
|
305 |
+
# apply it to the mixture
|
306 |
+
for i in range(nb_channels):
|
307 |
+
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
308 |
+
|
309 |
+
return y, v, R
|
310 |
+
|
311 |
+
|
312 |
+
def wiener(targets_spectrograms: torch.Tensor, mix_stft: torch.Tensor, iterations: int = 1, softmask: bool = False, residual: bool = False, scale_factor: float = 10.0, eps: float = 1e-10):
|
313 |
+
"""Wiener-based separation for multichannel audio.
|
314 |
+
|
315 |
+
The method uses the (possibly multichannel) spectrograms of the
|
316 |
+
sources to separate the (complex) Short Term Fourier Transform of the
|
317 |
+
mix. Separation is done in a sequential way by:
|
318 |
+
|
319 |
+
* Getting an initial estimate. This can be done in two ways: either by
|
320 |
+
directly using the spectrograms with the mixture phase, or
|
321 |
+
by using a softmasking strategy. This initial phase is controlled
|
322 |
+
by the `softmask` flag.
|
323 |
+
|
324 |
+
* If required, adding an additional residual target as the mix minus
|
325 |
+
all targets.
|
326 |
+
|
327 |
+
* Refinining these initial estimates through a call to
|
328 |
+
:func:`expectation_maximization` if the number of iterations is nonzero.
|
329 |
+
|
330 |
+
This implementation also allows to specify the epsilon value used for
|
331 |
+
regularization. It is based on [1]_, [2]_, [3]_, [4]_.
|
332 |
+
|
333 |
+
References
|
334 |
+
----------
|
335 |
+
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
336 |
+
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
337 |
+
on deep neural networks through data augmentation and network
|
338 |
+
blending." 2017 IEEE International Conference on Acoustics, Speech
|
339 |
+
and Signal Processing (ICASSP). IEEE, 2017.
|
340 |
+
|
341 |
+
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
342 |
+
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
343 |
+
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
344 |
+
|
345 |
+
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
346 |
+
separation with deep neural networks." 2016 24th European Signal
|
347 |
+
Processing Conference (EUSIPCO). IEEE, 2016.
|
348 |
+
|
349 |
+
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
350 |
+
source separation." IEEE Transactions on Signal Processing
|
351 |
+
62.16 (2014): 4298-4310.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
targets_spectrograms (Tensor): spectrograms of the sources
|
355 |
+
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
|
356 |
+
This is a nonnegative tensor that is
|
357 |
+
usually the output of the actual separation method of the user. The
|
358 |
+
spectrograms may be mono, but they need to be 4-dimensional in all
|
359 |
+
cases.
|
360 |
+
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
|
361 |
+
STFT of the mixture signal.
|
362 |
+
iterations (int): [scalar]
|
363 |
+
number of iterations for the EM algorithm
|
364 |
+
softmask (bool): Describes how the initial estimates are obtained.
|
365 |
+
* if `False`, then the mixture phase will directly be used with the
|
366 |
+
spectrogram as initial estimates.
|
367 |
+
* if `True`, initial estimates are obtained by multiplying the
|
368 |
+
complex mix element-wise with the ratio of each target spectrogram
|
369 |
+
with the sum of them all. This strategy is better if the model are
|
370 |
+
not really good, and worse otherwise.
|
371 |
+
residual (bool): if `True`, an additional target is created, which is
|
372 |
+
equal to the mixture minus the other targets, before application of
|
373 |
+
expectation maximization
|
374 |
+
eps (float): Epsilon value to use for computing the separations.
|
375 |
+
This is used whenever division with a model energy is
|
376 |
+
performed, i.e. when softmasking and when iterating the EM.
|
377 |
+
It can be understood as the energy of the additional white noise
|
378 |
+
that is taken out when separating.
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
|
382 |
+
STFT of estimated sources
|
383 |
+
|
384 |
+
Notes:
|
385 |
+
* Be careful that you need *magnitude spectrogram estimates* for the
|
386 |
+
case `softmask==False`.
|
387 |
+
* `softmask=False` is recommended
|
388 |
+
* The epsilon value will have a huge impact on performance. If it's
|
389 |
+
large, only the parts of the signal with a significant energy will
|
390 |
+
be kept in the sources. This epsilon then directly controls the
|
391 |
+
energy of the reconstruction error.
|
392 |
+
|
393 |
+
Warning:
|
394 |
+
As in :func:`expectation_maximization`, we recommend converting the
|
395 |
+
mixture `x` to double precision `torch.float64` *before* calling
|
396 |
+
:func:`wiener`.
|
397 |
+
"""
|
398 |
+
if softmask:
|
399 |
+
# if we use softmask, we compute the ratio mask for all targets and
|
400 |
+
# multiply by the mix stft
|
401 |
+
y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
|
402 |
+
else:
|
403 |
+
# otherwise, we just multiply the targets spectrograms with mix phase
|
404 |
+
# we tacitly assume that we have magnitude estimates.
|
405 |
+
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
406 |
+
nb_sources = targets_spectrograms.shape[-1]
|
407 |
+
y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
|
408 |
+
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
409 |
+
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
410 |
+
|
411 |
+
if residual:
|
412 |
+
# if required, adding an additional target as the mix minus
|
413 |
+
# available targets
|
414 |
+
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
415 |
+
|
416 |
+
if iterations == 0:
|
417 |
+
return y
|
418 |
+
|
419 |
+
# we need to refine the estimates. Scales down the estimates for
|
420 |
+
# numerical stability
|
421 |
+
max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
|
422 |
+
|
423 |
+
mix_stft = mix_stft / max_abs
|
424 |
+
y = y / max_abs
|
425 |
+
|
426 |
+
# call expectation maximization
|
427 |
+
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
428 |
+
|
429 |
+
# scale estimates up again
|
430 |
+
y = y * max_abs
|
431 |
+
return y
|
432 |
+
|
433 |
+
|
434 |
+
def _covariance(y_j):
|
435 |
+
"""
|
436 |
+
Compute the empirical covariance for a source.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
y_j (Tensor): complex stft of the source.
|
440 |
+
[shape=(nb_frames, nb_bins, nb_channels, 2)].
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
|
444 |
+
just y_j * conj(y_j.T): empirical covariance for each TF bin.
|
445 |
+
"""
|
446 |
+
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
447 |
+
Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
|
448 |
+
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
449 |
+
for index in indices:
|
450 |
+
Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
|
451 |
+
return Cj
|
audio_separator/separator/uvr_lib_v5/demucs/hdemucs.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
This code contains the spectrogram and Hybrid version of Demucs.
|
8 |
+
"""
|
9 |
+
from copy import deepcopy
|
10 |
+
import math
|
11 |
+
import typing as tp
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from .filtering import wiener
|
16 |
+
from .demucs import DConv, rescale_module
|
17 |
+
from .states import capture_init
|
18 |
+
from .spec import spectro, ispectro
|
19 |
+
|
20 |
+
|
21 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0):
|
22 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
23 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen."""
|
24 |
+
x0 = x
|
25 |
+
length = x.shape[-1]
|
26 |
+
padding_left, padding_right = paddings
|
27 |
+
if mode == "reflect":
|
28 |
+
max_pad = max(padding_left, padding_right)
|
29 |
+
if length <= max_pad:
|
30 |
+
extra_pad = max_pad - length + 1
|
31 |
+
extra_pad_right = min(padding_right, extra_pad)
|
32 |
+
extra_pad_left = extra_pad - extra_pad_right
|
33 |
+
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
34 |
+
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
35 |
+
out = F.pad(x, paddings, mode, value)
|
36 |
+
assert out.shape[-1] == length + padding_left + padding_right
|
37 |
+
assert (out[..., padding_left : padding_left + length] == x0).all()
|
38 |
+
return out
|
39 |
+
|
40 |
+
|
41 |
+
class ScaledEmbedding(nn.Module):
|
42 |
+
"""
|
43 |
+
Boost learning rate for embeddings (with `scale`).
|
44 |
+
Also, can make embeddings continuous with `smooth`.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth=False):
|
48 |
+
super().__init__()
|
49 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
50 |
+
if smooth:
|
51 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
52 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
53 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
54 |
+
self.embedding.weight.data[:] = weight
|
55 |
+
self.embedding.weight.data /= scale
|
56 |
+
self.scale = scale
|
57 |
+
|
58 |
+
@property
|
59 |
+
def weight(self):
|
60 |
+
return self.embedding.weight * self.scale
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
out = self.embedding(x) * self.scale
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class HEncLayer(nn.Module):
|
68 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
|
69 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
chin: number of input channels.
|
73 |
+
chout: number of output channels.
|
74 |
+
norm_groups: number of groups for group norm.
|
75 |
+
empty: used to make a layer with just the first conv. this is used
|
76 |
+
before merging the time and freq. branches.
|
77 |
+
freq: this is acting on frequencies.
|
78 |
+
dconv: insert DConv residual branches.
|
79 |
+
norm: use GroupNorm.
|
80 |
+
context: context size for the 1x1 conv.
|
81 |
+
dconv_kw: list of kwargs for the DConv class.
|
82 |
+
pad: pad the input. Padding is done so that the output size is
|
83 |
+
always the input size / stride.
|
84 |
+
rewrite: add 1x1 conv at the end of the layer.
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
88 |
+
if norm:
|
89 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
90 |
+
if pad:
|
91 |
+
pad = kernel_size // 4
|
92 |
+
else:
|
93 |
+
pad = 0
|
94 |
+
klass = nn.Conv1d
|
95 |
+
self.freq = freq
|
96 |
+
self.kernel_size = kernel_size
|
97 |
+
self.stride = stride
|
98 |
+
self.empty = empty
|
99 |
+
self.norm = norm
|
100 |
+
self.pad = pad
|
101 |
+
if freq:
|
102 |
+
kernel_size = [kernel_size, 1]
|
103 |
+
stride = [stride, 1]
|
104 |
+
pad = [pad, 0]
|
105 |
+
klass = nn.Conv2d
|
106 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
107 |
+
if self.empty:
|
108 |
+
return
|
109 |
+
self.norm1 = norm_fn(chout)
|
110 |
+
self.rewrite = None
|
111 |
+
if rewrite:
|
112 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
113 |
+
self.norm2 = norm_fn(2 * chout)
|
114 |
+
|
115 |
+
self.dconv = None
|
116 |
+
if dconv:
|
117 |
+
self.dconv = DConv(chout, **dconv_kw)
|
118 |
+
|
119 |
+
def forward(self, x, inject=None):
|
120 |
+
"""
|
121 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
122 |
+
when both have the same stride.
|
123 |
+
"""
|
124 |
+
if not self.freq and x.dim() == 4:
|
125 |
+
B, C, Fr, T = x.shape
|
126 |
+
x = x.view(B, -1, T)
|
127 |
+
|
128 |
+
if not self.freq:
|
129 |
+
le = x.shape[-1]
|
130 |
+
if not le % self.stride == 0:
|
131 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
132 |
+
y = self.conv(x)
|
133 |
+
if self.empty:
|
134 |
+
return y
|
135 |
+
if inject is not None:
|
136 |
+
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
137 |
+
if inject.dim() == 3 and y.dim() == 4:
|
138 |
+
inject = inject[:, :, None]
|
139 |
+
y = y + inject
|
140 |
+
y = F.gelu(self.norm1(y))
|
141 |
+
if self.dconv:
|
142 |
+
if self.freq:
|
143 |
+
B, C, Fr, T = y.shape
|
144 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
145 |
+
y = self.dconv(y)
|
146 |
+
if self.freq:
|
147 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
148 |
+
if self.rewrite:
|
149 |
+
z = self.norm2(self.rewrite(y))
|
150 |
+
z = F.glu(z, dim=1)
|
151 |
+
else:
|
152 |
+
z = y
|
153 |
+
return z
|
154 |
+
|
155 |
+
|
156 |
+
class MultiWrap(nn.Module):
|
157 |
+
"""
|
158 |
+
Takes one layer and replicate it N times. each replica will act
|
159 |
+
on a frequency band. All is done so that if the N replica have the same weights,
|
160 |
+
then this is exactly equivalent to applying the original module on all frequencies.
|
161 |
+
|
162 |
+
This is a bit over-engineered to avoid edge artifacts when splitting
|
163 |
+
the frequency bands, but it is possible the naive implementation would work as well...
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, layer, split_ratios):
|
167 |
+
"""
|
168 |
+
Args:
|
169 |
+
layer: module to clone, must be either HEncLayer or HDecLayer.
|
170 |
+
split_ratios: list of float indicating which ratio to keep for each band.
|
171 |
+
"""
|
172 |
+
super().__init__()
|
173 |
+
self.split_ratios = split_ratios
|
174 |
+
self.layers = nn.ModuleList()
|
175 |
+
self.conv = isinstance(layer, HEncLayer)
|
176 |
+
assert not layer.norm
|
177 |
+
assert layer.freq
|
178 |
+
assert layer.pad
|
179 |
+
if not self.conv:
|
180 |
+
assert not layer.context_freq
|
181 |
+
for k in range(len(split_ratios) + 1):
|
182 |
+
lay = deepcopy(layer)
|
183 |
+
if self.conv:
|
184 |
+
lay.conv.padding = (0, 0)
|
185 |
+
else:
|
186 |
+
lay.pad = False
|
187 |
+
for m in lay.modules():
|
188 |
+
if hasattr(m, "reset_parameters"):
|
189 |
+
m.reset_parameters()
|
190 |
+
self.layers.append(lay)
|
191 |
+
|
192 |
+
def forward(self, x, skip=None, length=None):
|
193 |
+
B, C, Fr, T = x.shape
|
194 |
+
|
195 |
+
ratios = list(self.split_ratios) + [1]
|
196 |
+
start = 0
|
197 |
+
outs = []
|
198 |
+
for ratio, layer in zip(ratios, self.layers):
|
199 |
+
if self.conv:
|
200 |
+
pad = layer.kernel_size // 4
|
201 |
+
if ratio == 1:
|
202 |
+
limit = Fr
|
203 |
+
frames = -1
|
204 |
+
else:
|
205 |
+
limit = int(round(Fr * ratio))
|
206 |
+
le = limit - start
|
207 |
+
if start == 0:
|
208 |
+
le += pad
|
209 |
+
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
210 |
+
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
211 |
+
if start == 0:
|
212 |
+
limit -= pad
|
213 |
+
assert limit - start > 0, (limit, start)
|
214 |
+
assert limit <= Fr, (limit, Fr)
|
215 |
+
y = x[:, :, start:limit, :]
|
216 |
+
if start == 0:
|
217 |
+
y = F.pad(y, (0, 0, pad, 0))
|
218 |
+
if ratio == 1:
|
219 |
+
y = F.pad(y, (0, 0, 0, pad))
|
220 |
+
outs.append(layer(y))
|
221 |
+
start = limit - layer.kernel_size + layer.stride
|
222 |
+
else:
|
223 |
+
if ratio == 1:
|
224 |
+
limit = Fr
|
225 |
+
else:
|
226 |
+
limit = int(round(Fr * ratio))
|
227 |
+
last = layer.last
|
228 |
+
layer.last = True
|
229 |
+
|
230 |
+
y = x[:, :, start:limit]
|
231 |
+
s = skip[:, :, start:limit]
|
232 |
+
out, _ = layer(y, s, None)
|
233 |
+
if outs:
|
234 |
+
outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
|
235 |
+
out = out[:, :, layer.stride :]
|
236 |
+
if ratio == 1:
|
237 |
+
out = out[:, :, : -layer.stride // 2, :]
|
238 |
+
if start == 0:
|
239 |
+
out = out[:, :, layer.stride // 2 :, :]
|
240 |
+
outs.append(out)
|
241 |
+
layer.last = last
|
242 |
+
start = limit
|
243 |
+
out = torch.cat(outs, dim=2)
|
244 |
+
if not self.conv and not last:
|
245 |
+
out = F.gelu(out)
|
246 |
+
if self.conv:
|
247 |
+
return out
|
248 |
+
else:
|
249 |
+
return out, None
|
250 |
+
|
251 |
+
|
252 |
+
class HDecLayer(nn.Module):
|
253 |
+
def __init__(
|
254 |
+
self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True
|
255 |
+
):
|
256 |
+
"""
|
257 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
258 |
+
"""
|
259 |
+
super().__init__()
|
260 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
261 |
+
if norm:
|
262 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
263 |
+
if pad:
|
264 |
+
pad = kernel_size // 4
|
265 |
+
else:
|
266 |
+
pad = 0
|
267 |
+
self.pad = pad
|
268 |
+
self.last = last
|
269 |
+
self.freq = freq
|
270 |
+
self.chin = chin
|
271 |
+
self.empty = empty
|
272 |
+
self.stride = stride
|
273 |
+
self.kernel_size = kernel_size
|
274 |
+
self.norm = norm
|
275 |
+
self.context_freq = context_freq
|
276 |
+
klass = nn.Conv1d
|
277 |
+
klass_tr = nn.ConvTranspose1d
|
278 |
+
if freq:
|
279 |
+
kernel_size = [kernel_size, 1]
|
280 |
+
stride = [stride, 1]
|
281 |
+
klass = nn.Conv2d
|
282 |
+
klass_tr = nn.ConvTranspose2d
|
283 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
284 |
+
self.norm2 = norm_fn(chout)
|
285 |
+
if self.empty:
|
286 |
+
return
|
287 |
+
self.rewrite = None
|
288 |
+
if rewrite:
|
289 |
+
if context_freq:
|
290 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
291 |
+
else:
|
292 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
|
293 |
+
self.norm1 = norm_fn(2 * chin)
|
294 |
+
|
295 |
+
self.dconv = None
|
296 |
+
if dconv:
|
297 |
+
self.dconv = DConv(chin, **dconv_kw)
|
298 |
+
|
299 |
+
def forward(self, x, skip, length):
|
300 |
+
if self.freq and x.dim() == 3:
|
301 |
+
B, C, T = x.shape
|
302 |
+
x = x.view(B, self.chin, -1, T)
|
303 |
+
|
304 |
+
if not self.empty:
|
305 |
+
x = x + skip
|
306 |
+
|
307 |
+
if self.rewrite:
|
308 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
309 |
+
else:
|
310 |
+
y = x
|
311 |
+
if self.dconv:
|
312 |
+
if self.freq:
|
313 |
+
B, C, Fr, T = y.shape
|
314 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
315 |
+
y = self.dconv(y)
|
316 |
+
if self.freq:
|
317 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
318 |
+
else:
|
319 |
+
y = x
|
320 |
+
assert skip is None
|
321 |
+
z = self.norm2(self.conv_tr(y))
|
322 |
+
if self.freq:
|
323 |
+
if self.pad:
|
324 |
+
z = z[..., self.pad : -self.pad, :]
|
325 |
+
else:
|
326 |
+
z = z[..., self.pad : self.pad + length]
|
327 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
328 |
+
if not self.last:
|
329 |
+
z = F.gelu(z)
|
330 |
+
return z, y
|
331 |
+
|
332 |
+
|
333 |
+
class HDemucs(nn.Module):
|
334 |
+
"""
|
335 |
+
Spectrogram and hybrid Demucs model.
|
336 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
337 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
338 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
339 |
+
|
340 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
341 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
342 |
+
|
343 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
344 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
345 |
+
Open Unmix implementation [Stoter et al. 2019].
|
346 |
+
|
347 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
348 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
349 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
350 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
351 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
352 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
353 |
+
hybrid models.
|
354 |
+
|
355 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
356 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
357 |
+
|
358 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
359 |
+
"""
|
360 |
+
|
361 |
+
@capture_init
|
362 |
+
def __init__(
|
363 |
+
self,
|
364 |
+
sources,
|
365 |
+
# Channels
|
366 |
+
audio_channels=2,
|
367 |
+
channels=48,
|
368 |
+
channels_time=None,
|
369 |
+
growth=2,
|
370 |
+
# STFT
|
371 |
+
nfft=4096,
|
372 |
+
wiener_iters=0,
|
373 |
+
end_iters=0,
|
374 |
+
wiener_residual=False,
|
375 |
+
cac=True,
|
376 |
+
# Main structure
|
377 |
+
depth=6,
|
378 |
+
rewrite=True,
|
379 |
+
hybrid=True,
|
380 |
+
hybrid_old=False,
|
381 |
+
# Frequency branch
|
382 |
+
multi_freqs=None,
|
383 |
+
multi_freqs_depth=2,
|
384 |
+
freq_emb=0.2,
|
385 |
+
emb_scale=10,
|
386 |
+
emb_smooth=True,
|
387 |
+
# Convolutions
|
388 |
+
kernel_size=8,
|
389 |
+
time_stride=2,
|
390 |
+
stride=4,
|
391 |
+
context=1,
|
392 |
+
context_enc=0,
|
393 |
+
# Normalization
|
394 |
+
norm_starts=4,
|
395 |
+
norm_groups=4,
|
396 |
+
# DConv residual branch
|
397 |
+
dconv_mode=1,
|
398 |
+
dconv_depth=2,
|
399 |
+
dconv_comp=4,
|
400 |
+
dconv_attn=4,
|
401 |
+
dconv_lstm=4,
|
402 |
+
dconv_init=1e-4,
|
403 |
+
# Weight init
|
404 |
+
rescale=0.1,
|
405 |
+
# Metadata
|
406 |
+
samplerate=44100,
|
407 |
+
segment=4 * 10,
|
408 |
+
):
|
409 |
+
"""
|
410 |
+
Args:
|
411 |
+
sources (list[str]): list of source names.
|
412 |
+
audio_channels (int): input/output audio channels.
|
413 |
+
channels (int): initial number of hidden channels.
|
414 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
415 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
416 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
417 |
+
various shape parameters and will not work out of the box for hybrid models.
|
418 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
419 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
420 |
+
wiener_residual: add residual source before wiener filtering.
|
421 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
422 |
+
in input and output. no further processing is done before ISTFT.
|
423 |
+
depth (int): number of layers in the encoder and in the decoder.
|
424 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
425 |
+
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
426 |
+
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
427 |
+
this bug to avoid retraining them.
|
428 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
429 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
430 |
+
layers will be wrapped.
|
431 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
432 |
+
the actual value controls the weight of the embedding.
|
433 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
434 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
435 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
436 |
+
stride: stride for encoder and decoder layers.
|
437 |
+
time_stride: stride for the final time layer, after the merge.
|
438 |
+
context: context for 1x1 conv in the decoder.
|
439 |
+
context_enc: context for 1x1 conv in the encoder.
|
440 |
+
norm_starts: layer at which group norm starts being used.
|
441 |
+
decoder layers are numbered in reverse order.
|
442 |
+
norm_groups: number of groups for group norm.
|
443 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
444 |
+
dconv_depth: depth of residual DConv branch.
|
445 |
+
dconv_comp: compression of DConv branch.
|
446 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
447 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
448 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
449 |
+
rescale: weight recaling trick
|
450 |
+
|
451 |
+
"""
|
452 |
+
super().__init__()
|
453 |
+
|
454 |
+
self.cac = cac
|
455 |
+
self.wiener_residual = wiener_residual
|
456 |
+
self.audio_channels = audio_channels
|
457 |
+
self.sources = sources
|
458 |
+
self.kernel_size = kernel_size
|
459 |
+
self.context = context
|
460 |
+
self.stride = stride
|
461 |
+
self.depth = depth
|
462 |
+
self.channels = channels
|
463 |
+
self.samplerate = samplerate
|
464 |
+
self.segment = segment
|
465 |
+
|
466 |
+
self.nfft = nfft
|
467 |
+
self.hop_length = nfft // 4
|
468 |
+
self.wiener_iters = wiener_iters
|
469 |
+
self.end_iters = end_iters
|
470 |
+
self.freq_emb = None
|
471 |
+
self.hybrid = hybrid
|
472 |
+
self.hybrid_old = hybrid_old
|
473 |
+
if hybrid_old:
|
474 |
+
assert hybrid, "hybrid_old must come with hybrid=True"
|
475 |
+
if hybrid:
|
476 |
+
assert wiener_iters == end_iters
|
477 |
+
|
478 |
+
self.encoder = nn.ModuleList()
|
479 |
+
self.decoder = nn.ModuleList()
|
480 |
+
|
481 |
+
if hybrid:
|
482 |
+
self.tencoder = nn.ModuleList()
|
483 |
+
self.tdecoder = nn.ModuleList()
|
484 |
+
|
485 |
+
chin = audio_channels
|
486 |
+
chin_z = chin # number of channels for the freq branch
|
487 |
+
if self.cac:
|
488 |
+
chin_z *= 2
|
489 |
+
chout = channels_time or channels
|
490 |
+
chout_z = channels
|
491 |
+
freqs = nfft // 2
|
492 |
+
|
493 |
+
for index in range(depth):
|
494 |
+
lstm = index >= dconv_lstm
|
495 |
+
attn = index >= dconv_attn
|
496 |
+
norm = index >= norm_starts
|
497 |
+
freq = freqs > 1
|
498 |
+
stri = stride
|
499 |
+
ker = kernel_size
|
500 |
+
if not freq:
|
501 |
+
assert freqs == 1
|
502 |
+
ker = time_stride * 2
|
503 |
+
stri = time_stride
|
504 |
+
|
505 |
+
pad = True
|
506 |
+
last_freq = False
|
507 |
+
if freq and freqs <= kernel_size:
|
508 |
+
ker = freqs
|
509 |
+
pad = False
|
510 |
+
last_freq = True
|
511 |
+
|
512 |
+
kw = {
|
513 |
+
"kernel_size": ker,
|
514 |
+
"stride": stri,
|
515 |
+
"freq": freq,
|
516 |
+
"pad": pad,
|
517 |
+
"norm": norm,
|
518 |
+
"rewrite": rewrite,
|
519 |
+
"norm_groups": norm_groups,
|
520 |
+
"dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
521 |
+
}
|
522 |
+
kwt = dict(kw)
|
523 |
+
kwt["freq"] = 0
|
524 |
+
kwt["kernel_size"] = kernel_size
|
525 |
+
kwt["stride"] = stride
|
526 |
+
kwt["pad"] = True
|
527 |
+
kw_dec = dict(kw)
|
528 |
+
multi = False
|
529 |
+
if multi_freqs and index < multi_freqs_depth:
|
530 |
+
multi = True
|
531 |
+
kw_dec["context_freq"] = False
|
532 |
+
|
533 |
+
if last_freq:
|
534 |
+
chout_z = max(chout, chout_z)
|
535 |
+
chout = chout_z
|
536 |
+
|
537 |
+
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
538 |
+
if hybrid and freq:
|
539 |
+
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
540 |
+
self.tencoder.append(tenc)
|
541 |
+
|
542 |
+
if multi:
|
543 |
+
enc = MultiWrap(enc, multi_freqs)
|
544 |
+
self.encoder.append(enc)
|
545 |
+
if index == 0:
|
546 |
+
chin = self.audio_channels * len(self.sources)
|
547 |
+
chin_z = chin
|
548 |
+
if self.cac:
|
549 |
+
chin_z *= 2
|
550 |
+
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
551 |
+
if multi:
|
552 |
+
dec = MultiWrap(dec, multi_freqs)
|
553 |
+
if hybrid and freq:
|
554 |
+
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
555 |
+
self.tdecoder.insert(0, tdec)
|
556 |
+
self.decoder.insert(0, dec)
|
557 |
+
|
558 |
+
chin = chout
|
559 |
+
chin_z = chout_z
|
560 |
+
chout = int(growth * chout)
|
561 |
+
chout_z = int(growth * chout_z)
|
562 |
+
if freq:
|
563 |
+
if freqs <= kernel_size:
|
564 |
+
freqs = 1
|
565 |
+
else:
|
566 |
+
freqs //= stride
|
567 |
+
if index == 0 and freq_emb:
|
568 |
+
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
569 |
+
self.freq_emb_scale = freq_emb
|
570 |
+
|
571 |
+
if rescale:
|
572 |
+
rescale_module(self, reference=rescale)
|
573 |
+
|
574 |
+
def _spec(self, x):
|
575 |
+
hl = self.hop_length
|
576 |
+
nfft = self.nfft
|
577 |
+
x0 = x # noqa
|
578 |
+
|
579 |
+
if self.hybrid:
|
580 |
+
# We re-pad the signal in order to keep the property
|
581 |
+
# that the size of the output is exactly the size of the input
|
582 |
+
# divided by the stride (here hop_length), when divisible.
|
583 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
584 |
+
# which is not supported by torch.stft.
|
585 |
+
# Having all convolution operations follow this convention allow to easily
|
586 |
+
# align the time and frequency branches later on.
|
587 |
+
assert hl == nfft // 4
|
588 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
589 |
+
pad = hl // 2 * 3
|
590 |
+
if not self.hybrid_old:
|
591 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
592 |
+
else:
|
593 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
|
594 |
+
|
595 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
596 |
+
if self.hybrid:
|
597 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
598 |
+
z = z[..., 2 : 2 + le]
|
599 |
+
return z
|
600 |
+
|
601 |
+
def _ispec(self, z, length=None, scale=0):
|
602 |
+
hl = self.hop_length // (4**scale)
|
603 |
+
z = F.pad(z, (0, 0, 0, 1))
|
604 |
+
if self.hybrid:
|
605 |
+
z = F.pad(z, (2, 2))
|
606 |
+
pad = hl // 2 * 3
|
607 |
+
if not self.hybrid_old:
|
608 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
609 |
+
else:
|
610 |
+
le = hl * int(math.ceil(length / hl))
|
611 |
+
x = ispectro(z, hl, length=le)
|
612 |
+
if not self.hybrid_old:
|
613 |
+
x = x[..., pad : pad + length]
|
614 |
+
else:
|
615 |
+
x = x[..., :length]
|
616 |
+
else:
|
617 |
+
x = ispectro(z, hl, length)
|
618 |
+
return x
|
619 |
+
|
620 |
+
def _magnitude(self, z):
|
621 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
622 |
+
# in which case we just move the complex dimension to the channel one.
|
623 |
+
if self.cac:
|
624 |
+
B, C, Fr, T = z.shape
|
625 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
626 |
+
m = m.reshape(B, C * 2, Fr, T)
|
627 |
+
else:
|
628 |
+
m = z.abs()
|
629 |
+
return m
|
630 |
+
|
631 |
+
def _mask(self, z, m):
|
632 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
633 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
634 |
+
niters = self.wiener_iters
|
635 |
+
if self.cac:
|
636 |
+
B, S, C, Fr, T = m.shape
|
637 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
638 |
+
out = torch.view_as_complex(out.contiguous())
|
639 |
+
return out
|
640 |
+
if self.training:
|
641 |
+
niters = self.end_iters
|
642 |
+
if niters < 0:
|
643 |
+
z = z[:, None]
|
644 |
+
return z / (1e-8 + z.abs()) * m
|
645 |
+
else:
|
646 |
+
return self._wiener(m, z, niters)
|
647 |
+
|
648 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
649 |
+
# apply wiener filtering from OpenUnmix.
|
650 |
+
init = mix_stft.dtype
|
651 |
+
wiener_win_len = 300
|
652 |
+
residual = self.wiener_residual
|
653 |
+
|
654 |
+
B, S, C, Fq, T = mag_out.shape
|
655 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
656 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
657 |
+
|
658 |
+
outs = []
|
659 |
+
for sample in range(B):
|
660 |
+
pos = 0
|
661 |
+
out = []
|
662 |
+
for pos in range(0, T, wiener_win_len):
|
663 |
+
frame = slice(pos, pos + wiener_win_len)
|
664 |
+
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
665 |
+
out.append(z_out.transpose(-1, -2))
|
666 |
+
outs.append(torch.cat(out, dim=0))
|
667 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
668 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
669 |
+
if residual:
|
670 |
+
out = out[:, :-1]
|
671 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
672 |
+
return out.to(init)
|
673 |
+
|
674 |
+
def forward(self, mix):
|
675 |
+
x = mix
|
676 |
+
length = x.shape[-1]
|
677 |
+
|
678 |
+
z = self._spec(mix)
|
679 |
+
mag = self._magnitude(z).to(mix.device)
|
680 |
+
x = mag
|
681 |
+
|
682 |
+
B, C, Fq, T = x.shape
|
683 |
+
|
684 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
685 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
686 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
687 |
+
x = (x - mean) / (1e-5 + std)
|
688 |
+
# x will be the freq. branch input.
|
689 |
+
|
690 |
+
if self.hybrid:
|
691 |
+
# Prepare the time branch input.
|
692 |
+
xt = mix
|
693 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
694 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
695 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
696 |
+
|
697 |
+
# okay, this is a giant mess I know...
|
698 |
+
saved = [] # skip connections, freq.
|
699 |
+
saved_t = [] # skip connections, time.
|
700 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
701 |
+
lengths_t = [] # saved lengths for time branch.
|
702 |
+
for idx, encode in enumerate(self.encoder):
|
703 |
+
lengths.append(x.shape[-1])
|
704 |
+
inject = None
|
705 |
+
if self.hybrid and idx < len(self.tencoder):
|
706 |
+
# we have not yet merged branches.
|
707 |
+
lengths_t.append(xt.shape[-1])
|
708 |
+
tenc = self.tencoder[idx]
|
709 |
+
xt = tenc(xt)
|
710 |
+
if not tenc.empty:
|
711 |
+
# save for skip connection
|
712 |
+
saved_t.append(xt)
|
713 |
+
else:
|
714 |
+
# tenc contains just the first conv., so that now time and freq.
|
715 |
+
# branches have the same shape and can be merged.
|
716 |
+
inject = xt
|
717 |
+
x = encode(x, inject)
|
718 |
+
if idx == 0 and self.freq_emb is not None:
|
719 |
+
# add frequency embedding to allow for non equivariant convolutions
|
720 |
+
# over the frequency axis.
|
721 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
722 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
723 |
+
x = x + self.freq_emb_scale * emb
|
724 |
+
|
725 |
+
saved.append(x)
|
726 |
+
|
727 |
+
x = torch.zeros_like(x)
|
728 |
+
if self.hybrid:
|
729 |
+
xt = torch.zeros_like(x)
|
730 |
+
# initialize everything to zero (signal will go through u-net skips).
|
731 |
+
|
732 |
+
for idx, decode in enumerate(self.decoder):
|
733 |
+
skip = saved.pop(-1)
|
734 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
735 |
+
# `pre` contains the output just before final transposed convolution,
|
736 |
+
# which is used when the freq. and time branch separate.
|
737 |
+
|
738 |
+
if self.hybrid:
|
739 |
+
offset = self.depth - len(self.tdecoder)
|
740 |
+
if self.hybrid and idx >= offset:
|
741 |
+
tdec = self.tdecoder[idx - offset]
|
742 |
+
length_t = lengths_t.pop(-1)
|
743 |
+
if tdec.empty:
|
744 |
+
assert pre.shape[2] == 1, pre.shape
|
745 |
+
pre = pre[:, :, 0]
|
746 |
+
xt, _ = tdec(pre, None, length_t)
|
747 |
+
else:
|
748 |
+
skip = saved_t.pop(-1)
|
749 |
+
xt, _ = tdec(xt, skip, length_t)
|
750 |
+
|
751 |
+
# Let's make sure we used all stored skip connections.
|
752 |
+
assert len(saved) == 0
|
753 |
+
assert len(lengths_t) == 0
|
754 |
+
assert len(saved_t) == 0
|
755 |
+
|
756 |
+
S = len(self.sources)
|
757 |
+
x = x.view(B, S, -1, Fq, T)
|
758 |
+
x = x * std[:, None] + mean[:, None]
|
759 |
+
|
760 |
+
# to cpu as non-cuda GPUs don't support complex numbers
|
761 |
+
# demucs issue #435 ##432
|
762 |
+
# NOTE: in this case z already is on cpu
|
763 |
+
# TODO: remove this when mps supports complex numbers
|
764 |
+
|
765 |
+
device_type = x.device.type
|
766 |
+
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
767 |
+
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
768 |
+
|
769 |
+
if x_is_other_gpu:
|
770 |
+
x = x.cpu()
|
771 |
+
|
772 |
+
zout = self._mask(z, x)
|
773 |
+
x = self._ispec(zout, length)
|
774 |
+
|
775 |
+
# back to other device
|
776 |
+
if x_is_other_gpu:
|
777 |
+
x = x.to(device_load)
|
778 |
+
|
779 |
+
if self.hybrid:
|
780 |
+
xt = xt.view(B, S, -1, length)
|
781 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
782 |
+
x = xt + x
|
783 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/htdemucs.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# First author is Simon Rouard.
|
7 |
+
"""
|
8 |
+
This code contains the spectrogram and Hybrid version of Demucs.
|
9 |
+
"""
|
10 |
+
import math
|
11 |
+
|
12 |
+
from .filtering import wiener
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from fractions import Fraction
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
from .transformer import CrossTransformerEncoder
|
20 |
+
|
21 |
+
from .demucs import rescale_module
|
22 |
+
from .states import capture_init
|
23 |
+
from .spec import spectro, ispectro
|
24 |
+
from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
|
25 |
+
|
26 |
+
|
27 |
+
class HTDemucs(nn.Module):
|
28 |
+
"""
|
29 |
+
Spectrogram and hybrid Demucs model.
|
30 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
31 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
32 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
33 |
+
|
34 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
35 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
36 |
+
|
37 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
38 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
39 |
+
Open Unmix implementation [Stoter et al. 2019].
|
40 |
+
|
41 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
42 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
43 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
44 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
45 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
46 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
47 |
+
hybrid models.
|
48 |
+
|
49 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
50 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
51 |
+
|
52 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
53 |
+
"""
|
54 |
+
|
55 |
+
@capture_init
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
sources,
|
59 |
+
# Channels
|
60 |
+
audio_channels=2,
|
61 |
+
channels=48,
|
62 |
+
channels_time=None,
|
63 |
+
growth=2,
|
64 |
+
# STFT
|
65 |
+
nfft=4096,
|
66 |
+
wiener_iters=0,
|
67 |
+
end_iters=0,
|
68 |
+
wiener_residual=False,
|
69 |
+
cac=True,
|
70 |
+
# Main structure
|
71 |
+
depth=4,
|
72 |
+
rewrite=True,
|
73 |
+
# Frequency branch
|
74 |
+
multi_freqs=None,
|
75 |
+
multi_freqs_depth=3,
|
76 |
+
freq_emb=0.2,
|
77 |
+
emb_scale=10,
|
78 |
+
emb_smooth=True,
|
79 |
+
# Convolutions
|
80 |
+
kernel_size=8,
|
81 |
+
time_stride=2,
|
82 |
+
stride=4,
|
83 |
+
context=1,
|
84 |
+
context_enc=0,
|
85 |
+
# Normalization
|
86 |
+
norm_starts=4,
|
87 |
+
norm_groups=4,
|
88 |
+
# DConv residual branch
|
89 |
+
dconv_mode=1,
|
90 |
+
dconv_depth=2,
|
91 |
+
dconv_comp=8,
|
92 |
+
dconv_init=1e-3,
|
93 |
+
# Before the Transformer
|
94 |
+
bottom_channels=0,
|
95 |
+
# Transformer
|
96 |
+
t_layers=5,
|
97 |
+
t_emb="sin",
|
98 |
+
t_hidden_scale=4.0,
|
99 |
+
t_heads=8,
|
100 |
+
t_dropout=0.0,
|
101 |
+
t_max_positions=10000,
|
102 |
+
t_norm_in=True,
|
103 |
+
t_norm_in_group=False,
|
104 |
+
t_group_norm=False,
|
105 |
+
t_norm_first=True,
|
106 |
+
t_norm_out=True,
|
107 |
+
t_max_period=10000.0,
|
108 |
+
t_weight_decay=0.0,
|
109 |
+
t_lr=None,
|
110 |
+
t_layer_scale=True,
|
111 |
+
t_gelu=True,
|
112 |
+
t_weight_pos_embed=1.0,
|
113 |
+
t_sin_random_shift=0,
|
114 |
+
t_cape_mean_normalize=True,
|
115 |
+
t_cape_augment=True,
|
116 |
+
t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
|
117 |
+
t_sparse_self_attn=False,
|
118 |
+
t_sparse_cross_attn=False,
|
119 |
+
t_mask_type="diag",
|
120 |
+
t_mask_random_seed=42,
|
121 |
+
t_sparse_attn_window=500,
|
122 |
+
t_global_window=100,
|
123 |
+
t_sparsity=0.95,
|
124 |
+
t_auto_sparsity=False,
|
125 |
+
# ------ Particuliar parameters
|
126 |
+
t_cross_first=False,
|
127 |
+
# Weight init
|
128 |
+
rescale=0.1,
|
129 |
+
# Metadata
|
130 |
+
samplerate=44100,
|
131 |
+
segment=10,
|
132 |
+
use_train_segment=True,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
sources (list[str]): list of source names.
|
137 |
+
audio_channels (int): input/output audio channels.
|
138 |
+
channels (int): initial number of hidden channels.
|
139 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
140 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
141 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
142 |
+
various shape parameters and will not work out of the box for hybrid models.
|
143 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
144 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
145 |
+
wiener_residual: add residual source before wiener filtering.
|
146 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
147 |
+
in input and output. no further processing is done before ISTFT.
|
148 |
+
depth (int): number of layers in the encoder and in the decoder.
|
149 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
150 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
151 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
152 |
+
layers will be wrapped.
|
153 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
154 |
+
the actual value controls the weight of the embedding.
|
155 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
156 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
157 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
158 |
+
stride: stride for encoder and decoder layers.
|
159 |
+
time_stride: stride for the final time layer, after the merge.
|
160 |
+
context: context for 1x1 conv in the decoder.
|
161 |
+
context_enc: context for 1x1 conv in the encoder.
|
162 |
+
norm_starts: layer at which group norm starts being used.
|
163 |
+
decoder layers are numbered in reverse order.
|
164 |
+
norm_groups: number of groups for group norm.
|
165 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
166 |
+
dconv_depth: depth of residual DConv branch.
|
167 |
+
dconv_comp: compression of DConv branch.
|
168 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
169 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
170 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
171 |
+
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
|
172 |
+
transformer in order to change the number of channels
|
173 |
+
t_layers: number of layers in each branch (waveform and spec) of the transformer
|
174 |
+
t_emb: "sin", "cape" or "scaled"
|
175 |
+
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
|
176 |
+
for instance if C = 384 (the number of channels in the transformer) and
|
177 |
+
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
|
178 |
+
384 * 4 = 1536
|
179 |
+
t_heads: number of heads for the transformer
|
180 |
+
t_dropout: dropout in the transformer
|
181 |
+
t_max_positions: max_positions for the "scaled" positional embedding, only
|
182 |
+
useful if t_emb="scaled"
|
183 |
+
t_norm_in: (bool) norm before addinf positional embedding and getting into the
|
184 |
+
transformer layers
|
185 |
+
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
|
186 |
+
timesteps (GroupNorm with group=1)
|
187 |
+
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
|
188 |
+
timesteps (GroupNorm with group=1)
|
189 |
+
t_norm_first: (bool) if True the norm is before the attention and before the FFN
|
190 |
+
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
|
191 |
+
t_max_period: (float) denominator in the sinusoidal embedding expression
|
192 |
+
t_weight_decay: (float) weight decay for the transformer
|
193 |
+
t_lr: (float) specific learning rate for the transformer
|
194 |
+
t_layer_scale: (bool) Layer Scale for the transformer
|
195 |
+
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
|
196 |
+
t_weight_pos_embed: (float) weighting of the positional embedding
|
197 |
+
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
|
198 |
+
see: https://arxiv.org/abs/2106.03143
|
199 |
+
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
|
200 |
+
during the inference, see: https://arxiv.org/abs/2106.03143
|
201 |
+
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
|
202 |
+
see: https://arxiv.org/abs/2106.03143
|
203 |
+
t_sparse_self_attn: (bool) if True, the self attentions are sparse
|
204 |
+
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
|
205 |
+
unless you designed really specific masks)
|
206 |
+
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
|
207 |
+
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
|
208 |
+
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
|
209 |
+
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
|
210 |
+
that generated the random part of the mask
|
211 |
+
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
|
212 |
+
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
|
213 |
+
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
|
214 |
+
and mask[:, :t_global_window] will be True
|
215 |
+
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
|
216 |
+
level of the random part of the mask.
|
217 |
+
t_cross_first: (bool) if True cross attention is the first layer of the
|
218 |
+
transformer (False seems to be better)
|
219 |
+
rescale: weight rescaling trick
|
220 |
+
use_train_segment: (bool) if True, the actual size that is used during the
|
221 |
+
training is used during inference.
|
222 |
+
"""
|
223 |
+
super().__init__()
|
224 |
+
self.cac = cac
|
225 |
+
self.wiener_residual = wiener_residual
|
226 |
+
self.audio_channels = audio_channels
|
227 |
+
self.sources = sources
|
228 |
+
self.kernel_size = kernel_size
|
229 |
+
self.context = context
|
230 |
+
self.stride = stride
|
231 |
+
self.depth = depth
|
232 |
+
self.bottom_channels = bottom_channels
|
233 |
+
self.channels = channels
|
234 |
+
self.samplerate = samplerate
|
235 |
+
self.segment = segment
|
236 |
+
self.use_train_segment = use_train_segment
|
237 |
+
self.nfft = nfft
|
238 |
+
self.hop_length = nfft // 4
|
239 |
+
self.wiener_iters = wiener_iters
|
240 |
+
self.end_iters = end_iters
|
241 |
+
self.freq_emb = None
|
242 |
+
assert wiener_iters == end_iters
|
243 |
+
|
244 |
+
self.encoder = nn.ModuleList()
|
245 |
+
self.decoder = nn.ModuleList()
|
246 |
+
|
247 |
+
self.tencoder = nn.ModuleList()
|
248 |
+
self.tdecoder = nn.ModuleList()
|
249 |
+
|
250 |
+
chin = audio_channels
|
251 |
+
chin_z = chin # number of channels for the freq branch
|
252 |
+
if self.cac:
|
253 |
+
chin_z *= 2
|
254 |
+
chout = channels_time or channels
|
255 |
+
chout_z = channels
|
256 |
+
freqs = nfft // 2
|
257 |
+
|
258 |
+
for index in range(depth):
|
259 |
+
norm = index >= norm_starts
|
260 |
+
freq = freqs > 1
|
261 |
+
stri = stride
|
262 |
+
ker = kernel_size
|
263 |
+
if not freq:
|
264 |
+
assert freqs == 1
|
265 |
+
ker = time_stride * 2
|
266 |
+
stri = time_stride
|
267 |
+
|
268 |
+
pad = True
|
269 |
+
last_freq = False
|
270 |
+
if freq and freqs <= kernel_size:
|
271 |
+
ker = freqs
|
272 |
+
pad = False
|
273 |
+
last_freq = True
|
274 |
+
|
275 |
+
kw = {
|
276 |
+
"kernel_size": ker,
|
277 |
+
"stride": stri,
|
278 |
+
"freq": freq,
|
279 |
+
"pad": pad,
|
280 |
+
"norm": norm,
|
281 |
+
"rewrite": rewrite,
|
282 |
+
"norm_groups": norm_groups,
|
283 |
+
"dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
284 |
+
}
|
285 |
+
kwt = dict(kw)
|
286 |
+
kwt["freq"] = 0
|
287 |
+
kwt["kernel_size"] = kernel_size
|
288 |
+
kwt["stride"] = stride
|
289 |
+
kwt["pad"] = True
|
290 |
+
kw_dec = dict(kw)
|
291 |
+
multi = False
|
292 |
+
if multi_freqs and index < multi_freqs_depth:
|
293 |
+
multi = True
|
294 |
+
kw_dec["context_freq"] = False
|
295 |
+
|
296 |
+
if last_freq:
|
297 |
+
chout_z = max(chout, chout_z)
|
298 |
+
chout = chout_z
|
299 |
+
|
300 |
+
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
301 |
+
if freq:
|
302 |
+
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
303 |
+
self.tencoder.append(tenc)
|
304 |
+
|
305 |
+
if multi:
|
306 |
+
enc = MultiWrap(enc, multi_freqs)
|
307 |
+
self.encoder.append(enc)
|
308 |
+
if index == 0:
|
309 |
+
chin = self.audio_channels * len(self.sources)
|
310 |
+
chin_z = chin
|
311 |
+
if self.cac:
|
312 |
+
chin_z *= 2
|
313 |
+
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
314 |
+
if multi:
|
315 |
+
dec = MultiWrap(dec, multi_freqs)
|
316 |
+
if freq:
|
317 |
+
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
318 |
+
self.tdecoder.insert(0, tdec)
|
319 |
+
self.decoder.insert(0, dec)
|
320 |
+
|
321 |
+
chin = chout
|
322 |
+
chin_z = chout_z
|
323 |
+
chout = int(growth * chout)
|
324 |
+
chout_z = int(growth * chout_z)
|
325 |
+
if freq:
|
326 |
+
if freqs <= kernel_size:
|
327 |
+
freqs = 1
|
328 |
+
else:
|
329 |
+
freqs //= stride
|
330 |
+
if index == 0 and freq_emb:
|
331 |
+
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
332 |
+
self.freq_emb_scale = freq_emb
|
333 |
+
|
334 |
+
if rescale:
|
335 |
+
rescale_module(self, reference=rescale)
|
336 |
+
|
337 |
+
transformer_channels = channels * growth ** (depth - 1)
|
338 |
+
if bottom_channels:
|
339 |
+
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
340 |
+
self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
341 |
+
self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
342 |
+
self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
343 |
+
|
344 |
+
transformer_channels = bottom_channels
|
345 |
+
|
346 |
+
if t_layers > 0:
|
347 |
+
self.crosstransformer = CrossTransformerEncoder(
|
348 |
+
dim=transformer_channels,
|
349 |
+
emb=t_emb,
|
350 |
+
hidden_scale=t_hidden_scale,
|
351 |
+
num_heads=t_heads,
|
352 |
+
num_layers=t_layers,
|
353 |
+
cross_first=t_cross_first,
|
354 |
+
dropout=t_dropout,
|
355 |
+
max_positions=t_max_positions,
|
356 |
+
norm_in=t_norm_in,
|
357 |
+
norm_in_group=t_norm_in_group,
|
358 |
+
group_norm=t_group_norm,
|
359 |
+
norm_first=t_norm_first,
|
360 |
+
norm_out=t_norm_out,
|
361 |
+
max_period=t_max_period,
|
362 |
+
weight_decay=t_weight_decay,
|
363 |
+
lr=t_lr,
|
364 |
+
layer_scale=t_layer_scale,
|
365 |
+
gelu=t_gelu,
|
366 |
+
sin_random_shift=t_sin_random_shift,
|
367 |
+
weight_pos_embed=t_weight_pos_embed,
|
368 |
+
cape_mean_normalize=t_cape_mean_normalize,
|
369 |
+
cape_augment=t_cape_augment,
|
370 |
+
cape_glob_loc_scale=t_cape_glob_loc_scale,
|
371 |
+
sparse_self_attn=t_sparse_self_attn,
|
372 |
+
sparse_cross_attn=t_sparse_cross_attn,
|
373 |
+
mask_type=t_mask_type,
|
374 |
+
mask_random_seed=t_mask_random_seed,
|
375 |
+
sparse_attn_window=t_sparse_attn_window,
|
376 |
+
global_window=t_global_window,
|
377 |
+
sparsity=t_sparsity,
|
378 |
+
auto_sparsity=t_auto_sparsity,
|
379 |
+
)
|
380 |
+
else:
|
381 |
+
self.crosstransformer = None
|
382 |
+
|
383 |
+
def _spec(self, x):
|
384 |
+
hl = self.hop_length
|
385 |
+
nfft = self.nfft
|
386 |
+
x0 = x # noqa
|
387 |
+
|
388 |
+
# We re-pad the signal in order to keep the property
|
389 |
+
# that the size of the output is exactly the size of the input
|
390 |
+
# divided by the stride (here hop_length), when divisible.
|
391 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
392 |
+
# which is not supported by torch.stft.
|
393 |
+
# Having all convolution operations follow this convention allow to easily
|
394 |
+
# align the time and frequency branches later on.
|
395 |
+
assert hl == nfft // 4
|
396 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
397 |
+
pad = hl // 2 * 3
|
398 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
399 |
+
|
400 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
401 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
402 |
+
z = z[..., 2 : 2 + le]
|
403 |
+
return z
|
404 |
+
|
405 |
+
def _ispec(self, z, length=None, scale=0):
|
406 |
+
hl = self.hop_length // (4**scale)
|
407 |
+
z = F.pad(z, (0, 0, 0, 1))
|
408 |
+
z = F.pad(z, (2, 2))
|
409 |
+
pad = hl // 2 * 3
|
410 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
411 |
+
x = ispectro(z, hl, length=le)
|
412 |
+
x = x[..., pad : pad + length]
|
413 |
+
return x
|
414 |
+
|
415 |
+
def _magnitude(self, z):
|
416 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
417 |
+
# in which case we just move the complex dimension to the channel one.
|
418 |
+
if self.cac:
|
419 |
+
B, C, Fr, T = z.shape
|
420 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
421 |
+
m = m.reshape(B, C * 2, Fr, T)
|
422 |
+
else:
|
423 |
+
m = z.abs()
|
424 |
+
return m
|
425 |
+
|
426 |
+
def _mask(self, z, m):
|
427 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
428 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
429 |
+
niters = self.wiener_iters
|
430 |
+
if self.cac:
|
431 |
+
B, S, C, Fr, T = m.shape
|
432 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
433 |
+
out = torch.view_as_complex(out.contiguous())
|
434 |
+
return out
|
435 |
+
if self.training:
|
436 |
+
niters = self.end_iters
|
437 |
+
if niters < 0:
|
438 |
+
z = z[:, None]
|
439 |
+
return z / (1e-8 + z.abs()) * m
|
440 |
+
else:
|
441 |
+
return self._wiener(m, z, niters)
|
442 |
+
|
443 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
444 |
+
# apply wiener filtering from OpenUnmix.
|
445 |
+
init = mix_stft.dtype
|
446 |
+
wiener_win_len = 300
|
447 |
+
residual = self.wiener_residual
|
448 |
+
|
449 |
+
B, S, C, Fq, T = mag_out.shape
|
450 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
451 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
452 |
+
|
453 |
+
outs = []
|
454 |
+
for sample in range(B):
|
455 |
+
pos = 0
|
456 |
+
out = []
|
457 |
+
for pos in range(0, T, wiener_win_len):
|
458 |
+
frame = slice(pos, pos + wiener_win_len)
|
459 |
+
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
460 |
+
out.append(z_out.transpose(-1, -2))
|
461 |
+
outs.append(torch.cat(out, dim=0))
|
462 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
463 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
464 |
+
if residual:
|
465 |
+
out = out[:, :-1]
|
466 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
467 |
+
return out.to(init)
|
468 |
+
|
469 |
+
def valid_length(self, length: int):
|
470 |
+
"""
|
471 |
+
Return a length that is appropriate for evaluation.
|
472 |
+
In our case, always return the training length, unless
|
473 |
+
it is smaller than the given length, in which case this
|
474 |
+
raises an error.
|
475 |
+
"""
|
476 |
+
if not self.use_train_segment:
|
477 |
+
return length
|
478 |
+
training_length = int(self.segment * self.samplerate)
|
479 |
+
if training_length < length:
|
480 |
+
raise ValueError(f"Given length {length} is longer than " f"training length {training_length}")
|
481 |
+
return training_length
|
482 |
+
|
483 |
+
def forward(self, mix):
|
484 |
+
length = mix.shape[-1]
|
485 |
+
length_pre_pad = None
|
486 |
+
if self.use_train_segment:
|
487 |
+
if self.training:
|
488 |
+
self.segment = Fraction(mix.shape[-1], self.samplerate)
|
489 |
+
else:
|
490 |
+
training_length = int(self.segment * self.samplerate)
|
491 |
+
if mix.shape[-1] < training_length:
|
492 |
+
length_pre_pad = mix.shape[-1]
|
493 |
+
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
494 |
+
z = self._spec(mix)
|
495 |
+
mag = self._magnitude(z).to(mix.device)
|
496 |
+
x = mag
|
497 |
+
|
498 |
+
B, C, Fq, T = x.shape
|
499 |
+
|
500 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
501 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
502 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
503 |
+
x = (x - mean) / (1e-5 + std)
|
504 |
+
# x will be the freq. branch input.
|
505 |
+
|
506 |
+
# Prepare the time branch input.
|
507 |
+
xt = mix
|
508 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
509 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
510 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
511 |
+
|
512 |
+
# okay, this is a giant mess I know...
|
513 |
+
saved = [] # skip connections, freq.
|
514 |
+
saved_t = [] # skip connections, time.
|
515 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
516 |
+
lengths_t = [] # saved lengths for time branch.
|
517 |
+
for idx, encode in enumerate(self.encoder):
|
518 |
+
lengths.append(x.shape[-1])
|
519 |
+
inject = None
|
520 |
+
if idx < len(self.tencoder):
|
521 |
+
# we have not yet merged branches.
|
522 |
+
lengths_t.append(xt.shape[-1])
|
523 |
+
tenc = self.tencoder[idx]
|
524 |
+
xt = tenc(xt)
|
525 |
+
if not tenc.empty:
|
526 |
+
# save for skip connection
|
527 |
+
saved_t.append(xt)
|
528 |
+
else:
|
529 |
+
# tenc contains just the first conv., so that now time and freq.
|
530 |
+
# branches have the same shape and can be merged.
|
531 |
+
inject = xt
|
532 |
+
x = encode(x, inject)
|
533 |
+
if idx == 0 and self.freq_emb is not None:
|
534 |
+
# add frequency embedding to allow for non equivariant convolutions
|
535 |
+
# over the frequency axis.
|
536 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
537 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
538 |
+
x = x + self.freq_emb_scale * emb
|
539 |
+
|
540 |
+
saved.append(x)
|
541 |
+
if self.crosstransformer:
|
542 |
+
if self.bottom_channels:
|
543 |
+
b, c, f, t = x.shape
|
544 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
545 |
+
x = self.channel_upsampler(x)
|
546 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
547 |
+
xt = self.channel_upsampler_t(xt)
|
548 |
+
|
549 |
+
x, xt = self.crosstransformer(x, xt)
|
550 |
+
|
551 |
+
if self.bottom_channels:
|
552 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
553 |
+
x = self.channel_downsampler(x)
|
554 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
555 |
+
xt = self.channel_downsampler_t(xt)
|
556 |
+
|
557 |
+
for idx, decode in enumerate(self.decoder):
|
558 |
+
skip = saved.pop(-1)
|
559 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
560 |
+
# `pre` contains the output just before final transposed convolution,
|
561 |
+
# which is used when the freq. and time branch separate.
|
562 |
+
|
563 |
+
offset = self.depth - len(self.tdecoder)
|
564 |
+
if idx >= offset:
|
565 |
+
tdec = self.tdecoder[idx - offset]
|
566 |
+
length_t = lengths_t.pop(-1)
|
567 |
+
if tdec.empty:
|
568 |
+
assert pre.shape[2] == 1, pre.shape
|
569 |
+
pre = pre[:, :, 0]
|
570 |
+
xt, _ = tdec(pre, None, length_t)
|
571 |
+
else:
|
572 |
+
skip = saved_t.pop(-1)
|
573 |
+
xt, _ = tdec(xt, skip, length_t)
|
574 |
+
|
575 |
+
# Let's make sure we used all stored skip connections.
|
576 |
+
assert len(saved) == 0
|
577 |
+
assert len(lengths_t) == 0
|
578 |
+
assert len(saved_t) == 0
|
579 |
+
|
580 |
+
S = len(self.sources)
|
581 |
+
x = x.view(B, S, -1, Fq, T)
|
582 |
+
x = x * std[:, None] + mean[:, None]
|
583 |
+
|
584 |
+
# to cpu as non-cuda GPUs don't support complex numbers
|
585 |
+
# demucs issue #435 ##432
|
586 |
+
# NOTE: in this case z already is on cpu
|
587 |
+
# TODO: remove this when mps supports complex numbers
|
588 |
+
|
589 |
+
device_type = x.device.type
|
590 |
+
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
591 |
+
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
592 |
+
|
593 |
+
if x_is_other_gpu:
|
594 |
+
x = x.cpu()
|
595 |
+
|
596 |
+
zout = self._mask(z, x)
|
597 |
+
if self.use_train_segment:
|
598 |
+
if self.training:
|
599 |
+
x = self._ispec(zout, length)
|
600 |
+
else:
|
601 |
+
x = self._ispec(zout, training_length)
|
602 |
+
else:
|
603 |
+
x = self._ispec(zout, length)
|
604 |
+
|
605 |
+
# back to other device
|
606 |
+
if x_is_other_gpu:
|
607 |
+
x = x.to(device_load)
|
608 |
+
|
609 |
+
if self.use_train_segment:
|
610 |
+
if self.training:
|
611 |
+
xt = xt.view(B, S, -1, length)
|
612 |
+
else:
|
613 |
+
xt = xt.view(B, S, -1, training_length)
|
614 |
+
else:
|
615 |
+
xt = xt.view(B, S, -1, length)
|
616 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
617 |
+
x = xt + x
|
618 |
+
if length_pre_pad:
|
619 |
+
x = x[..., :length_pre_pad]
|
620 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/model.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch as th
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from .utils import capture_init, center_trim
|
13 |
+
|
14 |
+
|
15 |
+
class BLSTM(nn.Module):
|
16 |
+
def __init__(self, dim, layers=1):
|
17 |
+
super().__init__()
|
18 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
19 |
+
self.linear = nn.Linear(2 * dim, dim)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x.permute(2, 0, 1)
|
23 |
+
x = self.lstm(x)[0]
|
24 |
+
x = self.linear(x)
|
25 |
+
x = x.permute(1, 2, 0)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def rescale_conv(conv, reference):
|
30 |
+
std = conv.weight.std().detach()
|
31 |
+
scale = (std / reference) ** 0.5
|
32 |
+
conv.weight.data /= scale
|
33 |
+
if conv.bias is not None:
|
34 |
+
conv.bias.data /= scale
|
35 |
+
|
36 |
+
|
37 |
+
def rescale_module(module, reference):
|
38 |
+
for sub in module.modules():
|
39 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
40 |
+
rescale_conv(sub, reference)
|
41 |
+
|
42 |
+
|
43 |
+
def upsample(x, stride):
|
44 |
+
"""
|
45 |
+
Linear upsampling, the output will be `stride` times longer.
|
46 |
+
"""
|
47 |
+
batch, channels, time = x.size()
|
48 |
+
weight = th.arange(stride, device=x.device, dtype=th.float) / stride
|
49 |
+
x = x.view(batch, channels, time, 1)
|
50 |
+
out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
|
51 |
+
return out.reshape(batch, channels, -1)
|
52 |
+
|
53 |
+
|
54 |
+
def downsample(x, stride):
|
55 |
+
"""
|
56 |
+
Downsample x by decimation.
|
57 |
+
"""
|
58 |
+
return x[:, :, ::stride]
|
59 |
+
|
60 |
+
|
61 |
+
class Demucs(nn.Module):
|
62 |
+
@capture_init
|
63 |
+
def __init__(
|
64 |
+
self, sources=4, audio_channels=2, channels=64, depth=6, rewrite=True, glu=True, upsample=False, rescale=0.1, kernel_size=8, stride=4, growth=2.0, lstm_layers=2, context=3, samplerate=44100
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
Args:
|
68 |
+
sources (int): number of sources to separate
|
69 |
+
audio_channels (int): stereo or mono
|
70 |
+
channels (int): first convolution channels
|
71 |
+
depth (int): number of encoder/decoder layers
|
72 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
73 |
+
and a convolution to each decoder layer.
|
74 |
+
For the decoder layer, `context` gives the kernel size.
|
75 |
+
glu (bool): use glu instead of ReLU
|
76 |
+
upsample (bool): use linear upsampling with convolutions
|
77 |
+
Wave-U-Net style, instead of transposed convolutions
|
78 |
+
rescale (int): rescale initial weights of convolutions
|
79 |
+
to get their standard deviation closer to `rescale`
|
80 |
+
kernel_size (int): kernel size for convolutions
|
81 |
+
stride (int): stride for convolutions
|
82 |
+
growth (float): multiply (resp divide) number of channels by that
|
83 |
+
for each layer of the encoder (resp decoder)
|
84 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
85 |
+
context (int): kernel size of the convolution in the
|
86 |
+
decoder before the transposed convolution. If > 1,
|
87 |
+
will provide some context from neighboring time
|
88 |
+
steps.
|
89 |
+
"""
|
90 |
+
|
91 |
+
super().__init__()
|
92 |
+
self.audio_channels = audio_channels
|
93 |
+
self.sources = sources
|
94 |
+
self.kernel_size = kernel_size
|
95 |
+
self.context = context
|
96 |
+
self.stride = stride
|
97 |
+
self.depth = depth
|
98 |
+
self.upsample = upsample
|
99 |
+
self.channels = channels
|
100 |
+
self.samplerate = samplerate
|
101 |
+
|
102 |
+
self.encoder = nn.ModuleList()
|
103 |
+
self.decoder = nn.ModuleList()
|
104 |
+
|
105 |
+
self.final = None
|
106 |
+
if upsample:
|
107 |
+
self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
|
108 |
+
stride = 1
|
109 |
+
|
110 |
+
if glu:
|
111 |
+
activation = nn.GLU(dim=1)
|
112 |
+
ch_scale = 2
|
113 |
+
else:
|
114 |
+
activation = nn.ReLU()
|
115 |
+
ch_scale = 1
|
116 |
+
in_channels = audio_channels
|
117 |
+
for index in range(depth):
|
118 |
+
encode = []
|
119 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
120 |
+
if rewrite:
|
121 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
122 |
+
self.encoder.append(nn.Sequential(*encode))
|
123 |
+
|
124 |
+
decode = []
|
125 |
+
if index > 0:
|
126 |
+
out_channels = in_channels
|
127 |
+
else:
|
128 |
+
if upsample:
|
129 |
+
out_channels = channels
|
130 |
+
else:
|
131 |
+
out_channels = sources * audio_channels
|
132 |
+
if rewrite:
|
133 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
134 |
+
if upsample:
|
135 |
+
decode += [nn.Conv1d(channels, out_channels, kernel_size, stride=1)]
|
136 |
+
else:
|
137 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
138 |
+
if index > 0:
|
139 |
+
decode.append(nn.ReLU())
|
140 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
141 |
+
in_channels = channels
|
142 |
+
channels = int(growth * channels)
|
143 |
+
|
144 |
+
channels = in_channels
|
145 |
+
|
146 |
+
if lstm_layers:
|
147 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
148 |
+
else:
|
149 |
+
self.lstm = None
|
150 |
+
|
151 |
+
if rescale:
|
152 |
+
rescale_module(self, reference=rescale)
|
153 |
+
|
154 |
+
def valid_length(self, length):
|
155 |
+
"""
|
156 |
+
Return the nearest valid length to use with the model so that
|
157 |
+
there is no time steps left over in a convolutions, e.g. for all
|
158 |
+
layers, size of the input - kernel_size % stride = 0.
|
159 |
+
|
160 |
+
If the mixture has a valid length, the estimated sources
|
161 |
+
will have exactly the same length when context = 1. If context > 1,
|
162 |
+
the two signals can be center trimmed to match.
|
163 |
+
|
164 |
+
For training, extracts should have a valid length.For evaluation
|
165 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
166 |
+
"""
|
167 |
+
for _ in range(self.depth):
|
168 |
+
if self.upsample:
|
169 |
+
length = math.ceil(length / self.stride) + self.kernel_size - 1
|
170 |
+
else:
|
171 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
172 |
+
length = max(1, length)
|
173 |
+
length += self.context - 1
|
174 |
+
for _ in range(self.depth):
|
175 |
+
if self.upsample:
|
176 |
+
length = length * self.stride + self.kernel_size - 1
|
177 |
+
else:
|
178 |
+
length = (length - 1) * self.stride + self.kernel_size
|
179 |
+
|
180 |
+
return int(length)
|
181 |
+
|
182 |
+
def forward(self, mix):
|
183 |
+
x = mix
|
184 |
+
saved = [x]
|
185 |
+
for encode in self.encoder:
|
186 |
+
x = encode(x)
|
187 |
+
saved.append(x)
|
188 |
+
if self.upsample:
|
189 |
+
x = downsample(x, self.stride)
|
190 |
+
if self.lstm:
|
191 |
+
x = self.lstm(x)
|
192 |
+
for decode in self.decoder:
|
193 |
+
if self.upsample:
|
194 |
+
x = upsample(x, stride=self.stride)
|
195 |
+
skip = center_trim(saved.pop(-1), x)
|
196 |
+
x = x + skip
|
197 |
+
x = decode(x)
|
198 |
+
if self.final:
|
199 |
+
skip = center_trim(saved.pop(-1), x)
|
200 |
+
x = th.cat([x, skip], dim=1)
|
201 |
+
x = self.final(x)
|
202 |
+
|
203 |
+
x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
|
204 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/model_v2.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import julius
|
10 |
+
from torch import nn
|
11 |
+
from .tasnet_v2 import ConvTasNet
|
12 |
+
|
13 |
+
from .utils import capture_init, center_trim
|
14 |
+
|
15 |
+
|
16 |
+
class BLSTM(nn.Module):
|
17 |
+
def __init__(self, dim, layers=1):
|
18 |
+
super().__init__()
|
19 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
20 |
+
self.linear = nn.Linear(2 * dim, dim)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = x.permute(2, 0, 1)
|
24 |
+
x = self.lstm(x)[0]
|
25 |
+
x = self.linear(x)
|
26 |
+
x = x.permute(1, 2, 0)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
def rescale_conv(conv, reference):
|
31 |
+
std = conv.weight.std().detach()
|
32 |
+
scale = (std / reference) ** 0.5
|
33 |
+
conv.weight.data /= scale
|
34 |
+
if conv.bias is not None:
|
35 |
+
conv.bias.data /= scale
|
36 |
+
|
37 |
+
|
38 |
+
def rescale_module(module, reference):
|
39 |
+
for sub in module.modules():
|
40 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
41 |
+
rescale_conv(sub, reference)
|
42 |
+
|
43 |
+
|
44 |
+
def auto_load_demucs_model_v2(sources, demucs_model_name):
|
45 |
+
|
46 |
+
if "48" in demucs_model_name:
|
47 |
+
channels = 48
|
48 |
+
elif "unittest" in demucs_model_name:
|
49 |
+
channels = 4
|
50 |
+
else:
|
51 |
+
channels = 64
|
52 |
+
|
53 |
+
if "tasnet" in demucs_model_name:
|
54 |
+
init_demucs_model = ConvTasNet(sources, X=10)
|
55 |
+
else:
|
56 |
+
init_demucs_model = Demucs(sources, channels=channels)
|
57 |
+
|
58 |
+
return init_demucs_model
|
59 |
+
|
60 |
+
|
61 |
+
class Demucs(nn.Module):
|
62 |
+
@capture_init
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
sources,
|
66 |
+
audio_channels=2,
|
67 |
+
channels=64,
|
68 |
+
depth=6,
|
69 |
+
rewrite=True,
|
70 |
+
glu=True,
|
71 |
+
rescale=0.1,
|
72 |
+
resample=True,
|
73 |
+
kernel_size=8,
|
74 |
+
stride=4,
|
75 |
+
growth=2.0,
|
76 |
+
lstm_layers=2,
|
77 |
+
context=3,
|
78 |
+
normalize=False,
|
79 |
+
samplerate=44100,
|
80 |
+
segment_length=4 * 10 * 44100,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
sources (list[str]): list of source names
|
85 |
+
audio_channels (int): stereo or mono
|
86 |
+
channels (int): first convolution channels
|
87 |
+
depth (int): number of encoder/decoder layers
|
88 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
89 |
+
and a convolution to each decoder layer.
|
90 |
+
For the decoder layer, `context` gives the kernel size.
|
91 |
+
glu (bool): use glu instead of ReLU
|
92 |
+
resample_input (bool): upsample x2 the input and downsample /2 the output.
|
93 |
+
rescale (int): rescale initial weights of convolutions
|
94 |
+
to get their standard deviation closer to `rescale`
|
95 |
+
kernel_size (int): kernel size for convolutions
|
96 |
+
stride (int): stride for convolutions
|
97 |
+
growth (float): multiply (resp divide) number of channels by that
|
98 |
+
for each layer of the encoder (resp decoder)
|
99 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
100 |
+
context (int): kernel size of the convolution in the
|
101 |
+
decoder before the transposed convolution. If > 1,
|
102 |
+
will provide some context from neighboring time
|
103 |
+
steps.
|
104 |
+
samplerate (int): stored as meta information for easing
|
105 |
+
future evaluations of the model.
|
106 |
+
segment_length (int): stored as meta information for easing
|
107 |
+
future evaluations of the model. Length of the segments on which
|
108 |
+
the model was trained.
|
109 |
+
"""
|
110 |
+
|
111 |
+
super().__init__()
|
112 |
+
self.audio_channels = audio_channels
|
113 |
+
self.sources = sources
|
114 |
+
self.kernel_size = kernel_size
|
115 |
+
self.context = context
|
116 |
+
self.stride = stride
|
117 |
+
self.depth = depth
|
118 |
+
self.resample = resample
|
119 |
+
self.channels = channels
|
120 |
+
self.normalize = normalize
|
121 |
+
self.samplerate = samplerate
|
122 |
+
self.segment_length = segment_length
|
123 |
+
|
124 |
+
self.encoder = nn.ModuleList()
|
125 |
+
self.decoder = nn.ModuleList()
|
126 |
+
|
127 |
+
if glu:
|
128 |
+
activation = nn.GLU(dim=1)
|
129 |
+
ch_scale = 2
|
130 |
+
else:
|
131 |
+
activation = nn.ReLU()
|
132 |
+
ch_scale = 1
|
133 |
+
in_channels = audio_channels
|
134 |
+
for index in range(depth):
|
135 |
+
encode = []
|
136 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
137 |
+
if rewrite:
|
138 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
139 |
+
self.encoder.append(nn.Sequential(*encode))
|
140 |
+
|
141 |
+
decode = []
|
142 |
+
if index > 0:
|
143 |
+
out_channels = in_channels
|
144 |
+
else:
|
145 |
+
out_channels = len(self.sources) * audio_channels
|
146 |
+
if rewrite:
|
147 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
148 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
149 |
+
if index > 0:
|
150 |
+
decode.append(nn.ReLU())
|
151 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
152 |
+
in_channels = channels
|
153 |
+
channels = int(growth * channels)
|
154 |
+
|
155 |
+
channels = in_channels
|
156 |
+
|
157 |
+
if lstm_layers:
|
158 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
159 |
+
else:
|
160 |
+
self.lstm = None
|
161 |
+
|
162 |
+
if rescale:
|
163 |
+
rescale_module(self, reference=rescale)
|
164 |
+
|
165 |
+
def valid_length(self, length):
|
166 |
+
"""
|
167 |
+
Return the nearest valid length to use with the model so that
|
168 |
+
there is no time steps left over in a convolutions, e.g. for all
|
169 |
+
layers, size of the input - kernel_size % stride = 0.
|
170 |
+
|
171 |
+
If the mixture has a valid length, the estimated sources
|
172 |
+
will have exactly the same length when context = 1. If context > 1,
|
173 |
+
the two signals can be center trimmed to match.
|
174 |
+
|
175 |
+
For training, extracts should have a valid length.For evaluation
|
176 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
177 |
+
"""
|
178 |
+
if self.resample:
|
179 |
+
length *= 2
|
180 |
+
for _ in range(self.depth):
|
181 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
182 |
+
length = max(1, length)
|
183 |
+
length += self.context - 1
|
184 |
+
for _ in range(self.depth):
|
185 |
+
length = (length - 1) * self.stride + self.kernel_size
|
186 |
+
|
187 |
+
if self.resample:
|
188 |
+
length = math.ceil(length / 2)
|
189 |
+
return int(length)
|
190 |
+
|
191 |
+
def forward(self, mix):
|
192 |
+
x = mix
|
193 |
+
|
194 |
+
if self.normalize:
|
195 |
+
mono = mix.mean(dim=1, keepdim=True)
|
196 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
197 |
+
std = mono.std(dim=-1, keepdim=True)
|
198 |
+
else:
|
199 |
+
mean = 0
|
200 |
+
std = 1
|
201 |
+
|
202 |
+
x = (x - mean) / (1e-5 + std)
|
203 |
+
|
204 |
+
if self.resample:
|
205 |
+
x = julius.resample_frac(x, 1, 2)
|
206 |
+
|
207 |
+
saved = []
|
208 |
+
for encode in self.encoder:
|
209 |
+
x = encode(x)
|
210 |
+
saved.append(x)
|
211 |
+
if self.lstm:
|
212 |
+
x = self.lstm(x)
|
213 |
+
for decode in self.decoder:
|
214 |
+
skip = center_trim(saved.pop(-1), x)
|
215 |
+
x = x + skip
|
216 |
+
x = decode(x)
|
217 |
+
|
218 |
+
if self.resample:
|
219 |
+
x = julius.resample_frac(x, 2, 1)
|
220 |
+
x = x * std + mean
|
221 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
222 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/pretrained.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Loading pretrained models.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
# from dora.log import fatal
|
14 |
+
|
15 |
+
import logging
|
16 |
+
|
17 |
+
from diffq import DiffQuantizer
|
18 |
+
import torch.hub
|
19 |
+
|
20 |
+
from .model import Demucs
|
21 |
+
from .tasnet_v2 import ConvTasNet
|
22 |
+
from .utils import set_state
|
23 |
+
|
24 |
+
from .hdemucs import HDemucs
|
25 |
+
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
|
29 |
+
REMOTE_ROOT = Path(__file__).parent / "remote"
|
30 |
+
|
31 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
32 |
+
|
33 |
+
|
34 |
+
def demucs_unittest():
|
35 |
+
model = HDemucs(channels=4, sources=SOURCES)
|
36 |
+
return model
|
37 |
+
|
38 |
+
|
39 |
+
def add_model_flags(parser):
|
40 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
41 |
+
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
42 |
+
group.add_argument("-n", "--name", default="mdx_extra_q", help="Pretrained model name or signature. Default is mdx_extra_q.")
|
43 |
+
parser.add_argument("--repo", type=Path, help="Folder containing all pre-trained models for use with -n.")
|
44 |
+
|
45 |
+
|
46 |
+
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
|
47 |
+
root: str = ""
|
48 |
+
models: tp.Dict[str, str] = {}
|
49 |
+
for line in remote_file_list.read_text().split("\n"):
|
50 |
+
line = line.strip()
|
51 |
+
if line.startswith("#"):
|
52 |
+
continue
|
53 |
+
elif line.startswith("root:"):
|
54 |
+
root = line.split(":", 1)[1].strip()
|
55 |
+
else:
|
56 |
+
sig = line.split("-", 1)[0]
|
57 |
+
assert sig not in models
|
58 |
+
models[sig] = ROOT_URL + root + line
|
59 |
+
return models
|
60 |
+
|
61 |
+
|
62 |
+
def get_model(name: str, repo: tp.Optional[Path] = None):
|
63 |
+
"""`name` must be a bag of models name or a pretrained signature
|
64 |
+
from the remote AWS model repo or the specified local repo if `repo` is not None.
|
65 |
+
"""
|
66 |
+
if name == "demucs_unittest":
|
67 |
+
return demucs_unittest()
|
68 |
+
model_repo: ModelOnlyRepo
|
69 |
+
if repo is None:
|
70 |
+
models = _parse_remote_files(REMOTE_ROOT / "files.txt")
|
71 |
+
model_repo = RemoteRepo(models)
|
72 |
+
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
73 |
+
else:
|
74 |
+
if not repo.is_dir():
|
75 |
+
fatal(f"{repo} must exist and be a directory.")
|
76 |
+
model_repo = LocalRepo(repo)
|
77 |
+
bag_repo = BagOnlyRepo(repo, model_repo)
|
78 |
+
any_repo = AnyModelRepo(model_repo, bag_repo)
|
79 |
+
model = any_repo.get_model(name)
|
80 |
+
model.eval()
|
81 |
+
return model
|
82 |
+
|
83 |
+
|
84 |
+
def get_model_from_args(args):
|
85 |
+
"""
|
86 |
+
Load local model package or pre-trained model.
|
87 |
+
"""
|
88 |
+
return get_model(name=args.name, repo=args.repo)
|
89 |
+
|
90 |
+
|
91 |
+
logger = logging.getLogger(__name__)
|
92 |
+
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
|
93 |
+
|
94 |
+
PRETRAINED_MODELS = {
|
95 |
+
"demucs": "e07c671f",
|
96 |
+
"demucs48_hq": "28a1282c",
|
97 |
+
"demucs_extra": "3646af93",
|
98 |
+
"demucs_quantized": "07afea75",
|
99 |
+
"tasnet": "beb46fac",
|
100 |
+
"tasnet_extra": "df3777b2",
|
101 |
+
"demucs_unittest": "09ebc15f",
|
102 |
+
}
|
103 |
+
|
104 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
105 |
+
|
106 |
+
|
107 |
+
def get_url(name):
|
108 |
+
sig = PRETRAINED_MODELS[name]
|
109 |
+
return ROOT + name + "-" + sig[:8] + ".th"
|
110 |
+
|
111 |
+
|
112 |
+
def is_pretrained(name):
|
113 |
+
return name in PRETRAINED_MODELS
|
114 |
+
|
115 |
+
|
116 |
+
def load_pretrained(name):
|
117 |
+
if name == "demucs":
|
118 |
+
return demucs(pretrained=True)
|
119 |
+
elif name == "demucs48_hq":
|
120 |
+
return demucs(pretrained=True, hq=True, channels=48)
|
121 |
+
elif name == "demucs_extra":
|
122 |
+
return demucs(pretrained=True, extra=True)
|
123 |
+
elif name == "demucs_quantized":
|
124 |
+
return demucs(pretrained=True, quantized=True)
|
125 |
+
elif name == "demucs_unittest":
|
126 |
+
return demucs_unittest(pretrained=True)
|
127 |
+
elif name == "tasnet":
|
128 |
+
return tasnet(pretrained=True)
|
129 |
+
elif name == "tasnet_extra":
|
130 |
+
return tasnet(pretrained=True, extra=True)
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Invalid pretrained name {name}")
|
133 |
+
|
134 |
+
|
135 |
+
def _load_state(name, model, quantizer=None):
|
136 |
+
url = get_url(name)
|
137 |
+
state = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
|
138 |
+
set_state(model, quantizer, state)
|
139 |
+
if quantizer:
|
140 |
+
quantizer.detach()
|
141 |
+
|
142 |
+
|
143 |
+
def demucs_unittest(pretrained=True):
|
144 |
+
model = Demucs(channels=4, sources=SOURCES)
|
145 |
+
if pretrained:
|
146 |
+
_load_state("demucs_unittest", model)
|
147 |
+
return model
|
148 |
+
|
149 |
+
|
150 |
+
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
|
151 |
+
if not pretrained and (extra or quantized or hq):
|
152 |
+
raise ValueError("if extra or quantized is True, pretrained must be True.")
|
153 |
+
model = Demucs(sources=SOURCES, channels=channels)
|
154 |
+
if pretrained:
|
155 |
+
name = "demucs"
|
156 |
+
if channels != 64:
|
157 |
+
name += str(channels)
|
158 |
+
quantizer = None
|
159 |
+
if sum([extra, quantized, hq]) > 1:
|
160 |
+
raise ValueError("Only one of extra, quantized, hq, can be True.")
|
161 |
+
if quantized:
|
162 |
+
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
|
163 |
+
name += "_quantized"
|
164 |
+
if extra:
|
165 |
+
name += "_extra"
|
166 |
+
if hq:
|
167 |
+
name += "_hq"
|
168 |
+
_load_state(name, model, quantizer)
|
169 |
+
return model
|
170 |
+
|
171 |
+
|
172 |
+
def tasnet(pretrained=True, extra=False):
|
173 |
+
if not pretrained and extra:
|
174 |
+
raise ValueError("if extra is True, pretrained must be True.")
|
175 |
+
model = ConvTasNet(X=10, sources=SOURCES)
|
176 |
+
if pretrained:
|
177 |
+
name = "tasnet"
|
178 |
+
if extra:
|
179 |
+
name = "tasnet_extra"
|
180 |
+
_load_state(name, model)
|
181 |
+
return model
|
audio_separator/separator/uvr_lib_v5/demucs/repo.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Represents a model repository, including pre-trained models and bags of models.
|
7 |
+
A repo can either be the main remote repository stored in AWS, or a local repository
|
8 |
+
with your own models.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from hashlib import sha256
|
12 |
+
from pathlib import Path
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import yaml
|
17 |
+
|
18 |
+
from .apply import BagOfModels, Model
|
19 |
+
from .states import load_model
|
20 |
+
|
21 |
+
|
22 |
+
AnyModel = tp.Union[Model, BagOfModels]
|
23 |
+
|
24 |
+
|
25 |
+
class ModelLoadingError(RuntimeError):
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
def check_checksum(path: Path, checksum: str):
|
30 |
+
sha = sha256()
|
31 |
+
with open(path, "rb") as file:
|
32 |
+
while True:
|
33 |
+
buf = file.read(2**20)
|
34 |
+
if not buf:
|
35 |
+
break
|
36 |
+
sha.update(buf)
|
37 |
+
actual_checksum = sha.hexdigest()[: len(checksum)]
|
38 |
+
if actual_checksum != checksum:
|
39 |
+
raise ModelLoadingError(f"Invalid checksum for file {path}, " f"expected {checksum} but got {actual_checksum}")
|
40 |
+
|
41 |
+
|
42 |
+
class ModelOnlyRepo:
|
43 |
+
"""Base class for all model only repos."""
|
44 |
+
|
45 |
+
def has_model(self, sig: str) -> bool:
|
46 |
+
raise NotImplementedError()
|
47 |
+
|
48 |
+
def get_model(self, sig: str) -> Model:
|
49 |
+
raise NotImplementedError()
|
50 |
+
|
51 |
+
|
52 |
+
class RemoteRepo(ModelOnlyRepo):
|
53 |
+
def __init__(self, models: tp.Dict[str, str]):
|
54 |
+
self._models = models
|
55 |
+
|
56 |
+
def has_model(self, sig: str) -> bool:
|
57 |
+
return sig in self._models
|
58 |
+
|
59 |
+
def get_model(self, sig: str) -> Model:
|
60 |
+
try:
|
61 |
+
url = self._models[sig]
|
62 |
+
except KeyError:
|
63 |
+
raise ModelLoadingError(f"Could not find a pre-trained model with signature {sig}.")
|
64 |
+
pkg = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
|
65 |
+
return load_model(pkg)
|
66 |
+
|
67 |
+
|
68 |
+
class LocalRepo(ModelOnlyRepo):
|
69 |
+
def __init__(self, root: Path):
|
70 |
+
self.root = root
|
71 |
+
self.scan()
|
72 |
+
|
73 |
+
def scan(self):
|
74 |
+
self._models = {}
|
75 |
+
self._checksums = {}
|
76 |
+
for file in self.root.iterdir():
|
77 |
+
if file.suffix == ".th":
|
78 |
+
if "-" in file.stem:
|
79 |
+
xp_sig, checksum = file.stem.split("-")
|
80 |
+
self._checksums[xp_sig] = checksum
|
81 |
+
else:
|
82 |
+
xp_sig = file.stem
|
83 |
+
if xp_sig in self._models:
|
84 |
+
print("Whats xp? ", xp_sig)
|
85 |
+
raise ModelLoadingError(f"Duplicate pre-trained model exist for signature {xp_sig}. " "Please delete all but one.")
|
86 |
+
self._models[xp_sig] = file
|
87 |
+
|
88 |
+
def has_model(self, sig: str) -> bool:
|
89 |
+
return sig in self._models
|
90 |
+
|
91 |
+
def get_model(self, sig: str) -> Model:
|
92 |
+
try:
|
93 |
+
file = self._models[sig]
|
94 |
+
except KeyError:
|
95 |
+
raise ModelLoadingError(f"Could not find pre-trained model with signature {sig}.")
|
96 |
+
if sig in self._checksums:
|
97 |
+
check_checksum(file, self._checksums[sig])
|
98 |
+
return load_model(file)
|
99 |
+
|
100 |
+
|
101 |
+
class BagOnlyRepo:
|
102 |
+
"""Handles only YAML files containing bag of models, leaving the actual
|
103 |
+
model loading to some Repo.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
|
107 |
+
self.root = root
|
108 |
+
self.model_repo = model_repo
|
109 |
+
self.scan()
|
110 |
+
|
111 |
+
def scan(self):
|
112 |
+
self._bags = {}
|
113 |
+
for file in self.root.iterdir():
|
114 |
+
if file.suffix == ".yaml":
|
115 |
+
self._bags[file.stem] = file
|
116 |
+
|
117 |
+
def has_model(self, name: str) -> bool:
|
118 |
+
return name in self._bags
|
119 |
+
|
120 |
+
def get_model(self, name: str) -> BagOfModels:
|
121 |
+
try:
|
122 |
+
yaml_file = self._bags[name]
|
123 |
+
except KeyError:
|
124 |
+
raise ModelLoadingError(f"{name} is neither a single pre-trained model or " "a bag of models.")
|
125 |
+
bag = yaml.safe_load(open(yaml_file))
|
126 |
+
signatures = bag["models"]
|
127 |
+
models = [self.model_repo.get_model(sig) for sig in signatures]
|
128 |
+
weights = bag.get("weights")
|
129 |
+
segment = bag.get("segment")
|
130 |
+
return BagOfModels(models, weights, segment)
|
131 |
+
|
132 |
+
|
133 |
+
class AnyModelRepo:
|
134 |
+
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
|
135 |
+
self.model_repo = model_repo
|
136 |
+
self.bag_repo = bag_repo
|
137 |
+
|
138 |
+
def has_model(self, name_or_sig: str) -> bool:
|
139 |
+
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
|
140 |
+
|
141 |
+
def get_model(self, name_or_sig: str) -> AnyModel:
|
142 |
+
# print('name_or_sig: ', name_or_sig)
|
143 |
+
if self.model_repo.has_model(name_or_sig):
|
144 |
+
return self.model_repo.get_model(name_or_sig)
|
145 |
+
else:
|
146 |
+
return self.bag_repo.get_model(name_or_sig)
|
audio_separator/separator/uvr_lib_v5/demucs/spec.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Conveniance wrapper to perform STFT and iSTFT"""
|
7 |
+
|
8 |
+
import torch as th
|
9 |
+
|
10 |
+
|
11 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
12 |
+
*other, length = x.shape
|
13 |
+
x = x.reshape(-1, length)
|
14 |
+
|
15 |
+
device_type = x.device.type
|
16 |
+
is_other_gpu = not device_type in ["cuda", "cpu"]
|
17 |
+
|
18 |
+
if is_other_gpu:
|
19 |
+
x = x.cpu()
|
20 |
+
z = th.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=th.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
|
21 |
+
_, freqs, frame = z.shape
|
22 |
+
return z.view(*other, freqs, frame)
|
23 |
+
|
24 |
+
|
25 |
+
def ispectro(z, hop_length=None, length=None, pad=0):
|
26 |
+
*other, freqs, frames = z.shape
|
27 |
+
n_fft = 2 * freqs - 2
|
28 |
+
z = z.view(-1, freqs, frames)
|
29 |
+
win_length = n_fft // (1 + pad)
|
30 |
+
|
31 |
+
device_type = z.device.type
|
32 |
+
is_other_gpu = not device_type in ["cuda", "cpu"]
|
33 |
+
|
34 |
+
if is_other_gpu:
|
35 |
+
z = z.cpu()
|
36 |
+
x = th.istft(z, n_fft, hop_length, window=th.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
|
37 |
+
_, length = x.shape
|
38 |
+
return x.view(*other, length)
|
audio_separator/separator/uvr_lib_v5/demucs/states.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Utilities to save and load models.
|
8 |
+
"""
|
9 |
+
from contextlib import contextmanager
|
10 |
+
|
11 |
+
import functools
|
12 |
+
import hashlib
|
13 |
+
import inspect
|
14 |
+
import io
|
15 |
+
from pathlib import Path
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
|
19 |
+
import torch
|
20 |
+
|
21 |
+
|
22 |
+
def get_quantizer(model, args, optimizer=None):
|
23 |
+
"""Return the quantizer given the XP quantization args."""
|
24 |
+
quantizer = None
|
25 |
+
if args.diffq:
|
26 |
+
quantizer = DiffQuantizer(model, min_size=args.min_size, group_size=args.group_size)
|
27 |
+
if optimizer is not None:
|
28 |
+
quantizer.setup_optimizer(optimizer)
|
29 |
+
elif args.qat:
|
30 |
+
quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.min_size)
|
31 |
+
return quantizer
|
32 |
+
|
33 |
+
|
34 |
+
def load_model(path_or_package, strict=False):
|
35 |
+
"""Load a model from the given serialized model, either given as a dict (already loaded)
|
36 |
+
or a path to a file on disk."""
|
37 |
+
if isinstance(path_or_package, dict):
|
38 |
+
package = path_or_package
|
39 |
+
elif isinstance(path_or_package, (str, Path)):
|
40 |
+
with warnings.catch_warnings():
|
41 |
+
warnings.simplefilter("ignore")
|
42 |
+
path = path_or_package
|
43 |
+
package = torch.load(path, "cpu", weights_only=False)
|
44 |
+
else:
|
45 |
+
raise ValueError(f"Invalid type for {path_or_package}.")
|
46 |
+
|
47 |
+
klass = package["klass"]
|
48 |
+
args = package["args"]
|
49 |
+
kwargs = package["kwargs"]
|
50 |
+
|
51 |
+
if strict:
|
52 |
+
model = klass(*args, **kwargs)
|
53 |
+
else:
|
54 |
+
sig = inspect.signature(klass)
|
55 |
+
for key in list(kwargs):
|
56 |
+
if key not in sig.parameters:
|
57 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
58 |
+
del kwargs[key]
|
59 |
+
model = klass(*args, **kwargs)
|
60 |
+
|
61 |
+
state = package["state"]
|
62 |
+
|
63 |
+
set_state(model, state)
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def get_state(model, quantizer, half=False):
|
68 |
+
"""Get the state from a model, potentially with quantization applied.
|
69 |
+
If `half` is True, model are stored as half precision, which shouldn't impact performance
|
70 |
+
but half the state size."""
|
71 |
+
if quantizer is None:
|
72 |
+
dtype = torch.half if half else None
|
73 |
+
state = {k: p.data.to(device="cpu", dtype=dtype) for k, p in model.state_dict().items()}
|
74 |
+
else:
|
75 |
+
state = quantizer.get_quantized_state()
|
76 |
+
state["__quantized"] = True
|
77 |
+
return state
|
78 |
+
|
79 |
+
|
80 |
+
def set_state(model, state, quantizer=None):
|
81 |
+
"""Set the state on a given model."""
|
82 |
+
if state.get("__quantized"):
|
83 |
+
if quantizer is not None:
|
84 |
+
quantizer.restore_quantized_state(model, state["quantized"])
|
85 |
+
else:
|
86 |
+
restore_quantized_state(model, state)
|
87 |
+
else:
|
88 |
+
model.load_state_dict(state)
|
89 |
+
return state
|
90 |
+
|
91 |
+
|
92 |
+
def save_with_checksum(content, path):
|
93 |
+
"""Save the given value on disk, along with a sha256 hash.
|
94 |
+
Should be used with the output of either `serialize_model` or `get_state`."""
|
95 |
+
buf = io.BytesIO()
|
96 |
+
torch.save(content, buf)
|
97 |
+
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
98 |
+
|
99 |
+
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
100 |
+
path.write_bytes(buf.getvalue())
|
101 |
+
|
102 |
+
|
103 |
+
def copy_state(state):
|
104 |
+
return {k: v.cpu().clone() for k, v in state.items()}
|
105 |
+
|
106 |
+
|
107 |
+
@contextmanager
|
108 |
+
def swap_state(model, state):
|
109 |
+
"""
|
110 |
+
Context manager that swaps the state of a model, e.g:
|
111 |
+
|
112 |
+
# model is in old state
|
113 |
+
with swap_state(model, new_state):
|
114 |
+
# model in new state
|
115 |
+
# model back to old state
|
116 |
+
"""
|
117 |
+
old_state = copy_state(model.state_dict())
|
118 |
+
model.load_state_dict(state, strict=False)
|
119 |
+
try:
|
120 |
+
yield
|
121 |
+
finally:
|
122 |
+
model.load_state_dict(old_state)
|
123 |
+
|
124 |
+
|
125 |
+
def capture_init(init):
|
126 |
+
@functools.wraps(init)
|
127 |
+
def __init__(self, *args, **kwargs):
|
128 |
+
self._init_args_kwargs = (args, kwargs)
|
129 |
+
init(self, *args, **kwargs)
|
130 |
+
|
131 |
+
return __init__
|
audio_separator/separator/uvr_lib_v5/demucs/tasnet.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Created on 2018/12
|
8 |
+
# Author: Kaituo XU
|
9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
10 |
+
# Here is the original license:
|
11 |
+
# The MIT License (MIT)
|
12 |
+
#
|
13 |
+
# Copyright (c) 2018 Kaituo XU
|
14 |
+
#
|
15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
17 |
+
# in the Software without restriction, including without limitation the rights
|
18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
20 |
+
# furnished to do so, subject to the following conditions:
|
21 |
+
#
|
22 |
+
# The above copyright notice and this permission notice shall be included in all
|
23 |
+
# copies or substantial portions of the Software.
|
24 |
+
#
|
25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
31 |
+
# SOFTWARE.
|
32 |
+
|
33 |
+
import math
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.nn as nn
|
37 |
+
import torch.nn.functional as F
|
38 |
+
|
39 |
+
from .utils import capture_init
|
40 |
+
|
41 |
+
EPS = 1e-8
|
42 |
+
|
43 |
+
|
44 |
+
def overlap_and_add(signal, frame_step):
|
45 |
+
outer_dimensions = signal.size()[:-2]
|
46 |
+
frames, frame_length = signal.size()[-2:]
|
47 |
+
|
48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
49 |
+
subframe_step = frame_step // subframe_length
|
50 |
+
subframes_per_frame = frame_length // subframe_length
|
51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
52 |
+
output_subframes = output_size // subframe_length
|
53 |
+
|
54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
55 |
+
|
56 |
+
frame = torch.arange(0, output_subframes, device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
57 |
+
frame = frame.long() # signal may in GPU or CPU
|
58 |
+
frame = frame.contiguous().view(-1)
|
59 |
+
|
60 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
61 |
+
result.index_add_(-2, frame, subframe_signal)
|
62 |
+
result = result.view(*outer_dimensions, -1)
|
63 |
+
return result
|
64 |
+
|
65 |
+
|
66 |
+
class ConvTasNet(nn.Module):
|
67 |
+
@capture_init
|
68 |
+
def __init__(self, N=256, L=20, B=256, H=512, P=3, X=8, R=4, C=4, audio_channels=1, samplerate=44100, norm_type="gLN", causal=False, mask_nonlinear="relu"):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
N: Number of filters in autoencoder
|
72 |
+
L: Length of the filters (in samples)
|
73 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
74 |
+
H: Number of channels in convolutional blocks
|
75 |
+
P: Kernel size in convolutional blocks
|
76 |
+
X: Number of convolutional blocks in each repeat
|
77 |
+
R: Number of repeats
|
78 |
+
C: Number of speakers
|
79 |
+
norm_type: BN, gLN, cLN
|
80 |
+
causal: causal or non-causal
|
81 |
+
mask_nonlinear: use which non-linear function to generate mask
|
82 |
+
"""
|
83 |
+
super(ConvTasNet, self).__init__()
|
84 |
+
# Hyper-parameter
|
85 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
|
86 |
+
self.norm_type = norm_type
|
87 |
+
self.causal = causal
|
88 |
+
self.mask_nonlinear = mask_nonlinear
|
89 |
+
self.audio_channels = audio_channels
|
90 |
+
self.samplerate = samplerate
|
91 |
+
# Components
|
92 |
+
self.encoder = Encoder(L, N, audio_channels)
|
93 |
+
self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
|
94 |
+
self.decoder = Decoder(N, L, audio_channels)
|
95 |
+
# init
|
96 |
+
for p in self.parameters():
|
97 |
+
if p.dim() > 1:
|
98 |
+
nn.init.xavier_normal_(p)
|
99 |
+
|
100 |
+
def valid_length(self, length):
|
101 |
+
return length
|
102 |
+
|
103 |
+
def forward(self, mixture):
|
104 |
+
"""
|
105 |
+
Args:
|
106 |
+
mixture: [M, T], M is batch size, T is #samples
|
107 |
+
Returns:
|
108 |
+
est_source: [M, C, T]
|
109 |
+
"""
|
110 |
+
mixture_w = self.encoder(mixture)
|
111 |
+
est_mask = self.separator(mixture_w)
|
112 |
+
est_source = self.decoder(mixture_w, est_mask)
|
113 |
+
|
114 |
+
# T changed after conv1d in encoder, fix it here
|
115 |
+
T_origin = mixture.size(-1)
|
116 |
+
T_conv = est_source.size(-1)
|
117 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
118 |
+
return est_source
|
119 |
+
|
120 |
+
|
121 |
+
class Encoder(nn.Module):
|
122 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer."""
|
123 |
+
|
124 |
+
def __init__(self, L, N, audio_channels):
|
125 |
+
super(Encoder, self).__init__()
|
126 |
+
# Hyper-parameter
|
127 |
+
self.L, self.N = L, N
|
128 |
+
# Components
|
129 |
+
# 50% overlap
|
130 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
131 |
+
|
132 |
+
def forward(self, mixture):
|
133 |
+
"""
|
134 |
+
Args:
|
135 |
+
mixture: [M, T], M is batch size, T is #samples
|
136 |
+
Returns:
|
137 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
138 |
+
"""
|
139 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
140 |
+
return mixture_w
|
141 |
+
|
142 |
+
|
143 |
+
class Decoder(nn.Module):
|
144 |
+
def __init__(self, N, L, audio_channels):
|
145 |
+
super(Decoder, self).__init__()
|
146 |
+
# Hyper-parameter
|
147 |
+
self.N, self.L = N, L
|
148 |
+
self.audio_channels = audio_channels
|
149 |
+
# Components
|
150 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
151 |
+
|
152 |
+
def forward(self, mixture_w, est_mask):
|
153 |
+
"""
|
154 |
+
Args:
|
155 |
+
mixture_w: [M, N, K]
|
156 |
+
est_mask: [M, C, N, K]
|
157 |
+
Returns:
|
158 |
+
est_source: [M, C, T]
|
159 |
+
"""
|
160 |
+
# D = W * M
|
161 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
162 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
163 |
+
# S = DV
|
164 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
165 |
+
m, c, k, _ = est_source.size()
|
166 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
167 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
168 |
+
return est_source
|
169 |
+
|
170 |
+
|
171 |
+
class TemporalConvNet(nn.Module):
|
172 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"):
|
173 |
+
"""
|
174 |
+
Args:
|
175 |
+
N: Number of filters in autoencoder
|
176 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
177 |
+
H: Number of channels in convolutional blocks
|
178 |
+
P: Kernel size in convolutional blocks
|
179 |
+
X: Number of convolutional blocks in each repeat
|
180 |
+
R: Number of repeats
|
181 |
+
C: Number of speakers
|
182 |
+
norm_type: BN, gLN, cLN
|
183 |
+
causal: causal or non-causal
|
184 |
+
mask_nonlinear: use which non-linear function to generate mask
|
185 |
+
"""
|
186 |
+
super(TemporalConvNet, self).__init__()
|
187 |
+
# Hyper-parameter
|
188 |
+
self.C = C
|
189 |
+
self.mask_nonlinear = mask_nonlinear
|
190 |
+
# Components
|
191 |
+
# [M, N, K] -> [M, N, K]
|
192 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
193 |
+
# [M, N, K] -> [M, B, K]
|
194 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
195 |
+
# [M, B, K] -> [M, B, K]
|
196 |
+
repeats = []
|
197 |
+
for r in range(R):
|
198 |
+
blocks = []
|
199 |
+
for x in range(X):
|
200 |
+
dilation = 2**x
|
201 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
202 |
+
blocks += [TemporalBlock(B, H, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal)]
|
203 |
+
repeats += [nn.Sequential(*blocks)]
|
204 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
205 |
+
# [M, B, K] -> [M, C*N, K]
|
206 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
207 |
+
# Put together
|
208 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1)
|
209 |
+
|
210 |
+
def forward(self, mixture_w):
|
211 |
+
"""
|
212 |
+
Keep this API same with TasNet
|
213 |
+
Args:
|
214 |
+
mixture_w: [M, N, K], M is batch size
|
215 |
+
returns:
|
216 |
+
est_mask: [M, C, N, K]
|
217 |
+
"""
|
218 |
+
M, N, K = mixture_w.size()
|
219 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
220 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
221 |
+
if self.mask_nonlinear == "softmax":
|
222 |
+
est_mask = F.softmax(score, dim=1)
|
223 |
+
elif self.mask_nonlinear == "relu":
|
224 |
+
est_mask = F.relu(score)
|
225 |
+
else:
|
226 |
+
raise ValueError("Unsupported mask non-linear function")
|
227 |
+
return est_mask
|
228 |
+
|
229 |
+
|
230 |
+
class TemporalBlock(nn.Module):
|
231 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
232 |
+
super(TemporalBlock, self).__init__()
|
233 |
+
# [M, B, K] -> [M, H, K]
|
234 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
235 |
+
prelu = nn.PReLU()
|
236 |
+
norm = chose_norm(norm_type, out_channels)
|
237 |
+
# [M, H, K] -> [M, B, K]
|
238 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)
|
239 |
+
# Put together
|
240 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
"""
|
244 |
+
Args:
|
245 |
+
x: [M, B, K]
|
246 |
+
Returns:
|
247 |
+
[M, B, K]
|
248 |
+
"""
|
249 |
+
residual = x
|
250 |
+
out = self.net(x)
|
251 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
252 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
253 |
+
# return F.relu(out + residual)
|
254 |
+
|
255 |
+
|
256 |
+
class DepthwiseSeparableConv(nn.Module):
|
257 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
258 |
+
super(DepthwiseSeparableConv, self).__init__()
|
259 |
+
# Use `groups` option to implement depthwise convolution
|
260 |
+
# [M, H, K] -> [M, H, K]
|
261 |
+
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
|
262 |
+
if causal:
|
263 |
+
chomp = Chomp1d(padding)
|
264 |
+
prelu = nn.PReLU()
|
265 |
+
norm = chose_norm(norm_type, in_channels)
|
266 |
+
# [M, H, K] -> [M, B, K]
|
267 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
268 |
+
# Put together
|
269 |
+
if causal:
|
270 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
271 |
+
else:
|
272 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
"""
|
276 |
+
Args:
|
277 |
+
x: [M, H, K]
|
278 |
+
Returns:
|
279 |
+
result: [M, B, K]
|
280 |
+
"""
|
281 |
+
return self.net(x)
|
282 |
+
|
283 |
+
|
284 |
+
class Chomp1d(nn.Module):
|
285 |
+
"""To ensure the output length is the same as the input."""
|
286 |
+
|
287 |
+
def __init__(self, chomp_size):
|
288 |
+
super(Chomp1d, self).__init__()
|
289 |
+
self.chomp_size = chomp_size
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
"""
|
293 |
+
Args:
|
294 |
+
x: [M, H, Kpad]
|
295 |
+
Returns:
|
296 |
+
[M, H, K]
|
297 |
+
"""
|
298 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
299 |
+
|
300 |
+
|
301 |
+
def chose_norm(norm_type, channel_size):
|
302 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
303 |
+
C is channel size and K is sequence length.
|
304 |
+
"""
|
305 |
+
if norm_type == "gLN":
|
306 |
+
return GlobalLayerNorm(channel_size)
|
307 |
+
elif norm_type == "cLN":
|
308 |
+
return ChannelwiseLayerNorm(channel_size)
|
309 |
+
elif norm_type == "id":
|
310 |
+
return nn.Identity()
|
311 |
+
else: # norm_type == "BN":
|
312 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
313 |
+
# along M and K, so this BN usage is right.
|
314 |
+
return nn.BatchNorm1d(channel_size)
|
315 |
+
|
316 |
+
|
317 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
318 |
+
class ChannelwiseLayerNorm(nn.Module):
|
319 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
320 |
+
|
321 |
+
def __init__(self, channel_size):
|
322 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
323 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
324 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
325 |
+
self.reset_parameters()
|
326 |
+
|
327 |
+
def reset_parameters(self):
|
328 |
+
self.gamma.data.fill_(1)
|
329 |
+
self.beta.data.zero_()
|
330 |
+
|
331 |
+
def forward(self, y):
|
332 |
+
"""
|
333 |
+
Args:
|
334 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
335 |
+
Returns:
|
336 |
+
cLN_y: [M, N, K]
|
337 |
+
"""
|
338 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
339 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
340 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
341 |
+
return cLN_y
|
342 |
+
|
343 |
+
|
344 |
+
class GlobalLayerNorm(nn.Module):
|
345 |
+
"""Global Layer Normalization (gLN)"""
|
346 |
+
|
347 |
+
def __init__(self, channel_size):
|
348 |
+
super(GlobalLayerNorm, self).__init__()
|
349 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
350 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
351 |
+
self.reset_parameters()
|
352 |
+
|
353 |
+
def reset_parameters(self):
|
354 |
+
self.gamma.data.fill_(1)
|
355 |
+
self.beta.data.zero_()
|
356 |
+
|
357 |
+
def forward(self, y):
|
358 |
+
"""
|
359 |
+
Args:
|
360 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
361 |
+
Returns:
|
362 |
+
gLN_y: [M, N, K]
|
363 |
+
"""
|
364 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
365 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
366 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
367 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
368 |
+
return gLN_y
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
torch.manual_seed(123)
|
373 |
+
M, N, L, T = 2, 3, 4, 12
|
374 |
+
K = 2 * T // L - 1
|
375 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
376 |
+
mixture = torch.randint(3, (M, T))
|
377 |
+
# test Encoder
|
378 |
+
encoder = Encoder(L, N)
|
379 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
380 |
+
mixture_w = encoder(mixture)
|
381 |
+
print("mixture", mixture)
|
382 |
+
print("U", encoder.conv1d_U.weight)
|
383 |
+
print("mixture_w", mixture_w)
|
384 |
+
print("mixture_w size", mixture_w.size())
|
385 |
+
|
386 |
+
# test TemporalConvNet
|
387 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
388 |
+
est_mask = separator(mixture_w)
|
389 |
+
print("est_mask", est_mask)
|
390 |
+
|
391 |
+
# test Decoder
|
392 |
+
decoder = Decoder(N, L)
|
393 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
394 |
+
est_source = decoder(mixture_w, est_mask)
|
395 |
+
print("est_source", est_source)
|
396 |
+
|
397 |
+
# test Conv-TasNet
|
398 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
399 |
+
est_source = conv_tasnet(mixture)
|
400 |
+
print("est_source", est_source)
|
401 |
+
print("est_source size", est_source.size())
|
audio_separator/separator/uvr_lib_v5/demucs/tasnet_v2.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Created on 2018/12
|
8 |
+
# Author: Kaituo XU
|
9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
10 |
+
# Here is the original license:
|
11 |
+
# The MIT License (MIT)
|
12 |
+
#
|
13 |
+
# Copyright (c) 2018 Kaituo XU
|
14 |
+
#
|
15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
17 |
+
# in the Software without restriction, including without limitation the rights
|
18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
20 |
+
# furnished to do so, subject to the following conditions:
|
21 |
+
#
|
22 |
+
# The above copyright notice and this permission notice shall be included in all
|
23 |
+
# copies or substantial portions of the Software.
|
24 |
+
#
|
25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
31 |
+
# SOFTWARE.
|
32 |
+
|
33 |
+
import math
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.nn as nn
|
37 |
+
import torch.nn.functional as F
|
38 |
+
|
39 |
+
from .utils import capture_init
|
40 |
+
|
41 |
+
EPS = 1e-8
|
42 |
+
|
43 |
+
|
44 |
+
def overlap_and_add(signal, frame_step):
|
45 |
+
outer_dimensions = signal.size()[:-2]
|
46 |
+
frames, frame_length = signal.size()[-2:]
|
47 |
+
|
48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
49 |
+
subframe_step = frame_step // subframe_length
|
50 |
+
subframes_per_frame = frame_length // subframe_length
|
51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
52 |
+
output_subframes = output_size // subframe_length
|
53 |
+
|
54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
55 |
+
|
56 |
+
frame = torch.arange(0, output_subframes, device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
57 |
+
frame = frame.long() # signal may in GPU or CPU
|
58 |
+
frame = frame.contiguous().view(-1)
|
59 |
+
|
60 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
61 |
+
result.index_add_(-2, frame, subframe_signal)
|
62 |
+
result = result.view(*outer_dimensions, -1)
|
63 |
+
return result
|
64 |
+
|
65 |
+
|
66 |
+
class ConvTasNet(nn.Module):
|
67 |
+
@capture_init
|
68 |
+
def __init__(self, sources, N=256, L=20, B=256, H=512, P=3, X=8, R=4, audio_channels=2, norm_type="gLN", causal=False, mask_nonlinear="relu", samplerate=44100, segment_length=44100 * 2 * 4):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
sources: list of sources
|
72 |
+
N: Number of filters in autoencoder
|
73 |
+
L: Length of the filters (in samples)
|
74 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
75 |
+
H: Number of channels in convolutional blocks
|
76 |
+
P: Kernel size in convolutional blocks
|
77 |
+
X: Number of convolutional blocks in each repeat
|
78 |
+
R: Number of repeats
|
79 |
+
norm_type: BN, gLN, cLN
|
80 |
+
causal: causal or non-causal
|
81 |
+
mask_nonlinear: use which non-linear function to generate mask
|
82 |
+
"""
|
83 |
+
super(ConvTasNet, self).__init__()
|
84 |
+
# Hyper-parameter
|
85 |
+
self.sources = sources
|
86 |
+
self.C = len(sources)
|
87 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
|
88 |
+
self.norm_type = norm_type
|
89 |
+
self.causal = causal
|
90 |
+
self.mask_nonlinear = mask_nonlinear
|
91 |
+
self.audio_channels = audio_channels
|
92 |
+
self.samplerate = samplerate
|
93 |
+
self.segment_length = segment_length
|
94 |
+
# Components
|
95 |
+
self.encoder = Encoder(L, N, audio_channels)
|
96 |
+
self.separator = TemporalConvNet(N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
|
97 |
+
self.decoder = Decoder(N, L, audio_channels)
|
98 |
+
# init
|
99 |
+
for p in self.parameters():
|
100 |
+
if p.dim() > 1:
|
101 |
+
nn.init.xavier_normal_(p)
|
102 |
+
|
103 |
+
def valid_length(self, length):
|
104 |
+
return length
|
105 |
+
|
106 |
+
def forward(self, mixture):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
mixture: [M, T], M is batch size, T is #samples
|
110 |
+
Returns:
|
111 |
+
est_source: [M, C, T]
|
112 |
+
"""
|
113 |
+
mixture_w = self.encoder(mixture)
|
114 |
+
est_mask = self.separator(mixture_w)
|
115 |
+
est_source = self.decoder(mixture_w, est_mask)
|
116 |
+
|
117 |
+
# T changed after conv1d in encoder, fix it here
|
118 |
+
T_origin = mixture.size(-1)
|
119 |
+
T_conv = est_source.size(-1)
|
120 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
121 |
+
return est_source
|
122 |
+
|
123 |
+
|
124 |
+
class Encoder(nn.Module):
|
125 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer."""
|
126 |
+
|
127 |
+
def __init__(self, L, N, audio_channels):
|
128 |
+
super(Encoder, self).__init__()
|
129 |
+
# Hyper-parameter
|
130 |
+
self.L, self.N = L, N
|
131 |
+
# Components
|
132 |
+
# 50% overlap
|
133 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
134 |
+
|
135 |
+
def forward(self, mixture):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
mixture: [M, T], M is batch size, T is #samples
|
139 |
+
Returns:
|
140 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
141 |
+
"""
|
142 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
143 |
+
return mixture_w
|
144 |
+
|
145 |
+
|
146 |
+
class Decoder(nn.Module):
|
147 |
+
def __init__(self, N, L, audio_channels):
|
148 |
+
super(Decoder, self).__init__()
|
149 |
+
# Hyper-parameter
|
150 |
+
self.N, self.L = N, L
|
151 |
+
self.audio_channels = audio_channels
|
152 |
+
# Components
|
153 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
154 |
+
|
155 |
+
def forward(self, mixture_w, est_mask):
|
156 |
+
"""
|
157 |
+
Args:
|
158 |
+
mixture_w: [M, N, K]
|
159 |
+
est_mask: [M, C, N, K]
|
160 |
+
Returns:
|
161 |
+
est_source: [M, C, T]
|
162 |
+
"""
|
163 |
+
# D = W * M
|
164 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
165 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
166 |
+
# S = DV
|
167 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
168 |
+
m, c, k, _ = est_source.size()
|
169 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
170 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
171 |
+
return est_source
|
172 |
+
|
173 |
+
|
174 |
+
class TemporalConvNet(nn.Module):
|
175 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"):
|
176 |
+
"""
|
177 |
+
Args:
|
178 |
+
N: Number of filters in autoencoder
|
179 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
180 |
+
H: Number of channels in convolutional blocks
|
181 |
+
P: Kernel size in convolutional blocks
|
182 |
+
X: Number of convolutional blocks in each repeat
|
183 |
+
R: Number of repeats
|
184 |
+
C: Number of speakers
|
185 |
+
norm_type: BN, gLN, cLN
|
186 |
+
causal: causal or non-causal
|
187 |
+
mask_nonlinear: use which non-linear function to generate mask
|
188 |
+
"""
|
189 |
+
super(TemporalConvNet, self).__init__()
|
190 |
+
# Hyper-parameter
|
191 |
+
self.C = C
|
192 |
+
self.mask_nonlinear = mask_nonlinear
|
193 |
+
# Components
|
194 |
+
# [M, N, K] -> [M, N, K]
|
195 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
196 |
+
# [M, N, K] -> [M, B, K]
|
197 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
198 |
+
# [M, B, K] -> [M, B, K]
|
199 |
+
repeats = []
|
200 |
+
for r in range(R):
|
201 |
+
blocks = []
|
202 |
+
for x in range(X):
|
203 |
+
dilation = 2**x
|
204 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
205 |
+
blocks += [TemporalBlock(B, H, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal)]
|
206 |
+
repeats += [nn.Sequential(*blocks)]
|
207 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
208 |
+
# [M, B, K] -> [M, C*N, K]
|
209 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
210 |
+
# Put together
|
211 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1)
|
212 |
+
|
213 |
+
def forward(self, mixture_w):
|
214 |
+
"""
|
215 |
+
Keep this API same with TasNet
|
216 |
+
Args:
|
217 |
+
mixture_w: [M, N, K], M is batch size
|
218 |
+
returns:
|
219 |
+
est_mask: [M, C, N, K]
|
220 |
+
"""
|
221 |
+
M, N, K = mixture_w.size()
|
222 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
223 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
224 |
+
if self.mask_nonlinear == "softmax":
|
225 |
+
est_mask = F.softmax(score, dim=1)
|
226 |
+
elif self.mask_nonlinear == "relu":
|
227 |
+
est_mask = F.relu(score)
|
228 |
+
else:
|
229 |
+
raise ValueError("Unsupported mask non-linear function")
|
230 |
+
return est_mask
|
231 |
+
|
232 |
+
|
233 |
+
class TemporalBlock(nn.Module):
|
234 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
235 |
+
super(TemporalBlock, self).__init__()
|
236 |
+
# [M, B, K] -> [M, H, K]
|
237 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
238 |
+
prelu = nn.PReLU()
|
239 |
+
norm = chose_norm(norm_type, out_channels)
|
240 |
+
# [M, H, K] -> [M, B, K]
|
241 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)
|
242 |
+
# Put together
|
243 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
x: [M, B, K]
|
249 |
+
Returns:
|
250 |
+
[M, B, K]
|
251 |
+
"""
|
252 |
+
residual = x
|
253 |
+
out = self.net(x)
|
254 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
255 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
256 |
+
# return F.relu(out + residual)
|
257 |
+
|
258 |
+
|
259 |
+
class DepthwiseSeparableConv(nn.Module):
|
260 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
261 |
+
super(DepthwiseSeparableConv, self).__init__()
|
262 |
+
# Use `groups` option to implement depthwise convolution
|
263 |
+
# [M, H, K] -> [M, H, K]
|
264 |
+
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
|
265 |
+
if causal:
|
266 |
+
chomp = Chomp1d(padding)
|
267 |
+
prelu = nn.PReLU()
|
268 |
+
norm = chose_norm(norm_type, in_channels)
|
269 |
+
# [M, H, K] -> [M, B, K]
|
270 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
271 |
+
# Put together
|
272 |
+
if causal:
|
273 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
274 |
+
else:
|
275 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
"""
|
279 |
+
Args:
|
280 |
+
x: [M, H, K]
|
281 |
+
Returns:
|
282 |
+
result: [M, B, K]
|
283 |
+
"""
|
284 |
+
return self.net(x)
|
285 |
+
|
286 |
+
|
287 |
+
class Chomp1d(nn.Module):
|
288 |
+
"""To ensure the output length is the same as the input."""
|
289 |
+
|
290 |
+
def __init__(self, chomp_size):
|
291 |
+
super(Chomp1d, self).__init__()
|
292 |
+
self.chomp_size = chomp_size
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
"""
|
296 |
+
Args:
|
297 |
+
x: [M, H, Kpad]
|
298 |
+
Returns:
|
299 |
+
[M, H, K]
|
300 |
+
"""
|
301 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
302 |
+
|
303 |
+
|
304 |
+
def chose_norm(norm_type, channel_size):
|
305 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
306 |
+
C is channel size and K is sequence length.
|
307 |
+
"""
|
308 |
+
if norm_type == "gLN":
|
309 |
+
return GlobalLayerNorm(channel_size)
|
310 |
+
elif norm_type == "cLN":
|
311 |
+
return ChannelwiseLayerNorm(channel_size)
|
312 |
+
elif norm_type == "id":
|
313 |
+
return nn.Identity()
|
314 |
+
else: # norm_type == "BN":
|
315 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
316 |
+
# along M and K, so this BN usage is right.
|
317 |
+
return nn.BatchNorm1d(channel_size)
|
318 |
+
|
319 |
+
|
320 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
321 |
+
class ChannelwiseLayerNorm(nn.Module):
|
322 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
323 |
+
|
324 |
+
def __init__(self, channel_size):
|
325 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
326 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
327 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
328 |
+
self.reset_parameters()
|
329 |
+
|
330 |
+
def reset_parameters(self):
|
331 |
+
self.gamma.data.fill_(1)
|
332 |
+
self.beta.data.zero_()
|
333 |
+
|
334 |
+
def forward(self, y):
|
335 |
+
"""
|
336 |
+
Args:
|
337 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
338 |
+
Returns:
|
339 |
+
cLN_y: [M, N, K]
|
340 |
+
"""
|
341 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
342 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
343 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
344 |
+
return cLN_y
|
345 |
+
|
346 |
+
|
347 |
+
class GlobalLayerNorm(nn.Module):
|
348 |
+
"""Global Layer Normalization (gLN)"""
|
349 |
+
|
350 |
+
def __init__(self, channel_size):
|
351 |
+
super(GlobalLayerNorm, self).__init__()
|
352 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
353 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
354 |
+
self.reset_parameters()
|
355 |
+
|
356 |
+
def reset_parameters(self):
|
357 |
+
self.gamma.data.fill_(1)
|
358 |
+
self.beta.data.zero_()
|
359 |
+
|
360 |
+
def forward(self, y):
|
361 |
+
"""
|
362 |
+
Args:
|
363 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
364 |
+
Returns:
|
365 |
+
gLN_y: [M, N, K]
|
366 |
+
"""
|
367 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
368 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
369 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
370 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
371 |
+
return gLN_y
|
372 |
+
|
373 |
+
|
374 |
+
if __name__ == "__main__":
|
375 |
+
torch.manual_seed(123)
|
376 |
+
M, N, L, T = 2, 3, 4, 12
|
377 |
+
K = 2 * T // L - 1
|
378 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
379 |
+
mixture = torch.randint(3, (M, T))
|
380 |
+
# test Encoder
|
381 |
+
encoder = Encoder(L, N)
|
382 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
383 |
+
mixture_w = encoder(mixture)
|
384 |
+
print("mixture", mixture)
|
385 |
+
print("U", encoder.conv1d_U.weight)
|
386 |
+
print("mixture_w", mixture_w)
|
387 |
+
print("mixture_w size", mixture_w.size())
|
388 |
+
|
389 |
+
# test TemporalConvNet
|
390 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
391 |
+
est_mask = separator(mixture_w)
|
392 |
+
print("est_mask", est_mask)
|
393 |
+
|
394 |
+
# test Decoder
|
395 |
+
decoder = Decoder(N, L)
|
396 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
397 |
+
est_source = decoder(mixture_w, est_mask)
|
398 |
+
print("est_source", est_source)
|
399 |
+
|
400 |
+
# test Conv-TasNet
|
401 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
402 |
+
est_source = conv_tasnet(mixture)
|
403 |
+
print("est_source", est_source)
|
404 |
+
print("est_source size", est_source.size())
|
audio_separator/separator/uvr_lib_v5/demucs/transformer.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-present, Meta, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# First author is Simon Rouard.
|
7 |
+
|
8 |
+
import random
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import numpy as np
|
15 |
+
import math
|
16 |
+
from einops import rearrange
|
17 |
+
|
18 |
+
|
19 |
+
def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000):
|
20 |
+
# We aim for TBC format
|
21 |
+
assert dim % 2 == 0
|
22 |
+
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
23 |
+
half_dim = dim // 2
|
24 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
25 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
26 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
30 |
+
"""
|
31 |
+
:param d_model: dimension of the model
|
32 |
+
:param height: height of the positions
|
33 |
+
:param width: width of the positions
|
34 |
+
:return: d_model*height*width position matrix
|
35 |
+
"""
|
36 |
+
if d_model % 4 != 0:
|
37 |
+
raise ValueError("Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model))
|
38 |
+
pe = torch.zeros(d_model, height, width)
|
39 |
+
# Each dimension use half of d_model
|
40 |
+
d_model = int(d_model / 2)
|
41 |
+
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
|
42 |
+
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
43 |
+
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
44 |
+
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
45 |
+
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
46 |
+
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
47 |
+
pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
48 |
+
|
49 |
+
return pe[None, :].to(device)
|
50 |
+
|
51 |
+
|
52 |
+
def create_sin_embedding_cape(
|
53 |
+
length: int,
|
54 |
+
dim: int,
|
55 |
+
batch_size: int,
|
56 |
+
mean_normalize: bool,
|
57 |
+
augment: bool, # True during training
|
58 |
+
max_global_shift: float = 0.0, # delta max
|
59 |
+
max_local_shift: float = 0.0, # epsilon max
|
60 |
+
max_scale: float = 1.0,
|
61 |
+
device: str = "cpu",
|
62 |
+
max_period: float = 10000.0,
|
63 |
+
):
|
64 |
+
# We aim for TBC format
|
65 |
+
assert dim % 2 == 0
|
66 |
+
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
|
67 |
+
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
|
68 |
+
if mean_normalize:
|
69 |
+
pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
70 |
+
|
71 |
+
if augment:
|
72 |
+
delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
|
73 |
+
delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
|
74 |
+
log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
|
75 |
+
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
76 |
+
|
77 |
+
pos = pos.to(device)
|
78 |
+
|
79 |
+
half_dim = dim // 2
|
80 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
81 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
82 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
|
83 |
+
|
84 |
+
|
85 |
+
def get_causal_mask(length):
|
86 |
+
pos = torch.arange(length)
|
87 |
+
return pos > pos[:, None]
|
88 |
+
|
89 |
+
|
90 |
+
def get_elementary_mask(T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device):
|
91 |
+
"""
|
92 |
+
When the input of the Decoder has length T1 and the output T2
|
93 |
+
The mask matrix has shape (T2, T1)
|
94 |
+
"""
|
95 |
+
assert mask_type in ["diag", "jmask", "random", "global"]
|
96 |
+
|
97 |
+
if mask_type == "global":
|
98 |
+
mask = torch.zeros(T2, T1, dtype=torch.bool)
|
99 |
+
mask[:, :global_window] = True
|
100 |
+
line_window = int(global_window * T2 / T1)
|
101 |
+
mask[:line_window, :] = True
|
102 |
+
|
103 |
+
if mask_type == "diag":
|
104 |
+
|
105 |
+
mask = torch.zeros(T2, T1, dtype=torch.bool)
|
106 |
+
rows = torch.arange(T2)[:, None]
|
107 |
+
cols = (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)).long().clamp(0, T1 - 1)
|
108 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
109 |
+
|
110 |
+
elif mask_type == "jmask":
|
111 |
+
mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
|
112 |
+
rows = torch.arange(T2 + 2)[:, None]
|
113 |
+
t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
|
114 |
+
t = (t * (t + 1) / 2).int()
|
115 |
+
t = torch.cat([-t.flip(0)[:-1], t])
|
116 |
+
cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
|
117 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
118 |
+
mask = mask[1:-1, 1:-1]
|
119 |
+
|
120 |
+
elif mask_type == "random":
|
121 |
+
gene = torch.Generator(device=device)
|
122 |
+
gene.manual_seed(mask_random_seed)
|
123 |
+
mask = torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) > sparsity
|
124 |
+
|
125 |
+
mask = mask.to(device)
|
126 |
+
return mask
|
127 |
+
|
128 |
+
|
129 |
+
def get_mask(T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device):
|
130 |
+
"""
|
131 |
+
Return a SparseCSRTensor mask that is a combination of elementary masks
|
132 |
+
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
|
133 |
+
"""
|
134 |
+
from xformers.sparse import SparseCSRTensor
|
135 |
+
|
136 |
+
# create a list
|
137 |
+
mask_types = mask_type.split("_")
|
138 |
+
|
139 |
+
all_masks = [get_elementary_mask(T1, T2, mask, sparse_attn_window, global_window, mask_random_seed, sparsity, device) for mask in mask_types]
|
140 |
+
|
141 |
+
final_mask = torch.stack(all_masks).sum(axis=0) > 0
|
142 |
+
|
143 |
+
return SparseCSRTensor.from_dense(final_mask[None])
|
144 |
+
|
145 |
+
|
146 |
+
class ScaledEmbedding(nn.Module):
|
147 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 1.0, boost: float = 3.0):
|
148 |
+
super().__init__()
|
149 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
150 |
+
self.embedding.weight.data *= scale / boost
|
151 |
+
self.boost = boost
|
152 |
+
|
153 |
+
@property
|
154 |
+
def weight(self):
|
155 |
+
return self.embedding.weight * self.boost
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
return self.embedding(x) * self.boost
|
159 |
+
|
160 |
+
|
161 |
+
class LayerScale(nn.Module):
|
162 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
163 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, channels: int, init: float = 0, channel_last=False):
|
167 |
+
"""
|
168 |
+
channel_last = False corresponds to (B, C, T) tensors
|
169 |
+
channel_last = True corresponds to (T, B, C) tensors
|
170 |
+
"""
|
171 |
+
super().__init__()
|
172 |
+
self.channel_last = channel_last
|
173 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
174 |
+
self.scale.data[:] = init
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
if self.channel_last:
|
178 |
+
return self.scale * x
|
179 |
+
else:
|
180 |
+
return self.scale[:, None] * x
|
181 |
+
|
182 |
+
|
183 |
+
class MyGroupNorm(nn.GroupNorm):
|
184 |
+
def __init__(self, *args, **kwargs):
|
185 |
+
super().__init__(*args, **kwargs)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
"""
|
189 |
+
x: (B, T, C)
|
190 |
+
if num_groups=1: Normalisation on all T and C together for each B
|
191 |
+
"""
|
192 |
+
x = x.transpose(1, 2)
|
193 |
+
return super().forward(x).transpose(1, 2)
|
194 |
+
|
195 |
+
|
196 |
+
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
d_model,
|
200 |
+
nhead,
|
201 |
+
dim_feedforward=2048,
|
202 |
+
dropout=0.1,
|
203 |
+
activation=F.relu,
|
204 |
+
group_norm=0,
|
205 |
+
norm_first=False,
|
206 |
+
norm_out=False,
|
207 |
+
layer_norm_eps=1e-5,
|
208 |
+
layer_scale=False,
|
209 |
+
init_values=1e-4,
|
210 |
+
device=None,
|
211 |
+
dtype=None,
|
212 |
+
sparse=False,
|
213 |
+
mask_type="diag",
|
214 |
+
mask_random_seed=42,
|
215 |
+
sparse_attn_window=500,
|
216 |
+
global_window=50,
|
217 |
+
auto_sparsity=False,
|
218 |
+
sparsity=0.95,
|
219 |
+
batch_first=False,
|
220 |
+
):
|
221 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
222 |
+
super().__init__(
|
223 |
+
d_model=d_model,
|
224 |
+
nhead=nhead,
|
225 |
+
dim_feedforward=dim_feedforward,
|
226 |
+
dropout=dropout,
|
227 |
+
activation=activation,
|
228 |
+
layer_norm_eps=layer_norm_eps,
|
229 |
+
batch_first=batch_first,
|
230 |
+
norm_first=norm_first,
|
231 |
+
device=device,
|
232 |
+
dtype=dtype,
|
233 |
+
)
|
234 |
+
self.sparse = sparse
|
235 |
+
self.auto_sparsity = auto_sparsity
|
236 |
+
if sparse:
|
237 |
+
if not auto_sparsity:
|
238 |
+
self.mask_type = mask_type
|
239 |
+
self.sparse_attn_window = sparse_attn_window
|
240 |
+
self.global_window = global_window
|
241 |
+
self.sparsity = sparsity
|
242 |
+
if group_norm:
|
243 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
244 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
245 |
+
|
246 |
+
self.norm_out = None
|
247 |
+
if self.norm_first & norm_out:
|
248 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
249 |
+
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
250 |
+
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
251 |
+
|
252 |
+
if sparse:
|
253 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0)
|
254 |
+
self.__setattr__("src_mask", torch.zeros(1, 1))
|
255 |
+
self.mask_random_seed = mask_random_seed
|
256 |
+
|
257 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
258 |
+
"""
|
259 |
+
if batch_first = False, src shape is (T, B, C)
|
260 |
+
the case where batch_first=True is not covered
|
261 |
+
"""
|
262 |
+
device = src.device
|
263 |
+
x = src
|
264 |
+
T, B, C = x.shape
|
265 |
+
if self.sparse and not self.auto_sparsity:
|
266 |
+
assert src_mask is None
|
267 |
+
src_mask = self.src_mask
|
268 |
+
if src_mask.shape[-1] != T:
|
269 |
+
src_mask = get_mask(T, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device)
|
270 |
+
self.__setattr__("src_mask", src_mask)
|
271 |
+
|
272 |
+
if self.norm_first:
|
273 |
+
x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
274 |
+
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
275 |
+
|
276 |
+
if self.norm_out:
|
277 |
+
x = self.norm_out(x)
|
278 |
+
else:
|
279 |
+
x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
|
280 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
281 |
+
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class CrossTransformerEncoderLayer(nn.Module):
|
286 |
+
def __init__(
|
287 |
+
self,
|
288 |
+
d_model: int,
|
289 |
+
nhead: int,
|
290 |
+
dim_feedforward: int = 2048,
|
291 |
+
dropout: float = 0.1,
|
292 |
+
activation=F.relu,
|
293 |
+
layer_norm_eps: float = 1e-5,
|
294 |
+
layer_scale: bool = False,
|
295 |
+
init_values: float = 1e-4,
|
296 |
+
norm_first: bool = False,
|
297 |
+
group_norm: bool = False,
|
298 |
+
norm_out: bool = False,
|
299 |
+
sparse=False,
|
300 |
+
mask_type="diag",
|
301 |
+
mask_random_seed=42,
|
302 |
+
sparse_attn_window=500,
|
303 |
+
global_window=50,
|
304 |
+
sparsity=0.95,
|
305 |
+
auto_sparsity=None,
|
306 |
+
device=None,
|
307 |
+
dtype=None,
|
308 |
+
batch_first=False,
|
309 |
+
):
|
310 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
311 |
+
super().__init__()
|
312 |
+
|
313 |
+
self.sparse = sparse
|
314 |
+
self.auto_sparsity = auto_sparsity
|
315 |
+
if sparse:
|
316 |
+
if not auto_sparsity:
|
317 |
+
self.mask_type = mask_type
|
318 |
+
self.sparse_attn_window = sparse_attn_window
|
319 |
+
self.global_window = global_window
|
320 |
+
self.sparsity = sparsity
|
321 |
+
|
322 |
+
self.cross_attn: nn.Module
|
323 |
+
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
324 |
+
# Implementation of Feedforward model
|
325 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
326 |
+
self.dropout = nn.Dropout(dropout)
|
327 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
328 |
+
|
329 |
+
self.norm_first = norm_first
|
330 |
+
self.norm1: nn.Module
|
331 |
+
self.norm2: nn.Module
|
332 |
+
self.norm3: nn.Module
|
333 |
+
if group_norm:
|
334 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
335 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
336 |
+
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
337 |
+
else:
|
338 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
339 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
340 |
+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
341 |
+
|
342 |
+
self.norm_out = None
|
343 |
+
if self.norm_first & norm_out:
|
344 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
345 |
+
|
346 |
+
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
347 |
+
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
348 |
+
|
349 |
+
self.dropout1 = nn.Dropout(dropout)
|
350 |
+
self.dropout2 = nn.Dropout(dropout)
|
351 |
+
|
352 |
+
# Legacy string support for activation function.
|
353 |
+
if isinstance(activation, str):
|
354 |
+
self.activation = self._get_activation_fn(activation)
|
355 |
+
else:
|
356 |
+
self.activation = activation
|
357 |
+
|
358 |
+
if sparse:
|
359 |
+
self.cross_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0)
|
360 |
+
if not auto_sparsity:
|
361 |
+
self.__setattr__("mask", torch.zeros(1, 1))
|
362 |
+
self.mask_random_seed = mask_random_seed
|
363 |
+
|
364 |
+
def forward(self, q, k, mask=None):
|
365 |
+
"""
|
366 |
+
Args:
|
367 |
+
q: tensor of shape (T, B, C)
|
368 |
+
k: tensor of shape (S, B, C)
|
369 |
+
mask: tensor of shape (T, S)
|
370 |
+
|
371 |
+
"""
|
372 |
+
device = q.device
|
373 |
+
T, B, C = q.shape
|
374 |
+
S, B, C = k.shape
|
375 |
+
if self.sparse and not self.auto_sparsity:
|
376 |
+
assert mask is None
|
377 |
+
mask = self.mask
|
378 |
+
if mask.shape[-1] != S or mask.shape[-2] != T:
|
379 |
+
mask = get_mask(S, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device)
|
380 |
+
self.__setattr__("mask", mask)
|
381 |
+
|
382 |
+
if self.norm_first:
|
383 |
+
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
384 |
+
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
385 |
+
if self.norm_out:
|
386 |
+
x = self.norm_out(x)
|
387 |
+
else:
|
388 |
+
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
389 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
390 |
+
|
391 |
+
return x
|
392 |
+
|
393 |
+
# self-attention block
|
394 |
+
def _ca_block(self, q, k, attn_mask=None):
|
395 |
+
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
396 |
+
return self.dropout1(x)
|
397 |
+
|
398 |
+
# feed forward block
|
399 |
+
def _ff_block(self, x):
|
400 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
401 |
+
return self.dropout2(x)
|
402 |
+
|
403 |
+
def _get_activation_fn(self, activation):
|
404 |
+
if activation == "relu":
|
405 |
+
return F.relu
|
406 |
+
elif activation == "gelu":
|
407 |
+
return F.gelu
|
408 |
+
|
409 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
410 |
+
|
411 |
+
|
412 |
+
# ----------------- MULTI-BLOCKS MODELS: -----------------------
|
413 |
+
|
414 |
+
|
415 |
+
class CrossTransformerEncoder(nn.Module):
|
416 |
+
def __init__(
|
417 |
+
self,
|
418 |
+
dim: int,
|
419 |
+
emb: str = "sin",
|
420 |
+
hidden_scale: float = 4.0,
|
421 |
+
num_heads: int = 8,
|
422 |
+
num_layers: int = 6,
|
423 |
+
cross_first: bool = False,
|
424 |
+
dropout: float = 0.0,
|
425 |
+
max_positions: int = 1000,
|
426 |
+
norm_in: bool = True,
|
427 |
+
norm_in_group: bool = False,
|
428 |
+
group_norm: int = False,
|
429 |
+
norm_first: bool = False,
|
430 |
+
norm_out: bool = False,
|
431 |
+
max_period: float = 10000.0,
|
432 |
+
weight_decay: float = 0.0,
|
433 |
+
lr: tp.Optional[float] = None,
|
434 |
+
layer_scale: bool = False,
|
435 |
+
gelu: bool = True,
|
436 |
+
sin_random_shift: int = 0,
|
437 |
+
weight_pos_embed: float = 1.0,
|
438 |
+
cape_mean_normalize: bool = True,
|
439 |
+
cape_augment: bool = True,
|
440 |
+
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
|
441 |
+
sparse_self_attn: bool = False,
|
442 |
+
sparse_cross_attn: bool = False,
|
443 |
+
mask_type: str = "diag",
|
444 |
+
mask_random_seed: int = 42,
|
445 |
+
sparse_attn_window: int = 500,
|
446 |
+
global_window: int = 50,
|
447 |
+
auto_sparsity: bool = False,
|
448 |
+
sparsity: float = 0.95,
|
449 |
+
):
|
450 |
+
super().__init__()
|
451 |
+
"""
|
452 |
+
"""
|
453 |
+
assert dim % num_heads == 0
|
454 |
+
|
455 |
+
hidden_dim = int(dim * hidden_scale)
|
456 |
+
|
457 |
+
self.num_layers = num_layers
|
458 |
+
# classic parity = 1 means that if idx%2 == 1 there is a
|
459 |
+
# classical encoder else there is a cross encoder
|
460 |
+
self.classic_parity = 1 if cross_first else 0
|
461 |
+
self.emb = emb
|
462 |
+
self.max_period = max_period
|
463 |
+
self.weight_decay = weight_decay
|
464 |
+
self.weight_pos_embed = weight_pos_embed
|
465 |
+
self.sin_random_shift = sin_random_shift
|
466 |
+
if emb == "cape":
|
467 |
+
self.cape_mean_normalize = cape_mean_normalize
|
468 |
+
self.cape_augment = cape_augment
|
469 |
+
self.cape_glob_loc_scale = cape_glob_loc_scale
|
470 |
+
if emb == "scaled":
|
471 |
+
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
472 |
+
|
473 |
+
self.lr = lr
|
474 |
+
|
475 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
476 |
+
|
477 |
+
self.norm_in: nn.Module
|
478 |
+
self.norm_in_t: nn.Module
|
479 |
+
if norm_in:
|
480 |
+
self.norm_in = nn.LayerNorm(dim)
|
481 |
+
self.norm_in_t = nn.LayerNorm(dim)
|
482 |
+
elif norm_in_group:
|
483 |
+
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
484 |
+
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
485 |
+
else:
|
486 |
+
self.norm_in = nn.Identity()
|
487 |
+
self.norm_in_t = nn.Identity()
|
488 |
+
|
489 |
+
# spectrogram layers
|
490 |
+
self.layers = nn.ModuleList()
|
491 |
+
# temporal layers
|
492 |
+
self.layers_t = nn.ModuleList()
|
493 |
+
|
494 |
+
kwargs_common = {
|
495 |
+
"d_model": dim,
|
496 |
+
"nhead": num_heads,
|
497 |
+
"dim_feedforward": hidden_dim,
|
498 |
+
"dropout": dropout,
|
499 |
+
"activation": activation,
|
500 |
+
"group_norm": group_norm,
|
501 |
+
"norm_first": norm_first,
|
502 |
+
"norm_out": norm_out,
|
503 |
+
"layer_scale": layer_scale,
|
504 |
+
"mask_type": mask_type,
|
505 |
+
"mask_random_seed": mask_random_seed,
|
506 |
+
"sparse_attn_window": sparse_attn_window,
|
507 |
+
"global_window": global_window,
|
508 |
+
"sparsity": sparsity,
|
509 |
+
"auto_sparsity": auto_sparsity,
|
510 |
+
"batch_first": True,
|
511 |
+
}
|
512 |
+
|
513 |
+
kwargs_classic_encoder = dict(kwargs_common)
|
514 |
+
kwargs_classic_encoder.update({"sparse": sparse_self_attn})
|
515 |
+
kwargs_cross_encoder = dict(kwargs_common)
|
516 |
+
kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
|
517 |
+
|
518 |
+
for idx in range(num_layers):
|
519 |
+
if idx % 2 == self.classic_parity:
|
520 |
+
|
521 |
+
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
522 |
+
self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
523 |
+
|
524 |
+
else:
|
525 |
+
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
526 |
+
|
527 |
+
self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
528 |
+
|
529 |
+
def forward(self, x, xt):
|
530 |
+
B, C, Fr, T1 = x.shape
|
531 |
+
pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period) # (1, C, Fr, T1)
|
532 |
+
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
533 |
+
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
534 |
+
x = self.norm_in(x)
|
535 |
+
x = x + self.weight_pos_embed * pos_emb_2d
|
536 |
+
|
537 |
+
B, C, T2 = xt.shape
|
538 |
+
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
|
539 |
+
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
|
540 |
+
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
541 |
+
xt = self.norm_in_t(xt)
|
542 |
+
xt = xt + self.weight_pos_embed * pos_emb
|
543 |
+
|
544 |
+
for idx in range(self.num_layers):
|
545 |
+
if idx % 2 == self.classic_parity:
|
546 |
+
x = self.layers[idx](x)
|
547 |
+
xt = self.layers_t[idx](xt)
|
548 |
+
else:
|
549 |
+
old_x = x
|
550 |
+
x = self.layers[idx](x, xt)
|
551 |
+
xt = self.layers_t[idx](xt, old_x)
|
552 |
+
|
553 |
+
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
|
554 |
+
xt = rearrange(xt, "b t2 c -> b c t2")
|
555 |
+
return x, xt
|
556 |
+
|
557 |
+
def _get_pos_embedding(self, T, B, C, device):
|
558 |
+
if self.emb == "sin":
|
559 |
+
shift = random.randrange(self.sin_random_shift + 1)
|
560 |
+
pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
|
561 |
+
elif self.emb == "cape":
|
562 |
+
if self.training:
|
563 |
+
pos_emb = create_sin_embedding_cape(
|
564 |
+
T,
|
565 |
+
C,
|
566 |
+
B,
|
567 |
+
device=device,
|
568 |
+
max_period=self.max_period,
|
569 |
+
mean_normalize=self.cape_mean_normalize,
|
570 |
+
augment=self.cape_augment,
|
571 |
+
max_global_shift=self.cape_glob_loc_scale[0],
|
572 |
+
max_local_shift=self.cape_glob_loc_scale[1],
|
573 |
+
max_scale=self.cape_glob_loc_scale[2],
|
574 |
+
)
|
575 |
+
else:
|
576 |
+
pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
|
577 |
+
|
578 |
+
elif self.emb == "scaled":
|
579 |
+
pos = torch.arange(T, device=device)
|
580 |
+
pos_emb = self.position_embeddings(pos)[:, None]
|
581 |
+
|
582 |
+
return pos_emb
|
583 |
+
|
584 |
+
def make_optim_group(self):
|
585 |
+
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
586 |
+
if self.lr is not None:
|
587 |
+
group["lr"] = self.lr
|
588 |
+
return group
|
589 |
+
|
590 |
+
|
591 |
+
# Attention Modules
|
592 |
+
|
593 |
+
|
594 |
+
class MultiheadAttention(nn.Module):
|
595 |
+
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, auto_sparsity=None):
|
596 |
+
super().__init__()
|
597 |
+
assert auto_sparsity is not None, "sanity check"
|
598 |
+
self.num_heads = num_heads
|
599 |
+
self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
600 |
+
self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
601 |
+
self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
602 |
+
self.attn_drop = torch.nn.Dropout(dropout)
|
603 |
+
self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
|
604 |
+
self.proj_drop = torch.nn.Dropout(dropout)
|
605 |
+
self.batch_first = batch_first
|
606 |
+
self.auto_sparsity = auto_sparsity
|
607 |
+
|
608 |
+
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True):
|
609 |
+
|
610 |
+
if not self.batch_first: # N, B, C
|
611 |
+
query = query.permute(1, 0, 2) # B, N_q, C
|
612 |
+
key = key.permute(1, 0, 2) # B, N_k, C
|
613 |
+
value = value.permute(1, 0, 2) # B, N_k, C
|
614 |
+
B, N_q, C = query.shape
|
615 |
+
B, N_k, C = key.shape
|
616 |
+
|
617 |
+
q = self.q(query).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
618 |
+
q = q.flatten(0, 1)
|
619 |
+
k = self.k(key).reshape(B, N_k, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
620 |
+
k = k.flatten(0, 1)
|
621 |
+
v = self.v(value).reshape(B, N_k, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
622 |
+
v = v.flatten(0, 1)
|
623 |
+
|
624 |
+
if self.auto_sparsity:
|
625 |
+
assert attn_mask is None
|
626 |
+
x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
|
627 |
+
else:
|
628 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
|
629 |
+
x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
|
630 |
+
|
631 |
+
x = x.transpose(1, 2).reshape(B, N_q, C)
|
632 |
+
x = self.proj(x)
|
633 |
+
x = self.proj_drop(x)
|
634 |
+
if not self.batch_first:
|
635 |
+
x = x.permute(1, 0, 2)
|
636 |
+
return x, None
|
637 |
+
|
638 |
+
|
639 |
+
def scaled_query_key_softmax(q, k, att_mask):
|
640 |
+
from xformers.ops import masked_matmul
|
641 |
+
|
642 |
+
q = q / (k.size(-1)) ** 0.5
|
643 |
+
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
|
644 |
+
att = torch.nn.functional.softmax(att, -1)
|
645 |
+
return att
|
646 |
+
|
647 |
+
|
648 |
+
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
|
649 |
+
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
650 |
+
att = dropout(att)
|
651 |
+
y = att @ v
|
652 |
+
return y
|
653 |
+
|
654 |
+
|
655 |
+
def _compute_buckets(x, R):
|
656 |
+
qq = torch.einsum("btf,bfhi->bhti", x, R)
|
657 |
+
qq = torch.cat([qq, -qq], dim=-1)
|
658 |
+
buckets = qq.argmax(dim=-1)
|
659 |
+
|
660 |
+
return buckets.permute(0, 2, 1).byte().contiguous()
|
661 |
+
|
662 |
+
|
663 |
+
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
|
664 |
+
# assert False, "The code for the custom sparse kernel is not ready for release yet."
|
665 |
+
from xformers.ops import find_locations, sparse_memory_efficient_attention
|
666 |
+
|
667 |
+
n_hashes = 32
|
668 |
+
proj_size = 4
|
669 |
+
query, key, value = [x.contiguous() for x in [query, key, value]]
|
670 |
+
with torch.no_grad():
|
671 |
+
R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
|
672 |
+
bucket_query = _compute_buckets(query, R)
|
673 |
+
bucket_key = _compute_buckets(key, R)
|
674 |
+
row_offsets, column_indices = find_locations(bucket_query, bucket_key, sparsity, infer_sparsity)
|
675 |
+
return sparse_memory_efficient_attention(query, key, value, row_offsets, column_indices, attn_bias)
|
audio_separator/separator/uvr_lib_v5/demucs/utils.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
from contextlib import contextmanager
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import errno
|
15 |
+
import functools
|
16 |
+
import hashlib
|
17 |
+
import inspect
|
18 |
+
import io
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import socket
|
22 |
+
import tempfile
|
23 |
+
import warnings
|
24 |
+
import zlib
|
25 |
+
|
26 |
+
from diffq import UniformQuantizer, DiffQuantizer
|
27 |
+
import torch as th
|
28 |
+
import tqdm
|
29 |
+
from torch import distributed
|
30 |
+
from torch.nn import functional as F
|
31 |
+
|
32 |
+
import torch
|
33 |
+
|
34 |
+
|
35 |
+
def unfold(a, kernel_size, stride):
|
36 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
37 |
+
with K the kernel size, by extracting frames with the given stride.
|
38 |
+
|
39 |
+
This will pad the input so that `F = ceil(T / K)`.
|
40 |
+
|
41 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
42 |
+
"""
|
43 |
+
*shape, length = a.shape
|
44 |
+
n_frames = math.ceil(length / stride)
|
45 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
46 |
+
a = F.pad(a, (0, tgt_length - length))
|
47 |
+
strides = list(a.stride())
|
48 |
+
assert strides[-1] == 1, "data should be contiguous"
|
49 |
+
strides = strides[:-1] + [stride, 1]
|
50 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
51 |
+
|
52 |
+
|
53 |
+
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
54 |
+
"""
|
55 |
+
Center trim `tensor` with respect to `reference`, along the last dimension.
|
56 |
+
`reference` can also be a number, representing the length to trim to.
|
57 |
+
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
58 |
+
"""
|
59 |
+
ref_size: int
|
60 |
+
if isinstance(reference, torch.Tensor):
|
61 |
+
ref_size = reference.size(-1)
|
62 |
+
else:
|
63 |
+
ref_size = reference
|
64 |
+
delta = tensor.size(-1) - ref_size
|
65 |
+
if delta < 0:
|
66 |
+
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
67 |
+
if delta:
|
68 |
+
tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
|
69 |
+
return tensor
|
70 |
+
|
71 |
+
|
72 |
+
def pull_metric(history: tp.List[dict], name: str):
|
73 |
+
out = []
|
74 |
+
for metrics in history:
|
75 |
+
metric = metrics
|
76 |
+
for part in name.split("."):
|
77 |
+
metric = metric[part]
|
78 |
+
out.append(metric)
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
def EMA(beta: float = 1):
|
83 |
+
"""
|
84 |
+
Exponential Moving Average callback.
|
85 |
+
Returns a single function that can be called to repeatidly update the EMA
|
86 |
+
with a dict of metrics. The callback will return
|
87 |
+
the new averaged dict of metrics.
|
88 |
+
|
89 |
+
Note that for `beta=1`, this is just plain averaging.
|
90 |
+
"""
|
91 |
+
fix: tp.Dict[str, float] = defaultdict(float)
|
92 |
+
total: tp.Dict[str, float] = defaultdict(float)
|
93 |
+
|
94 |
+
def _update(metrics: dict, weight: float = 1) -> dict:
|
95 |
+
nonlocal total, fix
|
96 |
+
for key, value in metrics.items():
|
97 |
+
total[key] = total[key] * beta + weight * float(value)
|
98 |
+
fix[key] = fix[key] * beta + weight
|
99 |
+
return {key: tot / fix[key] for key, tot in total.items()}
|
100 |
+
|
101 |
+
return _update
|
102 |
+
|
103 |
+
|
104 |
+
def sizeof_fmt(num: float, suffix: str = "B"):
|
105 |
+
"""
|
106 |
+
Given `num` bytes, return human readable size.
|
107 |
+
Taken from https://stackoverflow.com/a/1094933
|
108 |
+
"""
|
109 |
+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
|
110 |
+
if abs(num) < 1024.0:
|
111 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
112 |
+
num /= 1024.0
|
113 |
+
return "%.1f%s%s" % (num, "Yi", suffix)
|
114 |
+
|
115 |
+
|
116 |
+
@contextmanager
|
117 |
+
def temp_filenames(count: int, delete=True):
|
118 |
+
names = []
|
119 |
+
try:
|
120 |
+
for _ in range(count):
|
121 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
122 |
+
yield names
|
123 |
+
finally:
|
124 |
+
if delete:
|
125 |
+
for name in names:
|
126 |
+
os.unlink(name)
|
127 |
+
|
128 |
+
|
129 |
+
def average_metric(metric, count=1.0):
|
130 |
+
"""
|
131 |
+
Average `metric` which should be a float across all hosts. `count` should be
|
132 |
+
the weight for this particular host (i.e. number of examples).
|
133 |
+
"""
|
134 |
+
metric = th.tensor([count, count * metric], dtype=th.float32, device="cuda")
|
135 |
+
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
|
136 |
+
return metric[1].item() / metric[0].item()
|
137 |
+
|
138 |
+
|
139 |
+
def free_port(host="", low=20000, high=40000):
|
140 |
+
"""
|
141 |
+
Return a port number that is most likely free.
|
142 |
+
This could suffer from a race condition although
|
143 |
+
it should be quite rare.
|
144 |
+
"""
|
145 |
+
sock = socket.socket()
|
146 |
+
while True:
|
147 |
+
port = random.randint(low, high)
|
148 |
+
try:
|
149 |
+
sock.bind((host, port))
|
150 |
+
except OSError as error:
|
151 |
+
if error.errno == errno.EADDRINUSE:
|
152 |
+
continue
|
153 |
+
raise
|
154 |
+
return port
|
155 |
+
|
156 |
+
|
157 |
+
def sizeof_fmt(num, suffix="B"):
|
158 |
+
"""
|
159 |
+
Given `num` bytes, return human readable size.
|
160 |
+
Taken from https://stackoverflow.com/a/1094933
|
161 |
+
"""
|
162 |
+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
|
163 |
+
if abs(num) < 1024.0:
|
164 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
165 |
+
num /= 1024.0
|
166 |
+
return "%.1f%s%s" % (num, "Yi", suffix)
|
167 |
+
|
168 |
+
|
169 |
+
def human_seconds(seconds, display=".2f"):
|
170 |
+
"""
|
171 |
+
Given `seconds` seconds, return human readable duration.
|
172 |
+
"""
|
173 |
+
value = seconds * 1e6
|
174 |
+
ratios = [1e3, 1e3, 60, 60, 24]
|
175 |
+
names = ["us", "ms", "s", "min", "hrs", "days"]
|
176 |
+
last = names.pop(0)
|
177 |
+
for name, ratio in zip(names, ratios):
|
178 |
+
if value / ratio < 0.3:
|
179 |
+
break
|
180 |
+
value /= ratio
|
181 |
+
last = name
|
182 |
+
return f"{format(value, display)} {last}"
|
183 |
+
|
184 |
+
|
185 |
+
class TensorChunk:
|
186 |
+
def __init__(self, tensor, offset=0, length=None):
|
187 |
+
total_length = tensor.shape[-1]
|
188 |
+
assert offset >= 0
|
189 |
+
assert offset < total_length
|
190 |
+
|
191 |
+
if length is None:
|
192 |
+
length = total_length - offset
|
193 |
+
else:
|
194 |
+
length = min(total_length - offset, length)
|
195 |
+
|
196 |
+
self.tensor = tensor
|
197 |
+
self.offset = offset
|
198 |
+
self.length = length
|
199 |
+
self.device = tensor.device
|
200 |
+
|
201 |
+
@property
|
202 |
+
def shape(self):
|
203 |
+
shape = list(self.tensor.shape)
|
204 |
+
shape[-1] = self.length
|
205 |
+
return shape
|
206 |
+
|
207 |
+
def padded(self, target_length):
|
208 |
+
delta = target_length - self.length
|
209 |
+
total_length = self.tensor.shape[-1]
|
210 |
+
assert delta >= 0
|
211 |
+
|
212 |
+
start = self.offset - delta // 2
|
213 |
+
end = start + target_length
|
214 |
+
|
215 |
+
correct_start = max(0, start)
|
216 |
+
correct_end = min(total_length, end)
|
217 |
+
|
218 |
+
pad_left = correct_start - start
|
219 |
+
pad_right = end - correct_end
|
220 |
+
|
221 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
222 |
+
assert out.shape[-1] == target_length
|
223 |
+
return out
|
224 |
+
|
225 |
+
|
226 |
+
def tensor_chunk(tensor_or_chunk):
|
227 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
228 |
+
return tensor_or_chunk
|
229 |
+
else:
|
230 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
231 |
+
return TensorChunk(tensor_or_chunk)
|
232 |
+
|
233 |
+
|
234 |
+
def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
|
235 |
+
"""
|
236 |
+
Apply model to a given mixture.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
240 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
241 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
242 |
+
and improves SDR by up to 0.2 points.
|
243 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
244 |
+
and predictions will be performed individually on each and concatenated.
|
245 |
+
Useful for model with large memory footprint like Tasnet.
|
246 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
247 |
+
"""
|
248 |
+
|
249 |
+
channels, length = mix.size()
|
250 |
+
device = mix.device
|
251 |
+
progress_value = 0
|
252 |
+
|
253 |
+
if split:
|
254 |
+
out = th.zeros(4, channels, length, device=device)
|
255 |
+
shift = model.samplerate * 10
|
256 |
+
offsets = range(0, length, shift)
|
257 |
+
scale = 10
|
258 |
+
if progress:
|
259 |
+
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit="seconds")
|
260 |
+
for offset in offsets:
|
261 |
+
chunk = mix[..., offset : offset + shift]
|
262 |
+
if set_progress_bar:
|
263 |
+
progress_value += 1
|
264 |
+
set_progress_bar(0.1, (0.8 / len(offsets) * progress_value))
|
265 |
+
chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
|
266 |
+
else:
|
267 |
+
chunk_out = apply_model_v1(model, chunk, shifts=shifts)
|
268 |
+
out[..., offset : offset + shift] = chunk_out
|
269 |
+
offset += shift
|
270 |
+
return out
|
271 |
+
elif shifts:
|
272 |
+
max_shift = int(model.samplerate / 2)
|
273 |
+
mix = F.pad(mix, (max_shift, max_shift))
|
274 |
+
offsets = list(range(max_shift))
|
275 |
+
random.shuffle(offsets)
|
276 |
+
out = 0
|
277 |
+
for offset in offsets[:shifts]:
|
278 |
+
shifted = mix[..., offset : offset + length + max_shift]
|
279 |
+
if set_progress_bar:
|
280 |
+
shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
|
281 |
+
else:
|
282 |
+
shifted_out = apply_model_v1(model, shifted)
|
283 |
+
out += shifted_out[..., max_shift - offset : max_shift - offset + length]
|
284 |
+
out /= shifts
|
285 |
+
return out
|
286 |
+
else:
|
287 |
+
valid_length = model.valid_length(length)
|
288 |
+
delta = valid_length - length
|
289 |
+
padded = F.pad(mix, (delta // 2, delta - delta // 2))
|
290 |
+
with th.no_grad():
|
291 |
+
out = model(padded.unsqueeze(0))[0]
|
292 |
+
return center_trim(out, mix)
|
293 |
+
|
294 |
+
|
295 |
+
def apply_model_v2(model, mix, shifts=None, split=False, overlap=0.25, transition_power=1.0, progress=False, set_progress_bar=None):
|
296 |
+
"""
|
297 |
+
Apply model to a given mixture.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
301 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
302 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
303 |
+
and improves SDR by up to 0.2 points.
|
304 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
305 |
+
and predictions will be performed individually on each and concatenated.
|
306 |
+
Useful for model with large memory footprint like Tasnet.
|
307 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
308 |
+
"""
|
309 |
+
|
310 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
311 |
+
device = mix.device
|
312 |
+
channels, length = mix.shape
|
313 |
+
progress_value = 0
|
314 |
+
|
315 |
+
if split:
|
316 |
+
out = th.zeros(len(model.sources), channels, length, device=device)
|
317 |
+
sum_weight = th.zeros(length, device=device)
|
318 |
+
segment = model.segment_length
|
319 |
+
stride = int((1 - overlap) * segment)
|
320 |
+
offsets = range(0, length, stride)
|
321 |
+
scale = stride / model.samplerate
|
322 |
+
if progress:
|
323 |
+
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit="seconds")
|
324 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
325 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
326 |
+
# Large values of transition power will lead to sharper transitions.
|
327 |
+
weight = th.cat([th.arange(1, segment // 2 + 1), th.arange(segment - segment // 2, 0, -1)]).to(device)
|
328 |
+
assert len(weight) == segment
|
329 |
+
# If the overlap < 50%, this will translate to linear transition when
|
330 |
+
# transition_power is 1.
|
331 |
+
weight = (weight / weight.max()) ** transition_power
|
332 |
+
for offset in offsets:
|
333 |
+
chunk = TensorChunk(mix, offset, segment)
|
334 |
+
if set_progress_bar:
|
335 |
+
progress_value += 1
|
336 |
+
set_progress_bar(0.1, (0.8 / len(offsets) * progress_value))
|
337 |
+
chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
|
338 |
+
else:
|
339 |
+
chunk_out = apply_model_v2(model, chunk, shifts=shifts)
|
340 |
+
chunk_length = chunk_out.shape[-1]
|
341 |
+
out[..., offset : offset + segment] += weight[:chunk_length] * chunk_out
|
342 |
+
sum_weight[offset : offset + segment] += weight[:chunk_length]
|
343 |
+
offset += segment
|
344 |
+
assert sum_weight.min() > 0
|
345 |
+
out /= sum_weight
|
346 |
+
return out
|
347 |
+
elif shifts:
|
348 |
+
max_shift = int(0.5 * model.samplerate)
|
349 |
+
mix = tensor_chunk(mix)
|
350 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
351 |
+
out = 0
|
352 |
+
for _ in range(shifts):
|
353 |
+
offset = random.randint(0, max_shift)
|
354 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
355 |
+
|
356 |
+
if set_progress_bar:
|
357 |
+
progress_value += 1
|
358 |
+
shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
|
359 |
+
else:
|
360 |
+
shifted_out = apply_model_v2(model, shifted)
|
361 |
+
out += shifted_out[..., max_shift - offset :]
|
362 |
+
out /= shifts
|
363 |
+
return out
|
364 |
+
else:
|
365 |
+
valid_length = model.valid_length(length)
|
366 |
+
mix = tensor_chunk(mix)
|
367 |
+
padded_mix = mix.padded(valid_length)
|
368 |
+
with th.no_grad():
|
369 |
+
out = model(padded_mix.unsqueeze(0))[0]
|
370 |
+
return center_trim(out, length)
|
371 |
+
|
372 |
+
|
373 |
+
@contextmanager
|
374 |
+
def temp_filenames(count, delete=True):
|
375 |
+
names = []
|
376 |
+
try:
|
377 |
+
for _ in range(count):
|
378 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
379 |
+
yield names
|
380 |
+
finally:
|
381 |
+
if delete:
|
382 |
+
for name in names:
|
383 |
+
os.unlink(name)
|
384 |
+
|
385 |
+
|
386 |
+
def get_quantizer(model, args, optimizer=None):
|
387 |
+
quantizer = None
|
388 |
+
if args.diffq:
|
389 |
+
quantizer = DiffQuantizer(model, min_size=args.q_min_size, group_size=8)
|
390 |
+
if optimizer is not None:
|
391 |
+
quantizer.setup_optimizer(optimizer)
|
392 |
+
elif args.qat:
|
393 |
+
quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.q_min_size)
|
394 |
+
return quantizer
|
395 |
+
|
396 |
+
|
397 |
+
def load_model(path, strict=False):
|
398 |
+
with warnings.catch_warnings():
|
399 |
+
warnings.simplefilter("ignore")
|
400 |
+
load_from = path
|
401 |
+
package = th.load(load_from, "cpu")
|
402 |
+
|
403 |
+
klass = package["klass"]
|
404 |
+
args = package["args"]
|
405 |
+
kwargs = package["kwargs"]
|
406 |
+
|
407 |
+
if strict:
|
408 |
+
model = klass(*args, **kwargs)
|
409 |
+
else:
|
410 |
+
sig = inspect.signature(klass)
|
411 |
+
for key in list(kwargs):
|
412 |
+
if key not in sig.parameters:
|
413 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
414 |
+
del kwargs[key]
|
415 |
+
model = klass(*args, **kwargs)
|
416 |
+
|
417 |
+
state = package["state"]
|
418 |
+
training_args = package["training_args"]
|
419 |
+
quantizer = get_quantizer(model, training_args)
|
420 |
+
|
421 |
+
set_state(model, quantizer, state)
|
422 |
+
return model
|
423 |
+
|
424 |
+
|
425 |
+
def get_state(model, quantizer):
|
426 |
+
if quantizer is None:
|
427 |
+
state = {k: p.data.to("cpu") for k, p in model.state_dict().items()}
|
428 |
+
else:
|
429 |
+
state = quantizer.get_quantized_state()
|
430 |
+
buf = io.BytesIO()
|
431 |
+
th.save(state, buf)
|
432 |
+
state = {"compressed": zlib.compress(buf.getvalue())}
|
433 |
+
return state
|
434 |
+
|
435 |
+
|
436 |
+
def set_state(model, quantizer, state):
|
437 |
+
if quantizer is None:
|
438 |
+
model.load_state_dict(state)
|
439 |
+
else:
|
440 |
+
buf = io.BytesIO(zlib.decompress(state["compressed"]))
|
441 |
+
state = th.load(buf, "cpu")
|
442 |
+
quantizer.restore_quantized_state(state)
|
443 |
+
|
444 |
+
return state
|
445 |
+
|
446 |
+
|
447 |
+
def save_state(state, path):
|
448 |
+
buf = io.BytesIO()
|
449 |
+
th.save(state, buf)
|
450 |
+
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
451 |
+
|
452 |
+
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
453 |
+
path.write_bytes(buf.getvalue())
|
454 |
+
|
455 |
+
|
456 |
+
def save_model(model, quantizer, training_args, path):
|
457 |
+
args, kwargs = model._init_args_kwargs
|
458 |
+
klass = model.__class__
|
459 |
+
|
460 |
+
state = get_state(model, quantizer)
|
461 |
+
|
462 |
+
save_to = path
|
463 |
+
package = {"klass": klass, "args": args, "kwargs": kwargs, "state": state, "training_args": training_args}
|
464 |
+
th.save(package, save_to)
|
465 |
+
|
466 |
+
|
467 |
+
def capture_init(init):
|
468 |
+
@functools.wraps(init)
|
469 |
+
def __init__(self, *args, **kwargs):
|
470 |
+
self._init_args_kwargs = (args, kwargs)
|
471 |
+
init(self, *args, **kwargs)
|
472 |
+
|
473 |
+
return __init__
|
474 |
+
|
475 |
+
|
476 |
+
class DummyPoolExecutor:
|
477 |
+
class DummyResult:
|
478 |
+
def __init__(self, func, *args, **kwargs):
|
479 |
+
self.func = func
|
480 |
+
self.args = args
|
481 |
+
self.kwargs = kwargs
|
482 |
+
|
483 |
+
def result(self):
|
484 |
+
return self.func(*self.args, **self.kwargs)
|
485 |
+
|
486 |
+
def __init__(self, workers=0):
|
487 |
+
pass
|
488 |
+
|
489 |
+
def submit(self, func, *args, **kwargs):
|
490 |
+
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
491 |
+
|
492 |
+
def __enter__(self):
|
493 |
+
return self
|
494 |
+
|
495 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
496 |
+
return
|
audio_separator/separator/uvr_lib_v5/mdxnet.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .modules import TFC_TDF
|
4 |
+
from pytorch_lightning import LightningModule
|
5 |
+
|
6 |
+
dim_s = 4
|
7 |
+
|
8 |
+
class AbstractMDXNet(LightningModule):
|
9 |
+
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
|
10 |
+
super().__init__()
|
11 |
+
self.target_name = target_name
|
12 |
+
self.lr = lr
|
13 |
+
self.optimizer = optimizer
|
14 |
+
self.dim_c = dim_c
|
15 |
+
self.dim_f = dim_f
|
16 |
+
self.dim_t = dim_t
|
17 |
+
self.n_fft = n_fft
|
18 |
+
self.n_bins = n_fft // 2 + 1
|
19 |
+
self.hop_length = hop_length
|
20 |
+
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
|
21 |
+
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
|
22 |
+
|
23 |
+
def get_optimizer(self):
|
24 |
+
if self.optimizer == 'rmsprop':
|
25 |
+
return torch.optim.RMSprop(self.parameters(), self.lr)
|
26 |
+
|
27 |
+
if self.optimizer == 'adamw':
|
28 |
+
return torch.optim.AdamW(self.parameters(), self.lr)
|
29 |
+
|
30 |
+
class ConvTDFNet(AbstractMDXNet):
|
31 |
+
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
|
32 |
+
num_blocks, l, g, k, bn, bias, overlap):
|
33 |
+
|
34 |
+
super(ConvTDFNet, self).__init__(
|
35 |
+
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
|
36 |
+
#self.save_hyperparameters()
|
37 |
+
|
38 |
+
self.num_blocks = num_blocks
|
39 |
+
self.l = l
|
40 |
+
self.g = g
|
41 |
+
self.k = k
|
42 |
+
self.bn = bn
|
43 |
+
self.bias = bias
|
44 |
+
|
45 |
+
if optimizer == 'rmsprop':
|
46 |
+
norm = nn.BatchNorm2d
|
47 |
+
|
48 |
+
if optimizer == 'adamw':
|
49 |
+
norm = lambda input:nn.GroupNorm(2, input)
|
50 |
+
|
51 |
+
self.n = num_blocks // 2
|
52 |
+
scale = (2, 2)
|
53 |
+
|
54 |
+
self.first_conv = nn.Sequential(
|
55 |
+
nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
|
56 |
+
norm(g),
|
57 |
+
nn.ReLU(),
|
58 |
+
)
|
59 |
+
|
60 |
+
f = self.dim_f
|
61 |
+
c = g
|
62 |
+
self.encoding_blocks = nn.ModuleList()
|
63 |
+
self.ds = nn.ModuleList()
|
64 |
+
for i in range(self.n):
|
65 |
+
self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
|
66 |
+
self.ds.append(
|
67 |
+
nn.Sequential(
|
68 |
+
nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
|
69 |
+
norm(c + g),
|
70 |
+
nn.ReLU()
|
71 |
+
)
|
72 |
+
)
|
73 |
+
f = f // 2
|
74 |
+
c += g
|
75 |
+
|
76 |
+
self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
|
77 |
+
|
78 |
+
self.decoding_blocks = nn.ModuleList()
|
79 |
+
self.us = nn.ModuleList()
|
80 |
+
for i in range(self.n):
|
81 |
+
self.us.append(
|
82 |
+
nn.Sequential(
|
83 |
+
nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
|
84 |
+
norm(c - g),
|
85 |
+
nn.ReLU()
|
86 |
+
)
|
87 |
+
)
|
88 |
+
f = f * 2
|
89 |
+
c -= g
|
90 |
+
|
91 |
+
self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
|
92 |
+
|
93 |
+
self.final_conv = nn.Sequential(
|
94 |
+
nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
|
99 |
+
x = self.first_conv(x)
|
100 |
+
|
101 |
+
x = x.transpose(-1, -2)
|
102 |
+
|
103 |
+
ds_outputs = []
|
104 |
+
for i in range(self.n):
|
105 |
+
x = self.encoding_blocks[i](x)
|
106 |
+
ds_outputs.append(x)
|
107 |
+
x = self.ds[i](x)
|
108 |
+
|
109 |
+
x = self.bottleneck_block(x)
|
110 |
+
|
111 |
+
for i in range(self.n):
|
112 |
+
x = self.us[i](x)
|
113 |
+
x *= ds_outputs[-i - 1]
|
114 |
+
x = self.decoding_blocks[i](x)
|
115 |
+
|
116 |
+
x = x.transpose(-1, -2)
|
117 |
+
|
118 |
+
x = self.final_conv(x)
|
119 |
+
|
120 |
+
return x
|
121 |
+
|
122 |
+
class Mixer(nn.Module):
|
123 |
+
def __init__(self, device, mixer_path):
|
124 |
+
|
125 |
+
super(Mixer, self).__init__()
|
126 |
+
|
127 |
+
self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
|
128 |
+
|
129 |
+
self.load_state_dict(
|
130 |
+
torch.load(mixer_path, map_location=device)
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
|
135 |
+
x = self.linear(x)
|
136 |
+
return x.transpose(-1,-2).reshape(dim_s,2,-1)
|
audio_separator/separator/uvr_lib_v5/mixer.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea781bd52c6a523b825fa6cdbb6189f52e318edd8b17e6fe404f76f7af8caa9c
|
3 |
+
size 1208
|
audio_separator/separator/uvr_lib_v5/modules.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class TFC(nn.Module):
|
6 |
+
def __init__(self, c, l, k, norm):
|
7 |
+
super(TFC, self).__init__()
|
8 |
+
|
9 |
+
self.H = nn.ModuleList()
|
10 |
+
for i in range(l):
|
11 |
+
self.H.append(
|
12 |
+
nn.Sequential(
|
13 |
+
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
|
14 |
+
norm(c),
|
15 |
+
nn.ReLU(),
|
16 |
+
)
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
for h in self.H:
|
21 |
+
x = h(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class DenseTFC(nn.Module):
|
26 |
+
def __init__(self, c, l, k, norm):
|
27 |
+
super(DenseTFC, self).__init__()
|
28 |
+
|
29 |
+
self.conv = nn.ModuleList()
|
30 |
+
for i in range(l):
|
31 |
+
self.conv.append(
|
32 |
+
nn.Sequential(
|
33 |
+
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
|
34 |
+
norm(c),
|
35 |
+
nn.ReLU(),
|
36 |
+
)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
for layer in self.conv[:-1]:
|
41 |
+
x = torch.cat([layer(x), x], 1)
|
42 |
+
return self.conv[-1](x)
|
43 |
+
|
44 |
+
|
45 |
+
class TFC_TDF(nn.Module):
|
46 |
+
def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d):
|
47 |
+
|
48 |
+
super(TFC_TDF, self).__init__()
|
49 |
+
|
50 |
+
self.use_tdf = bn is not None
|
51 |
+
|
52 |
+
self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm)
|
53 |
+
|
54 |
+
if self.use_tdf:
|
55 |
+
if bn == 0:
|
56 |
+
self.tdf = nn.Sequential(
|
57 |
+
nn.Linear(f, f, bias=bias),
|
58 |
+
norm(c),
|
59 |
+
nn.ReLU()
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
self.tdf = nn.Sequential(
|
63 |
+
nn.Linear(f, f // bn, bias=bias),
|
64 |
+
norm(c),
|
65 |
+
nn.ReLU(),
|
66 |
+
nn.Linear(f // bn, f, bias=bias),
|
67 |
+
norm(c),
|
68 |
+
nn.ReLU()
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.tfc(x)
|
73 |
+
return x + self.tdf(x) if self.use_tdf else x
|
74 |
+
|
audio_separator/separator/uvr_lib_v5/playsound.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger(__name__)
|
3 |
+
|
4 |
+
class PlaysoundException(Exception):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def _canonicalizePath(path):
|
8 |
+
"""
|
9 |
+
Support passing in a pathlib.Path-like object by converting to str.
|
10 |
+
"""
|
11 |
+
import sys
|
12 |
+
if sys.version_info[0] >= 3:
|
13 |
+
return str(path)
|
14 |
+
else:
|
15 |
+
# On earlier Python versions, str is a byte string, so attempting to
|
16 |
+
# convert a unicode string to str will fail. Leave it alone in this case.
|
17 |
+
return path
|
18 |
+
|
19 |
+
def _playsoundWin(sound, block = True):
|
20 |
+
'''
|
21 |
+
Utilizes windll.winmm. Tested and known to work with MP3 and WAVE on
|
22 |
+
Windows 7 with Python 2.7. Probably works with more file formats.
|
23 |
+
Probably works on Windows XP thru Windows 10. Probably works with all
|
24 |
+
versions of Python.
|
25 |
+
|
26 |
+
Inspired by (but not copied from) Michael Gundlach <[email protected]>'s mp3play:
|
27 |
+
https://github.com/michaelgundlach/mp3play
|
28 |
+
|
29 |
+
I never would have tried using windll.winmm without seeing his code.
|
30 |
+
'''
|
31 |
+
sound = '"' + _canonicalizePath(sound) + '"'
|
32 |
+
|
33 |
+
from ctypes import create_unicode_buffer, windll, wintypes
|
34 |
+
windll.winmm.mciSendStringW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.UINT, wintypes.HANDLE]
|
35 |
+
windll.winmm.mciGetErrorStringW.argtypes = [wintypes.DWORD, wintypes.LPWSTR, wintypes.UINT]
|
36 |
+
|
37 |
+
def winCommand(*command):
|
38 |
+
bufLen = 600
|
39 |
+
buf = create_unicode_buffer(bufLen)
|
40 |
+
command = ' '.join(command)
|
41 |
+
errorCode = int(windll.winmm.mciSendStringW(command, buf, bufLen - 1, 0)) # use widestring version of the function
|
42 |
+
if errorCode:
|
43 |
+
errorBuffer = create_unicode_buffer(bufLen)
|
44 |
+
windll.winmm.mciGetErrorStringW(errorCode, errorBuffer, bufLen - 1) # use widestring version of the function
|
45 |
+
exceptionMessage = ('\n Error ' + str(errorCode) + ' for command:'
|
46 |
+
'\n ' + command +
|
47 |
+
'\n ' + errorBuffer.value)
|
48 |
+
logger.error(exceptionMessage)
|
49 |
+
raise PlaysoundException(exceptionMessage)
|
50 |
+
return buf.value
|
51 |
+
|
52 |
+
try:
|
53 |
+
logger.debug('Starting')
|
54 |
+
winCommand(u'open {}'.format(sound))
|
55 |
+
winCommand(u'play {}{}'.format(sound, ' wait' if block else ''))
|
56 |
+
logger.debug('Returning')
|
57 |
+
finally:
|
58 |
+
try:
|
59 |
+
winCommand(u'close {}'.format(sound))
|
60 |
+
except PlaysoundException:
|
61 |
+
logger.warning(u'Failed to close the file: {}'.format(sound))
|
62 |
+
# If it fails, there's nothing more that can be done...
|
63 |
+
pass
|
64 |
+
|
65 |
+
def _handlePathOSX(sound):
|
66 |
+
sound = _canonicalizePath(sound)
|
67 |
+
|
68 |
+
if '://' not in sound:
|
69 |
+
if not sound.startswith('/'):
|
70 |
+
from os import getcwd
|
71 |
+
sound = getcwd() + '/' + sound
|
72 |
+
sound = 'file://' + sound
|
73 |
+
|
74 |
+
try:
|
75 |
+
# Don't double-encode it.
|
76 |
+
sound.encode('ascii')
|
77 |
+
return sound.replace(' ', '%20')
|
78 |
+
except UnicodeEncodeError:
|
79 |
+
try:
|
80 |
+
from urllib.parse import quote # Try the Python 3 import first...
|
81 |
+
except ImportError:
|
82 |
+
from urllib import quote # Try using the Python 2 import before giving up entirely...
|
83 |
+
|
84 |
+
parts = sound.split('://', 1)
|
85 |
+
return parts[0] + '://' + quote(parts[1].encode('utf-8')).replace(' ', '%20')
|
86 |
+
|
87 |
+
|
88 |
+
def _playsoundOSX(sound, block = True):
|
89 |
+
'''
|
90 |
+
Utilizes AppKit.NSSound. Tested and known to work with MP3 and WAVE on
|
91 |
+
OS X 10.11 with Python 2.7. Probably works with anything QuickTime supports.
|
92 |
+
Probably works on OS X 10.5 and newer. Probably works with all versions of
|
93 |
+
Python.
|
94 |
+
|
95 |
+
Inspired by (but not copied from) Aaron's Stack Overflow answer here:
|
96 |
+
http://stackoverflow.com/a/34568298/901641
|
97 |
+
|
98 |
+
I never would have tried using AppKit.NSSound without seeing his code.
|
99 |
+
'''
|
100 |
+
try:
|
101 |
+
from AppKit import NSSound
|
102 |
+
except ImportError:
|
103 |
+
logger.warning("playsound could not find a copy of AppKit - falling back to using macOS's system copy.")
|
104 |
+
sys.path.append('/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/PyObjC')
|
105 |
+
from AppKit import NSSound
|
106 |
+
|
107 |
+
from Foundation import NSURL
|
108 |
+
from time import sleep
|
109 |
+
|
110 |
+
sound = _handlePathOSX(sound)
|
111 |
+
url = NSURL.URLWithString_(sound)
|
112 |
+
if not url:
|
113 |
+
raise PlaysoundException('Cannot find a sound with filename: ' + sound)
|
114 |
+
|
115 |
+
for i in range(5):
|
116 |
+
nssound = NSSound.alloc().initWithContentsOfURL_byReference_(url, True)
|
117 |
+
if nssound:
|
118 |
+
break
|
119 |
+
else:
|
120 |
+
logger.debug('Failed to load sound, although url was good... ' + sound)
|
121 |
+
else:
|
122 |
+
raise PlaysoundException('Could not load sound with filename, although URL was good... ' + sound)
|
123 |
+
nssound.play()
|
124 |
+
|
125 |
+
if block:
|
126 |
+
sleep(nssound.duration())
|
127 |
+
|
128 |
+
def _playsoundNix(sound, block = True):
|
129 |
+
"""Play a sound using GStreamer.
|
130 |
+
|
131 |
+
Inspired by this:
|
132 |
+
https://gstreamer.freedesktop.org/documentation/tutorials/playback/playbin-usage.html
|
133 |
+
"""
|
134 |
+
sound = _canonicalizePath(sound)
|
135 |
+
|
136 |
+
# pathname2url escapes non-URL-safe characters
|
137 |
+
from os.path import abspath, exists
|
138 |
+
try:
|
139 |
+
from urllib.request import pathname2url
|
140 |
+
except ImportError:
|
141 |
+
# python 2
|
142 |
+
from urllib import pathname2url
|
143 |
+
|
144 |
+
import gi
|
145 |
+
gi.require_version('Gst', '1.0')
|
146 |
+
from gi.repository import Gst
|
147 |
+
|
148 |
+
Gst.init(None)
|
149 |
+
|
150 |
+
playbin = Gst.ElementFactory.make('playbin', 'playbin')
|
151 |
+
if sound.startswith(('http://', 'https://')):
|
152 |
+
playbin.props.uri = sound
|
153 |
+
else:
|
154 |
+
path = abspath(sound)
|
155 |
+
if not exists(path):
|
156 |
+
raise PlaysoundException(u'File not found: {}'.format(path))
|
157 |
+
playbin.props.uri = 'file://' + pathname2url(path)
|
158 |
+
|
159 |
+
|
160 |
+
set_result = playbin.set_state(Gst.State.PLAYING)
|
161 |
+
if set_result != Gst.StateChangeReturn.ASYNC:
|
162 |
+
raise PlaysoundException(
|
163 |
+
"playbin.set_state returned " + repr(set_result))
|
164 |
+
|
165 |
+
# FIXME: use some other bus method than poll() with block=False
|
166 |
+
# https://lazka.github.io/pgi-docs/#Gst-1.0/classes/Bus.html
|
167 |
+
logger.debug('Starting play')
|
168 |
+
if block:
|
169 |
+
bus = playbin.get_bus()
|
170 |
+
try:
|
171 |
+
bus.poll(Gst.MessageType.EOS, Gst.CLOCK_TIME_NONE)
|
172 |
+
finally:
|
173 |
+
playbin.set_state(Gst.State.NULL)
|
174 |
+
|
175 |
+
logger.debug('Finishing play')
|
176 |
+
|
177 |
+
def _playsoundAnotherPython(otherPython, sound, block = True, macOS = False):
|
178 |
+
'''
|
179 |
+
Mostly written so that when this is run on python3 on macOS, it can invoke
|
180 |
+
python2 on macOS... but maybe this idea could be useful on linux, too.
|
181 |
+
'''
|
182 |
+
from inspect import getsourcefile
|
183 |
+
from os.path import abspath, exists
|
184 |
+
from subprocess import check_call
|
185 |
+
from threading import Thread
|
186 |
+
|
187 |
+
sound = _canonicalizePath(sound)
|
188 |
+
|
189 |
+
class PropogatingThread(Thread):
|
190 |
+
def run(self):
|
191 |
+
self.exc = None
|
192 |
+
try:
|
193 |
+
self.ret = self._target(*self._args, **self._kwargs)
|
194 |
+
except BaseException as e:
|
195 |
+
self.exc = e
|
196 |
+
|
197 |
+
def join(self, timeout = None):
|
198 |
+
super().join(timeout)
|
199 |
+
if self.exc:
|
200 |
+
raise self.exc
|
201 |
+
return self.ret
|
202 |
+
|
203 |
+
# Check if the file exists...
|
204 |
+
if not exists(abspath(sound)):
|
205 |
+
raise PlaysoundException('Cannot find a sound with filename: ' + sound)
|
206 |
+
|
207 |
+
playsoundPath = abspath(getsourcefile(lambda: 0))
|
208 |
+
t = PropogatingThread(target = lambda: check_call([otherPython, playsoundPath, _handlePathOSX(sound) if macOS else sound]))
|
209 |
+
t.start()
|
210 |
+
if block:
|
211 |
+
t.join()
|
212 |
+
|
213 |
+
from platform import system
|
214 |
+
system = system()
|
215 |
+
|
216 |
+
if system == 'Windows':
|
217 |
+
playsound_func = _playsoundWin
|
218 |
+
elif system == 'Darwin':
|
219 |
+
playsound_func = _playsoundOSX
|
220 |
+
import sys
|
221 |
+
if sys.version_info[0] > 2:
|
222 |
+
try:
|
223 |
+
from AppKit import NSSound
|
224 |
+
except ImportError:
|
225 |
+
logger.warning("playsound is relying on a python 2 subprocess. Please use `pip3 install PyObjC` if you want playsound to run more efficiently.")
|
226 |
+
playsound_func = lambda sound, block = True: _playsoundAnotherPython('/System/Library/Frameworks/Python.framework/Versions/2.7/bin/python', sound, block, macOS = True)
|
227 |
+
else:
|
228 |
+
playsound_func = _playsoundNix
|
229 |
+
if __name__ != '__main__': # Ensure we don't infinitely recurse trying to get another python instance.
|
230 |
+
try:
|
231 |
+
import gi
|
232 |
+
gi.require_version('Gst', '1.0')
|
233 |
+
from gi.repository import Gst
|
234 |
+
except:
|
235 |
+
logger.warning("playsound is relying on another python subprocess. Please use `pip install pygobject` if you want playsound to run more efficiently.")
|
236 |
+
playsound_func = lambda sound, block = True: _playsoundAnotherPython('/usr/bin/python3', sound, block, macOS = False)
|
237 |
+
|
238 |
+
del system
|
239 |
+
|
240 |
+
def play(audio_filepath):
|
241 |
+
playsound_func(audio_filepath)
|
audio_separator/separator/uvr_lib_v5/pyrb.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import tempfile
|
4 |
+
import six
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
import sys
|
8 |
+
|
9 |
+
if getattr(sys, 'frozen', False):
|
10 |
+
BASE_PATH_RUB = sys._MEIPASS
|
11 |
+
else:
|
12 |
+
BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__))
|
13 |
+
|
14 |
+
__all__ = ['time_stretch', 'pitch_shift']
|
15 |
+
|
16 |
+
__RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, 'rubberband')
|
17 |
+
|
18 |
+
if six.PY2:
|
19 |
+
DEVNULL = open(os.devnull, 'w')
|
20 |
+
else:
|
21 |
+
DEVNULL = subprocess.DEVNULL
|
22 |
+
|
23 |
+
def __rubberband(y, sr, **kwargs):
|
24 |
+
|
25 |
+
assert sr > 0
|
26 |
+
|
27 |
+
# Get the input and output tempfile
|
28 |
+
fd, infile = tempfile.mkstemp(suffix='.wav')
|
29 |
+
os.close(fd)
|
30 |
+
fd, outfile = tempfile.mkstemp(suffix='.wav')
|
31 |
+
os.close(fd)
|
32 |
+
|
33 |
+
# dump the audio
|
34 |
+
sf.write(infile, y, sr)
|
35 |
+
|
36 |
+
try:
|
37 |
+
# Execute rubberband
|
38 |
+
arguments = [__RUBBERBAND_UTIL, '-q']
|
39 |
+
|
40 |
+
for key, value in six.iteritems(kwargs):
|
41 |
+
arguments.append(str(key))
|
42 |
+
arguments.append(str(value))
|
43 |
+
|
44 |
+
arguments.extend([infile, outfile])
|
45 |
+
|
46 |
+
subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
|
47 |
+
|
48 |
+
# Load the processed audio.
|
49 |
+
y_out, _ = sf.read(outfile, always_2d=True)
|
50 |
+
|
51 |
+
# make sure that output dimensions matches input
|
52 |
+
if y.ndim == 1:
|
53 |
+
y_out = np.squeeze(y_out)
|
54 |
+
|
55 |
+
except OSError as exc:
|
56 |
+
six.raise_from(RuntimeError('Failed to execute rubberband. '
|
57 |
+
'Please verify that rubberband-cli '
|
58 |
+
'is installed.'),
|
59 |
+
exc)
|
60 |
+
|
61 |
+
finally:
|
62 |
+
# Remove temp files
|
63 |
+
os.unlink(infile)
|
64 |
+
os.unlink(outfile)
|
65 |
+
|
66 |
+
return y_out
|
67 |
+
|
68 |
+
def time_stretch(y, sr, rate, rbargs=None):
|
69 |
+
if rate <= 0:
|
70 |
+
raise ValueError('rate must be strictly positive')
|
71 |
+
|
72 |
+
if rate == 1.0:
|
73 |
+
return y
|
74 |
+
|
75 |
+
if rbargs is None:
|
76 |
+
rbargs = dict()
|
77 |
+
|
78 |
+
rbargs.setdefault('--tempo', rate)
|
79 |
+
|
80 |
+
return __rubberband(y, sr, **rbargs)
|
81 |
+
|
82 |
+
def pitch_shift(y, sr, n_steps, rbargs=None):
|
83 |
+
|
84 |
+
if n_steps == 0:
|
85 |
+
return y
|
86 |
+
|
87 |
+
if rbargs is None:
|
88 |
+
rbargs = dict()
|
89 |
+
|
90 |
+
rbargs.setdefault('--pitch', n_steps)
|
91 |
+
|
92 |
+
return __rubberband(y, sr, **rbargs)
|
audio_separator/separator/uvr_lib_v5/results.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
Matchering - Audio Matching and Mastering Python Library
|
5 |
+
Copyright (C) 2016-2022 Sergree
|
6 |
+
|
7 |
+
This program is free software: you can redistribute it and/or modify
|
8 |
+
it under the terms of the GNU General Public License as published by
|
9 |
+
the Free Software Foundation, either version 3 of the License, or
|
10 |
+
(at your option) any later version.
|
11 |
+
|
12 |
+
This program is distributed in the hope that it will be useful,
|
13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15 |
+
GNU General Public License for more details.
|
16 |
+
|
17 |
+
You should have received a copy of the GNU General Public License
|
18 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import os
|
22 |
+
import soundfile as sf
|
23 |
+
|
24 |
+
|
25 |
+
class Result:
|
26 |
+
def __init__(
|
27 |
+
self, file: str, subtype: str, use_limiter: bool = True, normalize: bool = True
|
28 |
+
):
|
29 |
+
_, file_ext = os.path.splitext(file)
|
30 |
+
file_ext = file_ext[1:].upper()
|
31 |
+
if not sf.check_format(file_ext):
|
32 |
+
raise TypeError(f"{file_ext} format is not supported")
|
33 |
+
if not sf.check_format(file_ext, subtype):
|
34 |
+
raise TypeError(f"{file_ext} format does not have {subtype} subtype")
|
35 |
+
self.file = file
|
36 |
+
self.subtype = subtype
|
37 |
+
self.use_limiter = use_limiter
|
38 |
+
self.normalize = normalize
|
39 |
+
|
40 |
+
|
41 |
+
def pcm16(file: str) -> Result:
|
42 |
+
return Result(file, "PCM_16")
|
43 |
+
|
44 |
+
def pcm24(file: str) -> Result:
|
45 |
+
return Result(file, "FLOAT")
|
46 |
+
|
47 |
+
def save_audiofile(file: str, wav_set="PCM_16") -> Result:
|
48 |
+
return Result(file, wav_set)
|
audio_separator/separator/uvr_lib_v5/roformer/attend.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
from packaging import version
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn, einsum
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, reduce
|
10 |
+
|
11 |
+
# constants
|
12 |
+
|
13 |
+
FlashAttentionConfig = namedtuple("FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
|
14 |
+
|
15 |
+
# helpers
|
16 |
+
|
17 |
+
|
18 |
+
def exists(val):
|
19 |
+
return val is not None
|
20 |
+
|
21 |
+
|
22 |
+
def once(fn):
|
23 |
+
called = False
|
24 |
+
|
25 |
+
@wraps(fn)
|
26 |
+
def inner(x):
|
27 |
+
nonlocal called
|
28 |
+
if called:
|
29 |
+
return
|
30 |
+
called = True
|
31 |
+
return fn(x)
|
32 |
+
|
33 |
+
return inner
|
34 |
+
|
35 |
+
|
36 |
+
print_once = once(print)
|
37 |
+
|
38 |
+
# main class
|
39 |
+
|
40 |
+
|
41 |
+
class Attend(nn.Module):
|
42 |
+
def __init__(self, dropout=0.0, flash=False):
|
43 |
+
super().__init__()
|
44 |
+
self.dropout = dropout
|
45 |
+
self.attn_dropout = nn.Dropout(dropout)
|
46 |
+
|
47 |
+
self.flash = flash
|
48 |
+
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
49 |
+
|
50 |
+
# determine efficient attention configs for cuda and cpu
|
51 |
+
|
52 |
+
self.cpu_config = FlashAttentionConfig(True, True, True)
|
53 |
+
self.cuda_config = None
|
54 |
+
|
55 |
+
if not torch.cuda.is_available() or not flash:
|
56 |
+
return
|
57 |
+
|
58 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
59 |
+
|
60 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
61 |
+
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
|
62 |
+
self.cuda_config = FlashAttentionConfig(True, False, False)
|
63 |
+
else:
|
64 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
65 |
+
|
66 |
+
def flash_attn(self, q, k, v):
|
67 |
+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
68 |
+
|
69 |
+
# Check if there is a compatible device for flash attention
|
70 |
+
|
71 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
72 |
+
|
73 |
+
# sdpa_flash kernel only supports float16 on sm80+ architecture gpu
|
74 |
+
if is_cuda and q.dtype != torch.float16:
|
75 |
+
config = FlashAttentionConfig(False, True, True)
|
76 |
+
|
77 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
78 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
79 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
|
80 |
+
|
81 |
+
return out
|
82 |
+
|
83 |
+
def forward(self, q, k, v):
|
84 |
+
"""
|
85 |
+
einstein notation
|
86 |
+
b - batch
|
87 |
+
h - heads
|
88 |
+
n, i, j - sequence length (base sequence length, source, target)
|
89 |
+
d - feature dimension
|
90 |
+
"""
|
91 |
+
|
92 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
93 |
+
|
94 |
+
scale = q.shape[-1] ** -0.5
|
95 |
+
|
96 |
+
if self.flash:
|
97 |
+
return self.flash_attn(q, k, v)
|
98 |
+
|
99 |
+
# similarity
|
100 |
+
|
101 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
102 |
+
|
103 |
+
# attention
|
104 |
+
|
105 |
+
attn = sim.softmax(dim=-1)
|
106 |
+
attn = self.attn_dropout(attn)
|
107 |
+
|
108 |
+
# aggregate values
|
109 |
+
|
110 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
111 |
+
|
112 |
+
return out
|
audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum, Tensor
|
5 |
+
from torch.nn import Module, ModuleList
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .attend import Attend
|
9 |
+
|
10 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
11 |
+
from beartype import beartype
|
12 |
+
|
13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
14 |
+
|
15 |
+
from einops import rearrange, pack, unpack
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
|
18 |
+
# helper functions
|
19 |
+
|
20 |
+
|
21 |
+
def exists(val):
|
22 |
+
return val is not None
|
23 |
+
|
24 |
+
|
25 |
+
def default(v, d):
|
26 |
+
return v if exists(v) else d
|
27 |
+
|
28 |
+
|
29 |
+
def pack_one(t, pattern):
|
30 |
+
return pack([t], pattern)
|
31 |
+
|
32 |
+
|
33 |
+
def unpack_one(t, ps, pattern):
|
34 |
+
return unpack(t, ps, pattern)[0]
|
35 |
+
|
36 |
+
|
37 |
+
# norm
|
38 |
+
|
39 |
+
|
40 |
+
def l2norm(t):
|
41 |
+
return F.normalize(t, dim=-1, p=2)
|
42 |
+
|
43 |
+
|
44 |
+
class RMSNorm(Module):
|
45 |
+
def __init__(self, dim):
|
46 |
+
super().__init__()
|
47 |
+
self.scale = dim**0.5
|
48 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = x.to(self.gamma.device)
|
52 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
53 |
+
|
54 |
+
|
55 |
+
# attention
|
56 |
+
|
57 |
+
|
58 |
+
class FeedForward(Module):
|
59 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
60 |
+
super().__init__()
|
61 |
+
dim_inner = int(dim * mult)
|
62 |
+
self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout))
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return self.net(x)
|
66 |
+
|
67 |
+
|
68 |
+
class Attention(Module):
|
69 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
70 |
+
super().__init__()
|
71 |
+
self.heads = heads
|
72 |
+
self.scale = dim_head**-0.5
|
73 |
+
dim_inner = heads * dim_head
|
74 |
+
|
75 |
+
self.rotary_embed = rotary_embed
|
76 |
+
|
77 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
78 |
+
|
79 |
+
self.norm = RMSNorm(dim)
|
80 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
81 |
+
|
82 |
+
self.to_gates = nn.Linear(dim, heads)
|
83 |
+
|
84 |
+
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
x = self.norm(x)
|
88 |
+
|
89 |
+
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
90 |
+
|
91 |
+
if exists(self.rotary_embed):
|
92 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
93 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
94 |
+
|
95 |
+
out = self.attend(q, k, v)
|
96 |
+
|
97 |
+
gates = self.to_gates(x)
|
98 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
99 |
+
|
100 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
101 |
+
return self.to_out(out)
|
102 |
+
|
103 |
+
|
104 |
+
class LinearAttention(Module):
|
105 |
+
"""
|
106 |
+
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
107 |
+
"""
|
108 |
+
|
109 |
+
@beartype
|
110 |
+
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
|
111 |
+
super().__init__()
|
112 |
+
dim_inner = dim_head * heads
|
113 |
+
self.norm = RMSNorm(dim)
|
114 |
+
|
115 |
+
self.to_qkv = nn.Sequential(nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads))
|
116 |
+
|
117 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
118 |
+
|
119 |
+
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
120 |
+
|
121 |
+
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = self.norm(x)
|
125 |
+
|
126 |
+
q, k, v = self.to_qkv(x)
|
127 |
+
|
128 |
+
q, k = map(l2norm, (q, k))
|
129 |
+
q = q * self.temperature.exp()
|
130 |
+
|
131 |
+
out = self.attend(q, k, v)
|
132 |
+
|
133 |
+
return self.to_out(out)
|
134 |
+
|
135 |
+
|
136 |
+
class Transformer(Module):
|
137 |
+
def __init__(self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True, linear_attn=False):
|
138 |
+
super().__init__()
|
139 |
+
self.layers = ModuleList([])
|
140 |
+
|
141 |
+
for _ in range(depth):
|
142 |
+
if linear_attn:
|
143 |
+
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
144 |
+
else:
|
145 |
+
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn)
|
146 |
+
|
147 |
+
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
|
148 |
+
|
149 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
|
153 |
+
for attn, ff in self.layers:
|
154 |
+
x = attn(x) + x
|
155 |
+
x = ff(x) + x
|
156 |
+
|
157 |
+
return self.norm(x)
|
158 |
+
|
159 |
+
|
160 |
+
# bandsplit module
|
161 |
+
|
162 |
+
|
163 |
+
class BandSplit(Module):
|
164 |
+
@beartype
|
165 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
166 |
+
super().__init__()
|
167 |
+
self.dim_inputs = dim_inputs
|
168 |
+
self.to_features = ModuleList([])
|
169 |
+
|
170 |
+
for dim_in in dim_inputs:
|
171 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
172 |
+
|
173 |
+
self.to_features.append(net)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
x = x.split(self.dim_inputs, dim=-1)
|
177 |
+
|
178 |
+
outs = []
|
179 |
+
for split_input, to_feature in zip(x, self.to_features):
|
180 |
+
split_output = to_feature(split_input)
|
181 |
+
outs.append(split_output)
|
182 |
+
|
183 |
+
return torch.stack(outs, dim=-2)
|
184 |
+
|
185 |
+
|
186 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
187 |
+
dim_hidden = default(dim_hidden, dim_in)
|
188 |
+
|
189 |
+
net = []
|
190 |
+
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
191 |
+
|
192 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
193 |
+
is_last = ind == (len(dims) - 2)
|
194 |
+
|
195 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
196 |
+
|
197 |
+
if is_last:
|
198 |
+
continue
|
199 |
+
|
200 |
+
net.append(activation())
|
201 |
+
|
202 |
+
return nn.Sequential(*net)
|
203 |
+
|
204 |
+
|
205 |
+
class MaskEstimator(Module):
|
206 |
+
@beartype
|
207 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
208 |
+
super().__init__()
|
209 |
+
self.dim_inputs = dim_inputs
|
210 |
+
self.to_freqs = ModuleList([])
|
211 |
+
dim_hidden = dim * mlp_expansion_factor
|
212 |
+
|
213 |
+
for dim_in in dim_inputs:
|
214 |
+
net = []
|
215 |
+
|
216 |
+
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
217 |
+
|
218 |
+
self.to_freqs.append(mlp)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
x = x.unbind(dim=-2)
|
222 |
+
|
223 |
+
outs = []
|
224 |
+
|
225 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
226 |
+
freq_out = mlp(band_features)
|
227 |
+
outs.append(freq_out)
|
228 |
+
|
229 |
+
return torch.cat(outs, dim=-1)
|
230 |
+
|
231 |
+
|
232 |
+
# main class
|
233 |
+
|
234 |
+
DEFAULT_FREQS_PER_BANDS = (
|
235 |
+
2,
|
236 |
+
2,
|
237 |
+
2,
|
238 |
+
2,
|
239 |
+
2,
|
240 |
+
2,
|
241 |
+
2,
|
242 |
+
2,
|
243 |
+
2,
|
244 |
+
2,
|
245 |
+
2,
|
246 |
+
2,
|
247 |
+
2,
|
248 |
+
2,
|
249 |
+
2,
|
250 |
+
2,
|
251 |
+
2,
|
252 |
+
2,
|
253 |
+
2,
|
254 |
+
2,
|
255 |
+
2,
|
256 |
+
2,
|
257 |
+
2,
|
258 |
+
2,
|
259 |
+
4,
|
260 |
+
4,
|
261 |
+
4,
|
262 |
+
4,
|
263 |
+
4,
|
264 |
+
4,
|
265 |
+
4,
|
266 |
+
4,
|
267 |
+
4,
|
268 |
+
4,
|
269 |
+
4,
|
270 |
+
4,
|
271 |
+
12,
|
272 |
+
12,
|
273 |
+
12,
|
274 |
+
12,
|
275 |
+
12,
|
276 |
+
12,
|
277 |
+
12,
|
278 |
+
12,
|
279 |
+
24,
|
280 |
+
24,
|
281 |
+
24,
|
282 |
+
24,
|
283 |
+
24,
|
284 |
+
24,
|
285 |
+
24,
|
286 |
+
24,
|
287 |
+
48,
|
288 |
+
48,
|
289 |
+
48,
|
290 |
+
48,
|
291 |
+
48,
|
292 |
+
48,
|
293 |
+
48,
|
294 |
+
48,
|
295 |
+
128,
|
296 |
+
129,
|
297 |
+
)
|
298 |
+
|
299 |
+
|
300 |
+
class BSRoformer(Module):
|
301 |
+
|
302 |
+
@beartype
|
303 |
+
def __init__(
|
304 |
+
self,
|
305 |
+
dim,
|
306 |
+
*,
|
307 |
+
depth,
|
308 |
+
stereo=False,
|
309 |
+
num_stems=1,
|
310 |
+
time_transformer_depth=2,
|
311 |
+
freq_transformer_depth=2,
|
312 |
+
linear_transformer_depth=0,
|
313 |
+
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
314 |
+
# in the paper, they divide into ~60 bands, test with 1 for starters
|
315 |
+
dim_head=64,
|
316 |
+
heads=8,
|
317 |
+
attn_dropout=0.0,
|
318 |
+
ff_dropout=0.0,
|
319 |
+
flash_attn=True,
|
320 |
+
dim_freqs_in=1025,
|
321 |
+
stft_n_fft=2048,
|
322 |
+
stft_hop_length=512,
|
323 |
+
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
324 |
+
stft_win_length=2048,
|
325 |
+
stft_normalized=False,
|
326 |
+
stft_window_fn: Optional[Callable] = None,
|
327 |
+
mask_estimator_depth=2,
|
328 |
+
multi_stft_resolution_loss_weight=1.0,
|
329 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
330 |
+
multi_stft_hop_size=147,
|
331 |
+
multi_stft_normalized=False,
|
332 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
333 |
+
):
|
334 |
+
super().__init__()
|
335 |
+
|
336 |
+
self.stereo = stereo
|
337 |
+
self.audio_channels = 2 if stereo else 1
|
338 |
+
self.num_stems = num_stems
|
339 |
+
|
340 |
+
self.layers = ModuleList([])
|
341 |
+
|
342 |
+
transformer_kwargs = dict(dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn, norm_output=False)
|
343 |
+
|
344 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
345 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
346 |
+
|
347 |
+
for _ in range(depth):
|
348 |
+
tran_modules = []
|
349 |
+
if linear_transformer_depth > 0:
|
350 |
+
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
351 |
+
tran_modules.append(Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs))
|
352 |
+
tran_modules.append(Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs))
|
353 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
354 |
+
|
355 |
+
self.final_norm = RMSNorm(dim)
|
356 |
+
|
357 |
+
self.stft_kwargs = dict(n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized)
|
358 |
+
|
359 |
+
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
360 |
+
|
361 |
+
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
362 |
+
|
363 |
+
assert len(freqs_per_bands) > 1
|
364 |
+
assert sum(freqs_per_bands) == freqs, f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
365 |
+
|
366 |
+
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
|
367 |
+
|
368 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
369 |
+
|
370 |
+
self.mask_estimators = nn.ModuleList([])
|
371 |
+
|
372 |
+
for _ in range(num_stems):
|
373 |
+
mask_estimator = MaskEstimator(dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth)
|
374 |
+
|
375 |
+
self.mask_estimators.append(mask_estimator)
|
376 |
+
|
377 |
+
# for the multi-resolution stft loss
|
378 |
+
|
379 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
380 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
381 |
+
self.multi_stft_n_fft = stft_n_fft
|
382 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
383 |
+
|
384 |
+
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
385 |
+
|
386 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
387 |
+
"""
|
388 |
+
einops
|
389 |
+
|
390 |
+
b - batch
|
391 |
+
f - freq
|
392 |
+
t - time
|
393 |
+
s - audio channel (1 for mono, 2 for stereo)
|
394 |
+
n - number of 'stems'
|
395 |
+
c - complex (2)
|
396 |
+
d - feature dimension
|
397 |
+
"""
|
398 |
+
|
399 |
+
original_device = raw_audio.device
|
400 |
+
x_is_mps = True if original_device.type == "mps" else False
|
401 |
+
|
402 |
+
# if x_is_mps:
|
403 |
+
# raw_audio = raw_audio.cpu()
|
404 |
+
|
405 |
+
device = raw_audio.device
|
406 |
+
|
407 |
+
if raw_audio.ndim == 2:
|
408 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
409 |
+
|
410 |
+
channels = raw_audio.shape[1]
|
411 |
+
assert (not self.stereo and channels == 1) or (
|
412 |
+
self.stereo and channels == 2
|
413 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
414 |
+
|
415 |
+
# to stft
|
416 |
+
|
417 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
418 |
+
|
419 |
+
stft_window = self.stft_window_fn().to(device)
|
420 |
+
|
421 |
+
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
422 |
+
stft_repr = torch.view_as_real(stft_repr)
|
423 |
+
|
424 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
425 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
426 |
+
|
427 |
+
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
428 |
+
|
429 |
+
x = self.band_split(x)
|
430 |
+
|
431 |
+
# axial / hierarchical attention
|
432 |
+
|
433 |
+
for transformer_block in self.layers:
|
434 |
+
|
435 |
+
if len(transformer_block) == 3:
|
436 |
+
linear_transformer, time_transformer, freq_transformer = transformer_block
|
437 |
+
|
438 |
+
x, ft_ps = pack([x], "b * d")
|
439 |
+
x = linear_transformer(x)
|
440 |
+
(x,) = unpack(x, ft_ps, "b * d")
|
441 |
+
else:
|
442 |
+
time_transformer, freq_transformer = transformer_block
|
443 |
+
|
444 |
+
x = rearrange(x, "b t f d -> b f t d")
|
445 |
+
x, ps = pack([x], "* t d")
|
446 |
+
|
447 |
+
x = time_transformer(x)
|
448 |
+
|
449 |
+
(x,) = unpack(x, ps, "* t d")
|
450 |
+
x = rearrange(x, "b f t d -> b t f d")
|
451 |
+
x, ps = pack([x], "* f d")
|
452 |
+
|
453 |
+
x = freq_transformer(x)
|
454 |
+
|
455 |
+
(x,) = unpack(x, ps, "* f d")
|
456 |
+
|
457 |
+
x = self.final_norm(x)
|
458 |
+
|
459 |
+
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
460 |
+
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
461 |
+
|
462 |
+
# if x_is_mps:
|
463 |
+
# mask = mask.to('cpu')
|
464 |
+
|
465 |
+
# modulate frequency representation
|
466 |
+
|
467 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
468 |
+
|
469 |
+
# complex number multiplication
|
470 |
+
|
471 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
472 |
+
mask = torch.view_as_complex(mask)
|
473 |
+
|
474 |
+
stft_repr = stft_repr * mask
|
475 |
+
|
476 |
+
# istft
|
477 |
+
|
478 |
+
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
479 |
+
|
480 |
+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False).to(device)
|
481 |
+
|
482 |
+
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=self.num_stems)
|
483 |
+
|
484 |
+
if self.num_stems == 1:
|
485 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
486 |
+
|
487 |
+
# if a target is passed in, calculate loss for learning
|
488 |
+
|
489 |
+
if not exists(target):
|
490 |
+
return recon_audio
|
491 |
+
|
492 |
+
if self.num_stems > 1:
|
493 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
494 |
+
|
495 |
+
if target.ndim == 2:
|
496 |
+
target = rearrange(target, "... t -> ... 1 t")
|
497 |
+
|
498 |
+
target = target[..., : recon_audio.shape[-1]]
|
499 |
+
|
500 |
+
loss = F.l1_loss(recon_audio, target)
|
501 |
+
|
502 |
+
multi_stft_resolution_loss = 0.0
|
503 |
+
|
504 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
505 |
+
res_stft_kwargs = dict(
|
506 |
+
n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs
|
507 |
+
)
|
508 |
+
|
509 |
+
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
510 |
+
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
511 |
+
|
512 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
513 |
+
|
514 |
+
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
515 |
+
|
516 |
+
total_loss = loss + weighted_multi_resolution_loss
|
517 |
+
|
518 |
+
if not return_loss_breakdown:
|
519 |
+
# Move the result back to the original device if it was moved to CPU for MPS compatibility
|
520 |
+
# if x_is_mps:
|
521 |
+
# total_loss = total_loss.to(original_device)
|
522 |
+
return total_loss
|
523 |
+
|
524 |
+
# For detailed loss breakdown, ensure all components are moved back to the original device for MPS
|
525 |
+
# if x_is_mps:
|
526 |
+
# loss = loss.to(original_device)
|
527 |
+
# multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
|
528 |
+
# weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
|
529 |
+
|
530 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|
531 |
+
|
532 |
+
# if not return_loss_breakdown:
|
533 |
+
# return total_loss
|
534 |
+
|
535 |
+
# return total_loss, (loss, multi_stft_resolution_loss)
|
audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum, Tensor
|
5 |
+
from torch.nn import Module, ModuleList
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .attend import Attend
|
9 |
+
|
10 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
11 |
+
from beartype import beartype
|
12 |
+
|
13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
14 |
+
|
15 |
+
from einops import rearrange, pack, unpack, reduce, repeat
|
16 |
+
|
17 |
+
from librosa import filters
|
18 |
+
|
19 |
+
|
20 |
+
def exists(val):
|
21 |
+
return val is not None
|
22 |
+
|
23 |
+
|
24 |
+
def default(v, d):
|
25 |
+
return v if exists(v) else d
|
26 |
+
|
27 |
+
|
28 |
+
def pack_one(t, pattern):
|
29 |
+
return pack([t], pattern)
|
30 |
+
|
31 |
+
|
32 |
+
def unpack_one(t, ps, pattern):
|
33 |
+
return unpack(t, ps, pattern)[0]
|
34 |
+
|
35 |
+
|
36 |
+
def pad_at_dim(t, pad, dim=-1, value=0.0):
|
37 |
+
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
38 |
+
zeros = (0, 0) * dims_from_right
|
39 |
+
return F.pad(t, (*zeros, *pad), value=value)
|
40 |
+
|
41 |
+
|
42 |
+
class RMSNorm(Module):
|
43 |
+
def __init__(self, dim):
|
44 |
+
super().__init__()
|
45 |
+
self.scale = dim**0.5
|
46 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = x.to(self.gamma.device)
|
50 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
51 |
+
|
52 |
+
|
53 |
+
class FeedForward(Module):
|
54 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
55 |
+
super().__init__()
|
56 |
+
dim_inner = int(dim * mult)
|
57 |
+
self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout))
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
return self.net(x)
|
61 |
+
|
62 |
+
|
63 |
+
class Attention(Module):
|
64 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
65 |
+
super().__init__()
|
66 |
+
self.heads = heads
|
67 |
+
self.scale = dim_head**-0.5
|
68 |
+
dim_inner = heads * dim_head
|
69 |
+
|
70 |
+
self.rotary_embed = rotary_embed
|
71 |
+
|
72 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
73 |
+
|
74 |
+
self.norm = RMSNorm(dim)
|
75 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
76 |
+
|
77 |
+
self.to_gates = nn.Linear(dim, heads)
|
78 |
+
|
79 |
+
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = self.norm(x)
|
83 |
+
|
84 |
+
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
85 |
+
|
86 |
+
if exists(self.rotary_embed):
|
87 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
88 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
89 |
+
|
90 |
+
out = self.attend(q, k, v)
|
91 |
+
|
92 |
+
gates = self.to_gates(x)
|
93 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
94 |
+
|
95 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
96 |
+
return self.to_out(out)
|
97 |
+
|
98 |
+
|
99 |
+
class Transformer(Module):
|
100 |
+
def __init__(self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True):
|
101 |
+
super().__init__()
|
102 |
+
self.layers = ModuleList([])
|
103 |
+
|
104 |
+
for _ in range(depth):
|
105 |
+
self.layers.append(
|
106 |
+
ModuleList(
|
107 |
+
[Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn), FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
|
115 |
+
for attn, ff in self.layers:
|
116 |
+
x = attn(x) + x
|
117 |
+
x = ff(x) + x
|
118 |
+
|
119 |
+
return self.norm(x)
|
120 |
+
|
121 |
+
|
122 |
+
class BandSplit(Module):
|
123 |
+
@beartype
|
124 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
125 |
+
super().__init__()
|
126 |
+
self.dim_inputs = dim_inputs
|
127 |
+
self.to_features = ModuleList([])
|
128 |
+
|
129 |
+
for dim_in in dim_inputs:
|
130 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
131 |
+
|
132 |
+
self.to_features.append(net)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
x = x.split(self.dim_inputs, dim=-1)
|
136 |
+
|
137 |
+
outs = []
|
138 |
+
for split_input, to_feature in zip(x, self.to_features):
|
139 |
+
split_output = to_feature(split_input)
|
140 |
+
outs.append(split_output)
|
141 |
+
|
142 |
+
return torch.stack(outs, dim=-2)
|
143 |
+
|
144 |
+
|
145 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
146 |
+
dim_hidden = default(dim_hidden, dim_in)
|
147 |
+
|
148 |
+
net = []
|
149 |
+
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
|
150 |
+
|
151 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
152 |
+
is_last = ind == (len(dims) - 2)
|
153 |
+
|
154 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
155 |
+
|
156 |
+
if is_last:
|
157 |
+
continue
|
158 |
+
|
159 |
+
net.append(activation())
|
160 |
+
|
161 |
+
return nn.Sequential(*net)
|
162 |
+
|
163 |
+
|
164 |
+
class MaskEstimator(Module):
|
165 |
+
@beartype
|
166 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
167 |
+
super().__init__()
|
168 |
+
self.dim_inputs = dim_inputs
|
169 |
+
self.to_freqs = ModuleList([])
|
170 |
+
dim_hidden = dim * mlp_expansion_factor
|
171 |
+
|
172 |
+
for dim_in in dim_inputs:
|
173 |
+
net = []
|
174 |
+
|
175 |
+
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
176 |
+
|
177 |
+
self.to_freqs.append(mlp)
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
x = x.unbind(dim=-2)
|
181 |
+
|
182 |
+
outs = []
|
183 |
+
|
184 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
185 |
+
freq_out = mlp(band_features)
|
186 |
+
outs.append(freq_out)
|
187 |
+
|
188 |
+
return torch.cat(outs, dim=-1)
|
189 |
+
|
190 |
+
|
191 |
+
class MelBandRoformer(Module):
|
192 |
+
|
193 |
+
@beartype
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
dim,
|
197 |
+
*,
|
198 |
+
depth,
|
199 |
+
stereo=False,
|
200 |
+
num_stems=1,
|
201 |
+
time_transformer_depth=2,
|
202 |
+
freq_transformer_depth=2,
|
203 |
+
num_bands=60,
|
204 |
+
dim_head=64,
|
205 |
+
heads=8,
|
206 |
+
attn_dropout=0.1,
|
207 |
+
ff_dropout=0.1,
|
208 |
+
flash_attn=True,
|
209 |
+
dim_freqs_in=1025,
|
210 |
+
sample_rate=44100,
|
211 |
+
stft_n_fft=2048,
|
212 |
+
stft_hop_length=512,
|
213 |
+
stft_win_length=2048,
|
214 |
+
stft_normalized=False,
|
215 |
+
stft_window_fn: Optional[Callable] = None,
|
216 |
+
mask_estimator_depth=1,
|
217 |
+
multi_stft_resolution_loss_weight=1.0,
|
218 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
219 |
+
multi_stft_hop_size=147,
|
220 |
+
multi_stft_normalized=False,
|
221 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
222 |
+
match_input_audio_length=False,
|
223 |
+
):
|
224 |
+
super().__init__()
|
225 |
+
|
226 |
+
self.stereo = stereo
|
227 |
+
self.audio_channels = 2 if stereo else 1
|
228 |
+
self.num_stems = num_stems
|
229 |
+
|
230 |
+
self.layers = ModuleList([])
|
231 |
+
|
232 |
+
transformer_kwargs = dict(dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn)
|
233 |
+
|
234 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
235 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
236 |
+
|
237 |
+
for _ in range(depth):
|
238 |
+
self.layers.append(
|
239 |
+
nn.ModuleList(
|
240 |
+
[
|
241 |
+
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs),
|
242 |
+
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs),
|
243 |
+
]
|
244 |
+
)
|
245 |
+
)
|
246 |
+
|
247 |
+
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
248 |
+
|
249 |
+
self.stft_kwargs = dict(n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized)
|
250 |
+
|
251 |
+
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
252 |
+
|
253 |
+
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
|
254 |
+
|
255 |
+
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
256 |
+
|
257 |
+
mel_filter_bank[0][0] = 1.0
|
258 |
+
|
259 |
+
mel_filter_bank[-1, -1] = 1.0
|
260 |
+
|
261 |
+
freqs_per_band = mel_filter_bank > 0
|
262 |
+
assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
|
263 |
+
|
264 |
+
repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
|
265 |
+
freq_indices = repeated_freq_indices[freqs_per_band]
|
266 |
+
|
267 |
+
if stereo:
|
268 |
+
freq_indices = repeat(freq_indices, "f -> f s", s=2)
|
269 |
+
freq_indices = freq_indices * 2 + torch.arange(2)
|
270 |
+
freq_indices = rearrange(freq_indices, "f s -> (f s)")
|
271 |
+
|
272 |
+
self.register_buffer("freq_indices", freq_indices, persistent=False)
|
273 |
+
self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
|
274 |
+
|
275 |
+
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
|
276 |
+
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
|
277 |
+
|
278 |
+
self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
|
279 |
+
self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
|
280 |
+
|
281 |
+
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
282 |
+
|
283 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
284 |
+
|
285 |
+
self.mask_estimators = nn.ModuleList([])
|
286 |
+
|
287 |
+
for _ in range(num_stems):
|
288 |
+
mask_estimator = MaskEstimator(dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth)
|
289 |
+
|
290 |
+
self.mask_estimators.append(mask_estimator)
|
291 |
+
|
292 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
293 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
294 |
+
self.multi_stft_n_fft = stft_n_fft
|
295 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
296 |
+
|
297 |
+
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
298 |
+
|
299 |
+
self.match_input_audio_length = match_input_audio_length
|
300 |
+
|
301 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
302 |
+
"""
|
303 |
+
einops
|
304 |
+
|
305 |
+
b - batch
|
306 |
+
f - freq
|
307 |
+
t - time
|
308 |
+
s - audio channel (1 for mono, 2 for stereo)
|
309 |
+
n - number of 'stems'
|
310 |
+
c - complex (2)
|
311 |
+
d - feature dimension
|
312 |
+
"""
|
313 |
+
|
314 |
+
original_device = raw_audio.device
|
315 |
+
x_is_mps = True if original_device.type == "mps" else False
|
316 |
+
|
317 |
+
if x_is_mps:
|
318 |
+
raw_audio = raw_audio.cpu()
|
319 |
+
|
320 |
+
device = raw_audio.device
|
321 |
+
|
322 |
+
if raw_audio.ndim == 2:
|
323 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
324 |
+
|
325 |
+
batch, channels, raw_audio_length = raw_audio.shape
|
326 |
+
|
327 |
+
istft_length = raw_audio_length if self.match_input_audio_length else None
|
328 |
+
|
329 |
+
assert (not self.stereo and channels == 1) or (
|
330 |
+
self.stereo and channels == 2
|
331 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
332 |
+
|
333 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
334 |
+
|
335 |
+
stft_window = self.stft_window_fn().to(device)
|
336 |
+
|
337 |
+
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
338 |
+
stft_repr = torch.view_as_real(stft_repr)
|
339 |
+
|
340 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
341 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
342 |
+
|
343 |
+
batch_arange = torch.arange(batch, device=device)[..., None]
|
344 |
+
|
345 |
+
x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
|
346 |
+
|
347 |
+
x = rearrange(x, "b f t c -> b t (f c)")
|
348 |
+
|
349 |
+
x = self.band_split(x)
|
350 |
+
|
351 |
+
for time_transformer, freq_transformer in self.layers:
|
352 |
+
x = rearrange(x, "b t f d -> b f t d")
|
353 |
+
x, ps = pack([x], "* t d")
|
354 |
+
|
355 |
+
x = time_transformer(x)
|
356 |
+
|
357 |
+
(x,) = unpack(x, ps, "* t d")
|
358 |
+
x = rearrange(x, "b f t d -> b t f d")
|
359 |
+
x, ps = pack([x], "* f d")
|
360 |
+
|
361 |
+
x = freq_transformer(x)
|
362 |
+
|
363 |
+
(x,) = unpack(x, ps, "* f d")
|
364 |
+
|
365 |
+
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
366 |
+
masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
|
367 |
+
|
368 |
+
if x_is_mps:
|
369 |
+
masks = masks.cpu()
|
370 |
+
|
371 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
372 |
+
|
373 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
374 |
+
masks = torch.view_as_complex(masks)
|
375 |
+
|
376 |
+
masks = masks.type(stft_repr.dtype)
|
377 |
+
|
378 |
+
if x_is_mps:
|
379 |
+
scatter_indices = repeat(self.freq_indices.cpu(), "f -> b n f t", b=batch, n=self.num_stems, t=stft_repr.shape[-1])
|
380 |
+
else:
|
381 |
+
scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=self.num_stems, t=stft_repr.shape[-1])
|
382 |
+
|
383 |
+
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=self.num_stems)
|
384 |
+
masks_summed = (
|
385 |
+
torch.zeros_like(stft_repr_expanded_stems.cpu() if x_is_mps else stft_repr_expanded_stems)
|
386 |
+
.scatter_add_(2, scatter_indices.cpu() if x_is_mps else scatter_indices, masks.cpu() if x_is_mps else masks)
|
387 |
+
.to(device)
|
388 |
+
)
|
389 |
+
|
390 |
+
denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
|
391 |
+
|
392 |
+
if x_is_mps:
|
393 |
+
denom = denom.cpu()
|
394 |
+
|
395 |
+
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
396 |
+
|
397 |
+
stft_repr = stft_repr * masks_averaged
|
398 |
+
|
399 |
+
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
400 |
+
|
401 |
+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=istft_length)
|
402 |
+
|
403 |
+
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=self.num_stems)
|
404 |
+
|
405 |
+
if self.num_stems == 1:
|
406 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
407 |
+
|
408 |
+
if not exists(target):
|
409 |
+
return recon_audio
|
410 |
+
|
411 |
+
if self.num_stems > 1:
|
412 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
413 |
+
|
414 |
+
if target.ndim == 2:
|
415 |
+
target = rearrange(target, "... t -> ... 1 t")
|
416 |
+
|
417 |
+
target = target[..., : recon_audio.shape[-1]]
|
418 |
+
|
419 |
+
loss = F.l1_loss(recon_audio, target)
|
420 |
+
|
421 |
+
multi_stft_resolution_loss = 0.0
|
422 |
+
|
423 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
424 |
+
res_stft_kwargs = dict(
|
425 |
+
n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs
|
426 |
+
)
|
427 |
+
|
428 |
+
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
429 |
+
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
430 |
+
|
431 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
432 |
+
|
433 |
+
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
434 |
+
|
435 |
+
total_loss = loss + weighted_multi_resolution_loss
|
436 |
+
|
437 |
+
# Move the total loss back to the original device if necessary
|
438 |
+
if x_is_mps:
|
439 |
+
total_loss = total_loss.to(original_device)
|
440 |
+
|
441 |
+
if not return_loss_breakdown:
|
442 |
+
return total_loss
|
443 |
+
|
444 |
+
# If detailed loss breakdown is requested, ensure all components are on the original device
|
445 |
+
return total_loss, (loss.to(original_device) if x_is_mps else loss, multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
|
audio_separator/separator/uvr_lib_v5/spec_utils.py
ADDED
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import audioread
|
2 |
+
import librosa
|
3 |
+
import numpy as np
|
4 |
+
import soundfile as sf
|
5 |
+
import math
|
6 |
+
import platform
|
7 |
+
import traceback
|
8 |
+
from audio_separator.separator.uvr_lib_v5 import pyrb
|
9 |
+
from scipy.signal import correlate, hilbert
|
10 |
+
import io
|
11 |
+
|
12 |
+
OPERATING_SYSTEM = platform.system()
|
13 |
+
SYSTEM_ARCH = platform.platform()
|
14 |
+
SYSTEM_PROC = platform.processor()
|
15 |
+
ARM = "arm"
|
16 |
+
|
17 |
+
AUTO_PHASE = "Automatic"
|
18 |
+
POSITIVE_PHASE = "Positive Phase"
|
19 |
+
NEGATIVE_PHASE = "Negative Phase"
|
20 |
+
NONE_P = ("None",)
|
21 |
+
LOW_P = ("Shifts: Low",)
|
22 |
+
MED_P = ("Shifts: Medium",)
|
23 |
+
HIGH_P = ("Shifts: High",)
|
24 |
+
VHIGH_P = "Shifts: Very High"
|
25 |
+
MAXIMUM_P = "Shifts: Maximum"
|
26 |
+
|
27 |
+
progress_value = 0
|
28 |
+
last_update_time = 0
|
29 |
+
is_macos = False
|
30 |
+
|
31 |
+
|
32 |
+
if OPERATING_SYSTEM == "Darwin":
|
33 |
+
wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
|
34 |
+
wav_resolution_float_resampling = "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution
|
35 |
+
is_macos = True
|
36 |
+
else:
|
37 |
+
wav_resolution = "sinc_fastest"
|
38 |
+
wav_resolution_float_resampling = wav_resolution
|
39 |
+
|
40 |
+
MAX_SPEC = "Max Spec"
|
41 |
+
MIN_SPEC = "Min Spec"
|
42 |
+
LIN_ENSE = "Linear Ensemble"
|
43 |
+
|
44 |
+
MAX_WAV = MAX_SPEC
|
45 |
+
MIN_WAV = MIN_SPEC
|
46 |
+
|
47 |
+
AVERAGE = "Average"
|
48 |
+
|
49 |
+
|
50 |
+
def crop_center(h1, h2):
|
51 |
+
"""
|
52 |
+
This function crops the center of the first input tensor to match the size of the second input tensor.
|
53 |
+
It is used to ensure that the two tensors have the same size in the time dimension.
|
54 |
+
"""
|
55 |
+
h1_shape = h1.size()
|
56 |
+
h2_shape = h2.size()
|
57 |
+
|
58 |
+
# If the time dimensions are already equal, return the first tensor as is
|
59 |
+
if h1_shape[3] == h2_shape[3]:
|
60 |
+
return h1
|
61 |
+
# If the time dimension of the first tensor is smaller, raise an error
|
62 |
+
elif h1_shape[3] < h2_shape[3]:
|
63 |
+
raise ValueError("h1_shape[3] must be greater than h2_shape[3]")
|
64 |
+
|
65 |
+
# Calculate the start and end indices for cropping
|
66 |
+
s_time = (h1_shape[3] - h2_shape[3]) // 2
|
67 |
+
e_time = s_time + h2_shape[3]
|
68 |
+
# Crop the first tensor
|
69 |
+
h1 = h1[:, :, :, s_time:e_time]
|
70 |
+
|
71 |
+
return h1
|
72 |
+
|
73 |
+
|
74 |
+
def preprocess(X_spec):
|
75 |
+
"""
|
76 |
+
This function preprocesses a spectrogram by separating it into magnitude and phase components.
|
77 |
+
This is a common preprocessing step in audio processing tasks.
|
78 |
+
"""
|
79 |
+
X_mag = np.abs(X_spec)
|
80 |
+
X_phase = np.angle(X_spec)
|
81 |
+
|
82 |
+
return X_mag, X_phase
|
83 |
+
|
84 |
+
|
85 |
+
def make_padding(width, cropsize, offset):
|
86 |
+
"""
|
87 |
+
This function calculates the padding needed to make the width of an image divisible by the crop size.
|
88 |
+
It is used in the process of splitting an image into smaller patches.
|
89 |
+
"""
|
90 |
+
left = offset
|
91 |
+
roi_size = cropsize - offset * 2
|
92 |
+
if roi_size == 0:
|
93 |
+
roi_size = cropsize
|
94 |
+
right = roi_size - (width % roi_size) + left
|
95 |
+
|
96 |
+
return left, right, roi_size
|
97 |
+
|
98 |
+
|
99 |
+
def normalize(wave, max_peak=1.0, min_peak=None):
|
100 |
+
"""Normalize (or amplify) audio waveform to a specified peak value.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
wave (array-like): Audio waveform.
|
104 |
+
max_peak (float): Maximum peak value for normalization.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
array-like: Normalized or original waveform.
|
108 |
+
"""
|
109 |
+
maxv = np.abs(wave).max()
|
110 |
+
if maxv > max_peak:
|
111 |
+
wave *= max_peak / maxv
|
112 |
+
elif min_peak is not None and maxv < min_peak:
|
113 |
+
wave *= min_peak / maxv
|
114 |
+
|
115 |
+
return wave
|
116 |
+
|
117 |
+
|
118 |
+
def auto_transpose(audio_array: np.ndarray):
|
119 |
+
"""
|
120 |
+
Ensure that the audio array is in the (channels, samples) format.
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
audio_array (ndarray): Input audio array.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
ndarray: Transposed audio array if necessary.
|
127 |
+
"""
|
128 |
+
|
129 |
+
# If the second dimension is 2 (indicating stereo channels), transpose the array
|
130 |
+
if audio_array.shape[1] == 2:
|
131 |
+
return audio_array.T
|
132 |
+
return audio_array
|
133 |
+
|
134 |
+
|
135 |
+
def write_array_to_mem(audio_data, subtype):
|
136 |
+
if isinstance(audio_data, np.ndarray):
|
137 |
+
audio_buffer = io.BytesIO()
|
138 |
+
sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
|
139 |
+
audio_buffer.seek(0)
|
140 |
+
return audio_buffer
|
141 |
+
else:
|
142 |
+
return audio_data
|
143 |
+
|
144 |
+
|
145 |
+
def spectrogram_to_image(spec, mode="magnitude"):
|
146 |
+
if mode == "magnitude":
|
147 |
+
if np.iscomplexobj(spec):
|
148 |
+
y = np.abs(spec)
|
149 |
+
else:
|
150 |
+
y = spec
|
151 |
+
y = np.log10(y**2 + 1e-8)
|
152 |
+
elif mode == "phase":
|
153 |
+
if np.iscomplexobj(spec):
|
154 |
+
y = np.angle(spec)
|
155 |
+
else:
|
156 |
+
y = spec
|
157 |
+
|
158 |
+
y -= y.min()
|
159 |
+
y *= 255 / y.max()
|
160 |
+
img = np.uint8(y)
|
161 |
+
|
162 |
+
if y.ndim == 3:
|
163 |
+
img = img.transpose(1, 2, 0)
|
164 |
+
img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
|
165 |
+
|
166 |
+
return img
|
167 |
+
|
168 |
+
|
169 |
+
def reduce_vocal_aggressively(X, y, softmask):
|
170 |
+
v = X - y
|
171 |
+
y_mag_tmp = np.abs(y)
|
172 |
+
v_mag_tmp = np.abs(v)
|
173 |
+
|
174 |
+
v_mask = v_mag_tmp > y_mag_tmp
|
175 |
+
y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
|
176 |
+
|
177 |
+
return y_mag * np.exp(1.0j * np.angle(y))
|
178 |
+
|
179 |
+
|
180 |
+
def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
|
181 |
+
mask = y_mask
|
182 |
+
|
183 |
+
try:
|
184 |
+
if min_range < fade_size * 2:
|
185 |
+
raise ValueError("min_range must be >= fade_size * 2")
|
186 |
+
|
187 |
+
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
|
188 |
+
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
|
189 |
+
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
|
190 |
+
artifact_idx = np.where(end_idx - start_idx > min_range)[0]
|
191 |
+
weight = np.zeros_like(y_mask)
|
192 |
+
if len(artifact_idx) > 0:
|
193 |
+
start_idx = start_idx[artifact_idx]
|
194 |
+
end_idx = end_idx[artifact_idx]
|
195 |
+
old_e = None
|
196 |
+
for s, e in zip(start_idx, end_idx):
|
197 |
+
if old_e is not None and s - old_e < fade_size:
|
198 |
+
s = old_e - fade_size * 2
|
199 |
+
|
200 |
+
if s != 0:
|
201 |
+
weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
|
202 |
+
else:
|
203 |
+
s -= fade_size
|
204 |
+
|
205 |
+
if e != y_mask.shape[2]:
|
206 |
+
weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
|
207 |
+
else:
|
208 |
+
e += fade_size
|
209 |
+
|
210 |
+
weight[:, :, s + fade_size : e - fade_size] = 1
|
211 |
+
old_e = e
|
212 |
+
|
213 |
+
v_mask = 1 - y_mask
|
214 |
+
y_mask += weight * v_mask
|
215 |
+
|
216 |
+
mask = y_mask
|
217 |
+
except Exception as e:
|
218 |
+
error_name = f"{type(e).__name__}"
|
219 |
+
traceback_text = "".join(traceback.format_tb(e.__traceback__))
|
220 |
+
message = f'{error_name}: "{e}"\n{traceback_text}"'
|
221 |
+
print("Post Process Failed: ", message)
|
222 |
+
|
223 |
+
return mask
|
224 |
+
|
225 |
+
|
226 |
+
def align_wave_head_and_tail(a, b):
|
227 |
+
l = min([a[0].size, b[0].size])
|
228 |
+
|
229 |
+
return a[:l, :l], b[:l, :l]
|
230 |
+
|
231 |
+
|
232 |
+
def convert_channels(spec, mp, band):
|
233 |
+
cc = mp.param["band"][band].get("convert_channels")
|
234 |
+
|
235 |
+
if "mid_side_c" == cc:
|
236 |
+
spec_left = np.add(spec[0], spec[1] * 0.25)
|
237 |
+
spec_right = np.subtract(spec[1], spec[0] * 0.25)
|
238 |
+
elif "mid_side" == cc:
|
239 |
+
spec_left = np.add(spec[0], spec[1]) / 2
|
240 |
+
spec_right = np.subtract(spec[0], spec[1])
|
241 |
+
elif "stereo_n" == cc:
|
242 |
+
spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
|
243 |
+
spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
|
244 |
+
else:
|
245 |
+
return spec
|
246 |
+
|
247 |
+
return np.asfortranarray([spec_left, spec_right])
|
248 |
+
|
249 |
+
|
250 |
+
def combine_spectrograms(specs, mp, is_v51_model=False):
|
251 |
+
l = min([specs[i].shape[2] for i in specs])
|
252 |
+
spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
|
253 |
+
offset = 0
|
254 |
+
bands_n = len(mp.param["band"])
|
255 |
+
|
256 |
+
for d in range(1, bands_n + 1):
|
257 |
+
h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
|
258 |
+
spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
|
259 |
+
offset += h
|
260 |
+
|
261 |
+
if offset > mp.param["bins"]:
|
262 |
+
raise ValueError("Too much bins")
|
263 |
+
|
264 |
+
# lowpass fiter
|
265 |
+
|
266 |
+
if mp.param["pre_filter_start"] > 0:
|
267 |
+
if is_v51_model:
|
268 |
+
spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
269 |
+
else:
|
270 |
+
if bands_n == 1:
|
271 |
+
spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
272 |
+
else:
|
273 |
+
gp = 1
|
274 |
+
for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
|
275 |
+
g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
|
276 |
+
gp = g
|
277 |
+
spec_c[:, b, :] *= g
|
278 |
+
|
279 |
+
return np.asfortranarray(spec_c)
|
280 |
+
|
281 |
+
|
282 |
+
def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
|
283 |
+
|
284 |
+
if wave.ndim == 1:
|
285 |
+
wave = np.asfortranarray([wave, wave])
|
286 |
+
|
287 |
+
if not is_v51_model:
|
288 |
+
if mp.param["reverse"]:
|
289 |
+
wave_left = np.flip(np.asfortranarray(wave[0]))
|
290 |
+
wave_right = np.flip(np.asfortranarray(wave[1]))
|
291 |
+
elif mp.param["mid_side"]:
|
292 |
+
wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
|
293 |
+
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
|
294 |
+
elif mp.param["mid_side_b2"]:
|
295 |
+
wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
|
296 |
+
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
|
297 |
+
else:
|
298 |
+
wave_left = np.asfortranarray(wave[0])
|
299 |
+
wave_right = np.asfortranarray(wave[1])
|
300 |
+
else:
|
301 |
+
wave_left = np.asfortranarray(wave[0])
|
302 |
+
wave_right = np.asfortranarray(wave[1])
|
303 |
+
|
304 |
+
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
305 |
+
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
306 |
+
|
307 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
308 |
+
|
309 |
+
if is_v51_model:
|
310 |
+
spec = convert_channels(spec, mp, band)
|
311 |
+
|
312 |
+
return spec
|
313 |
+
|
314 |
+
|
315 |
+
def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
|
316 |
+
spec_left = np.asfortranarray(spec[0])
|
317 |
+
spec_right = np.asfortranarray(spec[1])
|
318 |
+
|
319 |
+
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
320 |
+
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
321 |
+
|
322 |
+
if is_v51_model:
|
323 |
+
cc = mp.param["band"][band].get("convert_channels")
|
324 |
+
if "mid_side_c" == cc:
|
325 |
+
return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
|
326 |
+
elif "mid_side" == cc:
|
327 |
+
return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
328 |
+
elif "stereo_n" == cc:
|
329 |
+
return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
|
330 |
+
else:
|
331 |
+
if mp.param["reverse"]:
|
332 |
+
return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
|
333 |
+
elif mp.param["mid_side"]:
|
334 |
+
return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
335 |
+
elif mp.param["mid_side_b2"]:
|
336 |
+
return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
|
337 |
+
|
338 |
+
return np.asfortranarray([wave_left, wave_right])
|
339 |
+
|
340 |
+
|
341 |
+
def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
|
342 |
+
bands_n = len(mp.param["band"])
|
343 |
+
offset = 0
|
344 |
+
|
345 |
+
for d in range(1, bands_n + 1):
|
346 |
+
bp = mp.param["band"][d]
|
347 |
+
spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
|
348 |
+
h = bp["crop_stop"] - bp["crop_start"]
|
349 |
+
spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
|
350 |
+
|
351 |
+
offset += h
|
352 |
+
if d == bands_n: # higher
|
353 |
+
if extra_bins_h: # if --high_end_process bypass
|
354 |
+
max_bin = bp["n_fft"] // 2
|
355 |
+
spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
|
356 |
+
if bp["hpf_start"] > 0:
|
357 |
+
if is_v51_model:
|
358 |
+
spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
359 |
+
else:
|
360 |
+
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
361 |
+
if bands_n == 1:
|
362 |
+
wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
|
363 |
+
else:
|
364 |
+
wave = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
365 |
+
else:
|
366 |
+
sr = mp.param["band"][d + 1]["sr"]
|
367 |
+
if d == 1: # lower
|
368 |
+
if is_v51_model:
|
369 |
+
spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
370 |
+
else:
|
371 |
+
spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
372 |
+
|
373 |
+
try:
|
374 |
+
wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
375 |
+
except ValueError as e:
|
376 |
+
print(f"Error during resampling: {e}")
|
377 |
+
print(f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}")
|
378 |
+
|
379 |
+
else: # mid
|
380 |
+
if is_v51_model:
|
381 |
+
spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
382 |
+
spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
383 |
+
else:
|
384 |
+
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
385 |
+
spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
386 |
+
|
387 |
+
wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
388 |
+
|
389 |
+
try:
|
390 |
+
wave = librosa.resample(wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
391 |
+
except ValueError as e:
|
392 |
+
print(f"Error during resampling: {e}")
|
393 |
+
print(f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}")
|
394 |
+
|
395 |
+
return wave
|
396 |
+
|
397 |
+
|
398 |
+
def get_lp_filter_mask(n_bins, bin_start, bin_stop):
|
399 |
+
mask = np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
|
400 |
+
|
401 |
+
return mask
|
402 |
+
|
403 |
+
|
404 |
+
def get_hp_filter_mask(n_bins, bin_start, bin_stop):
|
405 |
+
mask = np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
|
406 |
+
|
407 |
+
return mask
|
408 |
+
|
409 |
+
|
410 |
+
def fft_lp_filter(spec, bin_start, bin_stop):
|
411 |
+
g = 1.0
|
412 |
+
for b in range(bin_start, bin_stop):
|
413 |
+
g -= 1 / (bin_stop - bin_start)
|
414 |
+
spec[:, b, :] = g * spec[:, b, :]
|
415 |
+
|
416 |
+
spec[:, bin_stop:, :] *= 0
|
417 |
+
|
418 |
+
return spec
|
419 |
+
|
420 |
+
|
421 |
+
def fft_hp_filter(spec, bin_start, bin_stop):
|
422 |
+
g = 1.0
|
423 |
+
for b in range(bin_start, bin_stop, -1):
|
424 |
+
g -= 1 / (bin_start - bin_stop)
|
425 |
+
spec[:, b, :] = g * spec[:, b, :]
|
426 |
+
|
427 |
+
spec[:, 0 : bin_stop + 1, :] *= 0
|
428 |
+
|
429 |
+
return spec
|
430 |
+
|
431 |
+
|
432 |
+
def spectrogram_to_wave_old(spec, hop_length=1024):
|
433 |
+
if spec.ndim == 2:
|
434 |
+
wave = librosa.istft(spec, hop_length=hop_length)
|
435 |
+
elif spec.ndim == 3:
|
436 |
+
spec_left = np.asfortranarray(spec[0])
|
437 |
+
spec_right = np.asfortranarray(spec[1])
|
438 |
+
|
439 |
+
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
440 |
+
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
441 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
442 |
+
|
443 |
+
return wave
|
444 |
+
|
445 |
+
|
446 |
+
def wave_to_spectrogram_old(wave, hop_length, n_fft):
|
447 |
+
wave_left = np.asfortranarray(wave[0])
|
448 |
+
wave_right = np.asfortranarray(wave[1])
|
449 |
+
|
450 |
+
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
451 |
+
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
452 |
+
|
453 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
454 |
+
|
455 |
+
return spec
|
456 |
+
|
457 |
+
|
458 |
+
def mirroring(a, spec_m, input_high_end, mp):
|
459 |
+
if "mirroring" == a:
|
460 |
+
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
|
461 |
+
mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
|
462 |
+
|
463 |
+
return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
|
464 |
+
|
465 |
+
if "mirroring2" == a:
|
466 |
+
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
|
467 |
+
mi = np.multiply(mirror, input_high_end * 1.7)
|
468 |
+
|
469 |
+
return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
|
470 |
+
|
471 |
+
|
472 |
+
def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
|
473 |
+
aggr = aggressiveness["value"] * 2
|
474 |
+
|
475 |
+
if aggr != 0:
|
476 |
+
if is_non_accom_stem:
|
477 |
+
aggr = 1 - aggr
|
478 |
+
|
479 |
+
if np.any(aggr > 10) or np.any(aggr < -10):
|
480 |
+
print(f"Warning: Extreme aggressiveness values detected: {aggr}")
|
481 |
+
|
482 |
+
aggr = [aggr, aggr]
|
483 |
+
|
484 |
+
if aggressiveness["aggr_correction"] is not None:
|
485 |
+
aggr[0] += aggressiveness["aggr_correction"]["left"]
|
486 |
+
aggr[1] += aggressiveness["aggr_correction"]["right"]
|
487 |
+
|
488 |
+
for ch in range(2):
|
489 |
+
mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
|
490 |
+
mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
|
491 |
+
|
492 |
+
return mask
|
493 |
+
|
494 |
+
|
495 |
+
def stft(wave, nfft, hl):
|
496 |
+
wave_left = np.asfortranarray(wave[0])
|
497 |
+
wave_right = np.asfortranarray(wave[1])
|
498 |
+
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
499 |
+
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
500 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
501 |
+
|
502 |
+
return spec
|
503 |
+
|
504 |
+
|
505 |
+
def istft(spec, hl):
|
506 |
+
spec_left = np.asfortranarray(spec[0])
|
507 |
+
spec_right = np.asfortranarray(spec[1])
|
508 |
+
wave_left = librosa.istft(spec_left, hop_length=hl)
|
509 |
+
wave_right = librosa.istft(spec_right, hop_length=hl)
|
510 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
511 |
+
|
512 |
+
return wave
|
513 |
+
|
514 |
+
|
515 |
+
def spec_effects(wave, algorithm="Default", value=None):
|
516 |
+
if np.isnan(wave).any() or np.isinf(wave).any():
|
517 |
+
print(f"Warning: Detected NaN or infinite values in wave input. Shape: {wave.shape}")
|
518 |
+
|
519 |
+
spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
|
520 |
+
if algorithm == "Min_Mag":
|
521 |
+
v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
|
522 |
+
wave = istft(v_spec_m, 1024)
|
523 |
+
elif algorithm == "Max_Mag":
|
524 |
+
v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
|
525 |
+
wave = istft(v_spec_m, 1024)
|
526 |
+
elif algorithm == "Default":
|
527 |
+
wave = (wave[1] * value) + (wave[0] * (1 - value))
|
528 |
+
elif algorithm == "Invert_p":
|
529 |
+
X_mag = np.abs(spec[0])
|
530 |
+
y_mag = np.abs(spec[1])
|
531 |
+
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
532 |
+
v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0]))
|
533 |
+
wave = istft(v_spec, 1024)
|
534 |
+
|
535 |
+
return wave
|
536 |
+
|
537 |
+
|
538 |
+
def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
|
539 |
+
wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
|
540 |
+
|
541 |
+
if wave.ndim == 1:
|
542 |
+
wave = np.asfortranarray([wave, wave])
|
543 |
+
|
544 |
+
return wave
|
545 |
+
|
546 |
+
|
547 |
+
def wave_to_spectrogram_no_mp(wave):
|
548 |
+
|
549 |
+
spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
|
550 |
+
|
551 |
+
if spec.ndim == 1:
|
552 |
+
spec = np.asfortranarray([spec, spec])
|
553 |
+
|
554 |
+
return spec
|
555 |
+
|
556 |
+
|
557 |
+
def invert_audio(specs, invert_p=True):
|
558 |
+
|
559 |
+
ln = min([specs[0].shape[2], specs[1].shape[2]])
|
560 |
+
specs[0] = specs[0][:, :, :ln]
|
561 |
+
specs[1] = specs[1][:, :, :ln]
|
562 |
+
|
563 |
+
if invert_p:
|
564 |
+
X_mag = np.abs(specs[0])
|
565 |
+
y_mag = np.abs(specs[1])
|
566 |
+
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
567 |
+
v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0]))
|
568 |
+
else:
|
569 |
+
specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
|
570 |
+
v_spec = specs[0] - specs[1]
|
571 |
+
|
572 |
+
return v_spec
|
573 |
+
|
574 |
+
|
575 |
+
def invert_stem(mixture, stem):
|
576 |
+
mixture = wave_to_spectrogram_no_mp(mixture)
|
577 |
+
stem = wave_to_spectrogram_no_mp(stem)
|
578 |
+
output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
|
579 |
+
|
580 |
+
return -output.T
|
581 |
+
|
582 |
+
|
583 |
+
def ensembling(a, inputs, is_wavs=False):
|
584 |
+
|
585 |
+
for i in range(1, len(inputs)):
|
586 |
+
if i == 1:
|
587 |
+
input = inputs[0]
|
588 |
+
|
589 |
+
if is_wavs:
|
590 |
+
ln = min([input.shape[1], inputs[i].shape[1]])
|
591 |
+
input = input[:, :ln]
|
592 |
+
inputs[i] = inputs[i][:, :ln]
|
593 |
+
else:
|
594 |
+
ln = min([input.shape[2], inputs[i].shape[2]])
|
595 |
+
input = input[:, :, :ln]
|
596 |
+
inputs[i] = inputs[i][:, :, :ln]
|
597 |
+
|
598 |
+
if MIN_SPEC == a:
|
599 |
+
input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
|
600 |
+
if MAX_SPEC == a:
|
601 |
+
#input = np.array(np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), inputs[i], input), dtype=object)
|
602 |
+
input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
|
603 |
+
#max_spec = np.array([np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), s, specs[0]) for s in specs[1:]], dtype=object)[-1]
|
604 |
+
|
605 |
+
# linear_ensemble
|
606 |
+
# input = ensemble_wav(inputs, split_size=1)
|
607 |
+
|
608 |
+
return input
|
609 |
+
|
610 |
+
|
611 |
+
def ensemble_for_align(waves):
|
612 |
+
|
613 |
+
specs = []
|
614 |
+
|
615 |
+
for wav in waves:
|
616 |
+
spec = wave_to_spectrogram_no_mp(wav.T)
|
617 |
+
specs.append(spec)
|
618 |
+
|
619 |
+
wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
|
620 |
+
wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
|
621 |
+
|
622 |
+
return wav_aligned
|
623 |
+
|
624 |
+
|
625 |
+
def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
|
626 |
+
|
627 |
+
wavs_ = []
|
628 |
+
|
629 |
+
if algorithm == AVERAGE:
|
630 |
+
output = average_audio(audio_input)
|
631 |
+
samplerate = 44100
|
632 |
+
else:
|
633 |
+
specs = []
|
634 |
+
|
635 |
+
for i in range(len(audio_input)):
|
636 |
+
wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
|
637 |
+
wavs_.append(wave)
|
638 |
+
spec = wave if is_wave else wave_to_spectrogram_no_mp(wave)
|
639 |
+
specs.append(spec)
|
640 |
+
|
641 |
+
wave_shapes = [w.shape[1] for w in wavs_]
|
642 |
+
target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
|
643 |
+
|
644 |
+
if is_wave:
|
645 |
+
output = ensembling(algorithm, specs, is_wavs=True)
|
646 |
+
else:
|
647 |
+
output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
|
648 |
+
|
649 |
+
output = to_shape(output, target_shape.shape)
|
650 |
+
|
651 |
+
sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
|
652 |
+
|
653 |
+
|
654 |
+
def to_shape(x, target_shape):
|
655 |
+
padding_list = []
|
656 |
+
for x_dim, target_dim in zip(x.shape, target_shape):
|
657 |
+
pad_value = target_dim - x_dim
|
658 |
+
pad_tuple = (0, pad_value)
|
659 |
+
padding_list.append(pad_tuple)
|
660 |
+
|
661 |
+
return np.pad(x, tuple(padding_list), mode="constant")
|
662 |
+
|
663 |
+
|
664 |
+
def to_shape_minimize(x: np.ndarray, target_shape):
|
665 |
+
|
666 |
+
padding_list = []
|
667 |
+
for x_dim, target_dim in zip(x.shape, target_shape):
|
668 |
+
pad_value = target_dim - x_dim
|
669 |
+
pad_tuple = (0, pad_value)
|
670 |
+
padding_list.append(pad_tuple)
|
671 |
+
|
672 |
+
return np.pad(x, tuple(padding_list), mode="constant")
|
673 |
+
|
674 |
+
|
675 |
+
def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
|
676 |
+
"""
|
677 |
+
Detect silence at the beginning of an audio signal.
|
678 |
+
|
679 |
+
:param audio: np.array, audio signal
|
680 |
+
:param sr: int, sample rate
|
681 |
+
:param silence_threshold: float, magnitude threshold below which is considered silence
|
682 |
+
:param frame_length: int, the number of samples to consider for each check
|
683 |
+
|
684 |
+
:return: float, duration of the leading silence in milliseconds
|
685 |
+
"""
|
686 |
+
|
687 |
+
if len(audio.shape) == 2:
|
688 |
+
# If stereo, pick the channel with more energy to determine the silence
|
689 |
+
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
690 |
+
audio = audio[channel]
|
691 |
+
|
692 |
+
for i in range(0, len(audio), frame_length):
|
693 |
+
if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold:
|
694 |
+
return (i / sr) * 1000
|
695 |
+
|
696 |
+
return (len(audio) / sr) * 1000
|
697 |
+
|
698 |
+
|
699 |
+
def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
|
700 |
+
"""
|
701 |
+
Adjust the leading silence of the target_audio to match the leading silence of the reference_audio.
|
702 |
+
|
703 |
+
:param target_audio: np.array, audio signal that will have its silence adjusted
|
704 |
+
:param reference_audio: np.array, audio signal used as a reference
|
705 |
+
:param sr: int, sample rate
|
706 |
+
:param silence_threshold: float, magnitude threshold below which is considered silence
|
707 |
+
:param frame_length: int, the number of samples to consider for each check
|
708 |
+
|
709 |
+
:return: np.array, target_audio adjusted to have the same leading silence as reference_audio
|
710 |
+
"""
|
711 |
+
|
712 |
+
def find_silence_end(audio):
|
713 |
+
if len(audio.shape) == 2:
|
714 |
+
# If stereo, pick the channel with more energy to determine the silence
|
715 |
+
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
716 |
+
audio_mono = audio[channel]
|
717 |
+
else:
|
718 |
+
audio_mono = audio
|
719 |
+
|
720 |
+
for i in range(0, len(audio_mono), frame_length):
|
721 |
+
if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold:
|
722 |
+
return i
|
723 |
+
return len(audio_mono)
|
724 |
+
|
725 |
+
ref_silence_end = find_silence_end(reference_audio)
|
726 |
+
target_silence_end = find_silence_end(target_audio)
|
727 |
+
silence_difference = ref_silence_end - target_silence_end
|
728 |
+
|
729 |
+
try:
|
730 |
+
ref_silence_end_p = (ref_silence_end / 44100) * 1000
|
731 |
+
target_silence_end_p = (target_silence_end / 44100) * 1000
|
732 |
+
silence_difference_p = ref_silence_end_p - target_silence_end_p
|
733 |
+
print("silence_difference: ", silence_difference_p)
|
734 |
+
except Exception as e:
|
735 |
+
pass
|
736 |
+
|
737 |
+
if silence_difference > 0: # Add silence to target_audio
|
738 |
+
if len(target_audio.shape) == 2: # stereo
|
739 |
+
silence_to_add = np.zeros((target_audio.shape[0], silence_difference))
|
740 |
+
else: # mono
|
741 |
+
silence_to_add = np.zeros(silence_difference)
|
742 |
+
return np.hstack((silence_to_add, target_audio))
|
743 |
+
elif silence_difference < 0: # Remove silence from target_audio
|
744 |
+
if len(target_audio.shape) == 2: # stereo
|
745 |
+
return target_audio[:, -silence_difference:]
|
746 |
+
else: # mono
|
747 |
+
return target_audio[-silence_difference:]
|
748 |
+
else: # No adjustment needed
|
749 |
+
return target_audio
|
750 |
+
|
751 |
+
|
752 |
+
def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False):
|
753 |
+
|
754 |
+
if is_swap:
|
755 |
+
array_1, array_2 = array_1.T, array_2.T
|
756 |
+
|
757 |
+
# print("before", array_1.shape, array_2.shape)
|
758 |
+
if array_1.shape[1] > array_2.shape[1]:
|
759 |
+
array_1 = array_1[:, : array_2.shape[1]]
|
760 |
+
elif array_1.shape[1] < array_2.shape[1]:
|
761 |
+
padding = array_2.shape[1] - array_1.shape[1]
|
762 |
+
array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
|
763 |
+
|
764 |
+
# print("after", array_1.shape, array_2.shape)
|
765 |
+
|
766 |
+
if is_swap:
|
767 |
+
array_1, array_2 = array_1.T, array_2.T
|
768 |
+
|
769 |
+
return array_1
|
770 |
+
|
771 |
+
|
772 |
+
def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
|
773 |
+
|
774 |
+
if len(array_1) > len(array_2):
|
775 |
+
array_1 = array_1[: len(array_2)]
|
776 |
+
elif len(array_1) < len(array_2):
|
777 |
+
padding = len(array_2) - len(array_1)
|
778 |
+
array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
|
779 |
+
|
780 |
+
return array_1
|
781 |
+
|
782 |
+
|
783 |
+
def change_pitch_semitones(y, sr, semitone_shift):
|
784 |
+
factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling
|
785 |
+
y_pitch_tuned = []
|
786 |
+
for y_channel in y:
|
787 |
+
y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
|
788 |
+
y_pitch_tuned = np.array(y_pitch_tuned)
|
789 |
+
new_sr = sr * factor
|
790 |
+
return y_pitch_tuned, new_sr
|
791 |
+
|
792 |
+
|
793 |
+
def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
|
794 |
+
|
795 |
+
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
796 |
+
|
797 |
+
if wav.ndim == 1:
|
798 |
+
wav = np.asfortranarray([wav, wav])
|
799 |
+
|
800 |
+
if not is_time_correction:
|
801 |
+
wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
|
802 |
+
else:
|
803 |
+
if is_pitch:
|
804 |
+
wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None)
|
805 |
+
wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None)
|
806 |
+
else:
|
807 |
+
wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None)
|
808 |
+
wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None)
|
809 |
+
|
810 |
+
if wav_1.shape > wav_2.shape:
|
811 |
+
wav_2 = to_shape(wav_2, wav_1.shape)
|
812 |
+
if wav_1.shape < wav_2.shape:
|
813 |
+
wav_1 = to_shape(wav_1, wav_2.shape)
|
814 |
+
|
815 |
+
wav_mix = np.asfortranarray([wav_1, wav_2])
|
816 |
+
|
817 |
+
sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
|
818 |
+
save_format(export_path)
|
819 |
+
|
820 |
+
|
821 |
+
def average_audio(audio):
|
822 |
+
|
823 |
+
waves = []
|
824 |
+
wave_shapes = []
|
825 |
+
final_waves = []
|
826 |
+
|
827 |
+
for i in range(len(audio)):
|
828 |
+
wave = librosa.load(audio[i], sr=44100, mono=False)
|
829 |
+
waves.append(wave[0])
|
830 |
+
wave_shapes.append(wave[0].shape[1])
|
831 |
+
|
832 |
+
wave_shapes_index = wave_shapes.index(max(wave_shapes))
|
833 |
+
target_shape = waves[wave_shapes_index]
|
834 |
+
waves.pop(wave_shapes_index)
|
835 |
+
final_waves.append(target_shape)
|
836 |
+
|
837 |
+
for n_array in waves:
|
838 |
+
wav_target = to_shape(n_array, target_shape.shape)
|
839 |
+
final_waves.append(wav_target)
|
840 |
+
|
841 |
+
waves = sum(final_waves)
|
842 |
+
waves = waves / len(audio)
|
843 |
+
|
844 |
+
return waves
|
845 |
+
|
846 |
+
|
847 |
+
def average_dual_sources(wav_1, wav_2, value):
|
848 |
+
|
849 |
+
if wav_1.shape > wav_2.shape:
|
850 |
+
wav_2 = to_shape(wav_2, wav_1.shape)
|
851 |
+
if wav_1.shape < wav_2.shape:
|
852 |
+
wav_1 = to_shape(wav_1, wav_2.shape)
|
853 |
+
|
854 |
+
wave = (wav_1 * value) + (wav_2 * (1 - value))
|
855 |
+
|
856 |
+
return wave
|
857 |
+
|
858 |
+
|
859 |
+
def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
|
860 |
+
|
861 |
+
if wav_1.shape > wav_2.shape:
|
862 |
+
wav_2 = to_shape(wav_2, wav_1.shape)
|
863 |
+
if wav_1.shape < wav_2.shape:
|
864 |
+
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
865 |
+
wav_2 = wav_2[:, :ln]
|
866 |
+
|
867 |
+
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
868 |
+
wav_1 = wav_1[:, :ln]
|
869 |
+
wav_2 = wav_2[:, :ln]
|
870 |
+
|
871 |
+
return wav_2
|
872 |
+
|
873 |
+
|
874 |
+
def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray):
|
875 |
+
|
876 |
+
if wav_1_shape > wav_2.shape:
|
877 |
+
wav_2 = to_shape(wav_2, wav_1_shape)
|
878 |
+
|
879 |
+
return wav_2
|
880 |
+
|
881 |
+
|
882 |
+
def combine_arrarys(audio_sources, is_swap=False):
|
883 |
+
source = np.zeros_like(max(audio_sources, key=np.size))
|
884 |
+
|
885 |
+
for v in audio_sources:
|
886 |
+
v = match_array_shapes(v, source, is_swap=is_swap)
|
887 |
+
source += v
|
888 |
+
|
889 |
+
return source
|
890 |
+
|
891 |
+
|
892 |
+
def combine_audio(paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
|
893 |
+
|
894 |
+
source = combine_arrarys([load_audio(i) for i in paths])
|
895 |
+
save_path = f"{audio_file_base}_combined.wav"
|
896 |
+
sf.write(save_path, source.T, 44100, subtype=wav_type_set)
|
897 |
+
save_format(save_path)
|
898 |
+
|
899 |
+
|
900 |
+
def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
|
901 |
+
# Reduce the volume
|
902 |
+
inst_source = inst_source * (1 - reduction_rate)
|
903 |
+
|
904 |
+
mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True)
|
905 |
+
|
906 |
+
return mix_reduced
|
907 |
+
|
908 |
+
|
909 |
+
def organize_inputs(inputs):
|
910 |
+
input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
|
911 |
+
|
912 |
+
for i in inputs:
|
913 |
+
if i.endswith("_(Vocals).wav"):
|
914 |
+
input_list["reference"] = i
|
915 |
+
elif "_RVC_" in i:
|
916 |
+
input_list["target"] = i
|
917 |
+
elif i.endswith("reverbed_stem.wav"):
|
918 |
+
input_list["reverb"] = i
|
919 |
+
elif i.endswith("_(Instrumental).wav"):
|
920 |
+
input_list["inst"] = i
|
921 |
+
|
922 |
+
return input_list
|
923 |
+
|
924 |
+
|
925 |
+
def check_if_phase_inverted(wav1, wav2, is_mono=False):
|
926 |
+
# Load the audio files
|
927 |
+
if not is_mono:
|
928 |
+
wav1 = np.mean(wav1, axis=0)
|
929 |
+
wav2 = np.mean(wav2, axis=0)
|
930 |
+
|
931 |
+
# Compute the correlation
|
932 |
+
correlation = np.corrcoef(wav1[:1000], wav2[:1000])
|
933 |
+
|
934 |
+
return correlation[0, 1] < 0
|
935 |
+
|
936 |
+
|
937 |
+
def align_audio(
|
938 |
+
file1,
|
939 |
+
file2,
|
940 |
+
file2_aligned,
|
941 |
+
file_subtracted,
|
942 |
+
wav_type_set,
|
943 |
+
is_save_aligned,
|
944 |
+
command_Text,
|
945 |
+
save_format,
|
946 |
+
align_window: list,
|
947 |
+
align_intro_val: list,
|
948 |
+
db_analysis: tuple,
|
949 |
+
set_progress_bar,
|
950 |
+
phase_option,
|
951 |
+
phase_shifts,
|
952 |
+
is_match_silence,
|
953 |
+
is_spec_match,
|
954 |
+
):
|
955 |
+
|
956 |
+
global progress_value
|
957 |
+
progress_value = 0
|
958 |
+
is_mono = False
|
959 |
+
|
960 |
+
def get_diff(a, b):
|
961 |
+
corr = np.correlate(a, b, "full")
|
962 |
+
diff = corr.argmax() - (b.shape[0] - 1)
|
963 |
+
|
964 |
+
return diff
|
965 |
+
|
966 |
+
def progress_bar(length):
|
967 |
+
global progress_value
|
968 |
+
progress_value += 1
|
969 |
+
|
970 |
+
if (0.90 / length * progress_value) >= 0.9:
|
971 |
+
length = progress_value + 1
|
972 |
+
|
973 |
+
set_progress_bar(0.1, (0.9 / length * progress_value))
|
974 |
+
|
975 |
+
# read tracks
|
976 |
+
|
977 |
+
if file1.endswith(".mp3") and is_macos:
|
978 |
+
length1 = rerun_mp3(file1)
|
979 |
+
wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False)
|
980 |
+
else:
|
981 |
+
wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
|
982 |
+
|
983 |
+
if file2.endswith(".mp3") and is_macos:
|
984 |
+
length2 = rerun_mp3(file2)
|
985 |
+
wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False)
|
986 |
+
else:
|
987 |
+
wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
|
988 |
+
|
989 |
+
if wav1.ndim == 1 and wav2.ndim == 1:
|
990 |
+
is_mono = True
|
991 |
+
elif wav1.ndim == 1:
|
992 |
+
wav1 = np.asfortranarray([wav1, wav1])
|
993 |
+
elif wav2.ndim == 1:
|
994 |
+
wav2 = np.asfortranarray([wav2, wav2])
|
995 |
+
|
996 |
+
# Check if phase is inverted
|
997 |
+
if phase_option == AUTO_PHASE:
|
998 |
+
if check_if_phase_inverted(wav1, wav2, is_mono=is_mono):
|
999 |
+
wav2 = -wav2
|
1000 |
+
elif phase_option == POSITIVE_PHASE:
|
1001 |
+
wav2 = +wav2
|
1002 |
+
elif phase_option == NEGATIVE_PHASE:
|
1003 |
+
wav2 = -wav2
|
1004 |
+
|
1005 |
+
if is_match_silence:
|
1006 |
+
wav2 = adjust_leading_silence(wav2, wav1)
|
1007 |
+
|
1008 |
+
wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
|
1009 |
+
wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
|
1010 |
+
|
1011 |
+
if not is_mono:
|
1012 |
+
wav1 = wav1.transpose()
|
1013 |
+
wav2 = wav2.transpose()
|
1014 |
+
|
1015 |
+
wav2_org = wav2.copy()
|
1016 |
+
|
1017 |
+
command_Text("Processing files... \n")
|
1018 |
+
seconds_length = min(wav1_length, wav2_length)
|
1019 |
+
|
1020 |
+
wav2_aligned_sources = []
|
1021 |
+
|
1022 |
+
for sec_len in align_intro_val:
|
1023 |
+
# pick a position at 1 second in and get diff
|
1024 |
+
sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
|
1025 |
+
index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
|
1026 |
+
|
1027 |
+
if is_mono:
|
1028 |
+
samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
|
1029 |
+
diff = get_diff(samp1, samp2)
|
1030 |
+
# print(f"Estimated difference: {diff}\n")
|
1031 |
+
else:
|
1032 |
+
index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
|
1033 |
+
samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
|
1034 |
+
samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
|
1035 |
+
diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
|
1036 |
+
# print(f"Estimated difference Left Channel: {diff}\nEstimated difference Right Channel: {diff_r}\n")
|
1037 |
+
|
1038 |
+
# make aligned track 2
|
1039 |
+
if diff > 0:
|
1040 |
+
zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2))
|
1041 |
+
wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0)
|
1042 |
+
elif diff < 0:
|
1043 |
+
wav2_aligned = wav2_org[-diff:]
|
1044 |
+
else:
|
1045 |
+
wav2_aligned = wav2_org
|
1046 |
+
# command_Text(f"Audio files already aligned.\n")
|
1047 |
+
|
1048 |
+
if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources):
|
1049 |
+
wav2_aligned_sources.append(wav2_aligned)
|
1050 |
+
|
1051 |
+
# print("Unique Sources: ", len(wav2_aligned_sources))
|
1052 |
+
|
1053 |
+
unique_sources = len(wav2_aligned_sources)
|
1054 |
+
|
1055 |
+
sub_mapper_big_mapper = {}
|
1056 |
+
|
1057 |
+
for s in wav2_aligned_sources:
|
1058 |
+
wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
|
1059 |
+
|
1060 |
+
if align_window:
|
1061 |
+
wav_sub = time_correction(
|
1062 |
+
wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts
|
1063 |
+
)
|
1064 |
+
wav_sub_size = np.abs(wav_sub).mean()
|
1065 |
+
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
|
1066 |
+
else:
|
1067 |
+
wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
|
1068 |
+
db_range = db_analysis[1]
|
1069 |
+
|
1070 |
+
for db_adjustment in db_range:
|
1071 |
+
# Adjust the dB of track2
|
1072 |
+
s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20))
|
1073 |
+
wav_sub = wav1 - s_adjusted
|
1074 |
+
wav_sub_size = np.abs(wav_sub).mean()
|
1075 |
+
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
|
1076 |
+
|
1077 |
+
# print(sub_mapper_big_mapper.keys(), min(sub_mapper_big_mapper.keys()))
|
1078 |
+
|
1079 |
+
sub_mapper_value_list = list(sub_mapper_big_mapper.values())
|
1080 |
+
|
1081 |
+
if is_spec_match and len(sub_mapper_value_list) >= 2:
|
1082 |
+
# print("using spec ensemble with align")
|
1083 |
+
wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values()))
|
1084 |
+
else:
|
1085 |
+
# print("using linear ensemble with align")
|
1086 |
+
wav_sub = ensemble_wav(list(sub_mapper_big_mapper.values()))
|
1087 |
+
|
1088 |
+
# print(f"Mix Mean: {np.abs(wav1).mean()}\nInst Mean: {np.abs(wav2).mean()}")
|
1089 |
+
# print('Final: ', np.abs(wav_sub).mean())
|
1090 |
+
wav_sub = np.clip(wav_sub, -1, +1)
|
1091 |
+
|
1092 |
+
command_Text(f"Saving inverted track... ")
|
1093 |
+
|
1094 |
+
if is_save_aligned or is_spec_match:
|
1095 |
+
wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
|
1096 |
+
wav2_aligned = wav1 - wav_sub
|
1097 |
+
|
1098 |
+
if is_spec_match:
|
1099 |
+
if wav1.ndim == 1 and wav2.ndim == 1:
|
1100 |
+
wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
|
1101 |
+
wav1 = np.asfortranarray([wav1, wav1]).T
|
1102 |
+
|
1103 |
+
wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
|
1104 |
+
wav_sub = wav1 - wav2_aligned
|
1105 |
+
|
1106 |
+
if is_save_aligned:
|
1107 |
+
sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
|
1108 |
+
save_format(file2_aligned)
|
1109 |
+
|
1110 |
+
sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
|
1111 |
+
save_format(file_subtracted)
|
1112 |
+
|
1113 |
+
|
1114 |
+
def phase_shift_hilbert(signal, degree):
|
1115 |
+
analytic_signal = hilbert(signal)
|
1116 |
+
return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
|
1117 |
+
|
1118 |
+
|
1119 |
+
def get_phase_shifted_tracks(track, phase_shift):
|
1120 |
+
if phase_shift == 180:
|
1121 |
+
return [track, -track]
|
1122 |
+
|
1123 |
+
step = phase_shift
|
1124 |
+
end = 180 - (180 % step) if 180 % step == 0 else 181
|
1125 |
+
phase_range = range(step, end, step)
|
1126 |
+
|
1127 |
+
flipped_list = [track, -track]
|
1128 |
+
for i in phase_range:
|
1129 |
+
flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
|
1130 |
+
|
1131 |
+
return flipped_list
|
1132 |
+
|
1133 |
+
|
1134 |
+
def time_correction(mix: np.ndarray, instrumental: np.ndarray, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
|
1135 |
+
# Function to align two tracks using cross-correlation
|
1136 |
+
|
1137 |
+
def align_tracks(track1, track2):
|
1138 |
+
# A dictionary to store each version of track2_shifted and its mean absolute value
|
1139 |
+
shifted_tracks = {}
|
1140 |
+
|
1141 |
+
# Loop to adjust dB of track2
|
1142 |
+
track2 = track2 * np.power(10, db_analysis[0] / 20)
|
1143 |
+
db_range = db_analysis[1]
|
1144 |
+
|
1145 |
+
if phase_shifts == 190:
|
1146 |
+
track2_flipped = [track2]
|
1147 |
+
else:
|
1148 |
+
track2_flipped = get_phase_shifted_tracks(track2, phase_shifts)
|
1149 |
+
|
1150 |
+
for db_adjustment in db_range:
|
1151 |
+
for t in track2_flipped:
|
1152 |
+
# Adjust the dB of track2
|
1153 |
+
track2_adjusted = t * (10 ** (db_adjustment / 20))
|
1154 |
+
corr = correlate(track1, track2_adjusted)
|
1155 |
+
delay = np.argmax(np.abs(corr)) - (len(track1) - 1)
|
1156 |
+
track2_shifted = np.roll(track2_adjusted, shift=delay)
|
1157 |
+
|
1158 |
+
# Compute the mean absolute value of track2_shifted
|
1159 |
+
track2_shifted_sub = track1 - track2_shifted
|
1160 |
+
mean_abs_value = np.abs(track2_shifted_sub).mean()
|
1161 |
+
|
1162 |
+
# Store track2_shifted and its mean absolute value in the dictionary
|
1163 |
+
shifted_tracks[mean_abs_value] = track2_shifted
|
1164 |
+
|
1165 |
+
# Return the version of track2_shifted with the smallest mean absolute value
|
1166 |
+
|
1167 |
+
return shifted_tracks[min(shifted_tracks.keys())]
|
1168 |
+
|
1169 |
+
# Make sure the audio files have the same shape
|
1170 |
+
|
1171 |
+
assert mix.shape == instrumental.shape, f"Audio files must have the same shape - Mix: {mix.shape}, Inst: {instrumental.shape}"
|
1172 |
+
|
1173 |
+
seconds_length = seconds_length // 2
|
1174 |
+
|
1175 |
+
sub_mapper = {}
|
1176 |
+
|
1177 |
+
progress_update_interval = 120
|
1178 |
+
total_iterations = 0
|
1179 |
+
|
1180 |
+
if len(align_window) > 2:
|
1181 |
+
progress_update_interval = 320
|
1182 |
+
|
1183 |
+
for secs in align_window:
|
1184 |
+
step = secs / 2
|
1185 |
+
window_size = int(sr * secs)
|
1186 |
+
step_size = int(sr * step)
|
1187 |
+
|
1188 |
+
if len(mix.shape) == 1:
|
1189 |
+
total_mono = (len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources
|
1190 |
+
total_iterations += total_mono
|
1191 |
+
else:
|
1192 |
+
total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2
|
1193 |
+
total_stereo = (total_stereo_ // progress_update_interval) * unique_sources
|
1194 |
+
total_iterations += total_stereo
|
1195 |
+
|
1196 |
+
# print(total_iterations)
|
1197 |
+
|
1198 |
+
for secs in align_window:
|
1199 |
+
sub = np.zeros_like(mix)
|
1200 |
+
divider = np.zeros_like(mix)
|
1201 |
+
step = secs / 2
|
1202 |
+
window_size = int(sr * secs)
|
1203 |
+
step_size = int(sr * step)
|
1204 |
+
window = np.hanning(window_size)
|
1205 |
+
|
1206 |
+
# For the mono case:
|
1207 |
+
if len(mix.shape) == 1:
|
1208 |
+
# The files are mono
|
1209 |
+
counter = 0
|
1210 |
+
for i in range(0, len(mix) - window_size, step_size):
|
1211 |
+
counter += 1
|
1212 |
+
if counter % progress_update_interval == 0:
|
1213 |
+
progress_bar(total_iterations)
|
1214 |
+
window_mix = mix[i : i + window_size] * window
|
1215 |
+
window_instrumental = instrumental[i : i + window_size] * window
|
1216 |
+
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
1217 |
+
sub[i : i + window_size] += window_mix - window_instrumental_aligned
|
1218 |
+
divider[i : i + window_size] += window
|
1219 |
+
else:
|
1220 |
+
# The files are stereo
|
1221 |
+
counter = 0
|
1222 |
+
for ch in range(mix.shape[1]):
|
1223 |
+
for i in range(0, len(mix[:, ch]) - window_size, step_size):
|
1224 |
+
counter += 1
|
1225 |
+
if counter % progress_update_interval == 0:
|
1226 |
+
progress_bar(total_iterations)
|
1227 |
+
window_mix = mix[i : i + window_size, ch] * window
|
1228 |
+
window_instrumental = instrumental[i : i + window_size, ch] * window
|
1229 |
+
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
1230 |
+
sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
|
1231 |
+
divider[i : i + window_size, ch] += window
|
1232 |
+
|
1233 |
+
# Normalize the result by the overlap count
|
1234 |
+
sub = np.where(divider > 1e-6, sub / divider, sub)
|
1235 |
+
sub_size = np.abs(sub).mean()
|
1236 |
+
sub_mapper = {**sub_mapper, **{sub_size: sub}}
|
1237 |
+
|
1238 |
+
# print("SUB_LEN", len(list(sub_mapper.values())))
|
1239 |
+
|
1240 |
+
sub = ensemble_wav(list(sub_mapper.values()), split_size=12)
|
1241 |
+
|
1242 |
+
return sub
|
1243 |
+
|
1244 |
+
|
1245 |
+
def ensemble_wav(waveforms, split_size=240):
|
1246 |
+
# Create a dictionary to hold the thirds of each waveform and their mean absolute values
|
1247 |
+
waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
|
1248 |
+
|
1249 |
+
# Initialize the final waveform
|
1250 |
+
final_waveform = []
|
1251 |
+
|
1252 |
+
# For chunk
|
1253 |
+
for third_idx in range(split_size):
|
1254 |
+
# Compute the mean absolute value of each third from each waveform
|
1255 |
+
means = [np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))]
|
1256 |
+
|
1257 |
+
# Find the index of the waveform with the lowest mean absolute value for this third
|
1258 |
+
min_index = np.argmin(means)
|
1259 |
+
|
1260 |
+
# Add the least noisy third to the final waveform
|
1261 |
+
final_waveform.append(waveform_thirds[min_index][third_idx])
|
1262 |
+
|
1263 |
+
# Concatenate all the thirds to create the final waveform
|
1264 |
+
final_waveform = np.concatenate(final_waveform)
|
1265 |
+
|
1266 |
+
return final_waveform
|
1267 |
+
|
1268 |
+
|
1269 |
+
def ensemble_wav_min(waveforms):
|
1270 |
+
for i in range(1, len(waveforms)):
|
1271 |
+
if i == 1:
|
1272 |
+
wave = waveforms[0]
|
1273 |
+
|
1274 |
+
ln = min(len(wave), len(waveforms[i]))
|
1275 |
+
wave = wave[:ln]
|
1276 |
+
waveforms[i] = waveforms[i][:ln]
|
1277 |
+
|
1278 |
+
wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
|
1279 |
+
|
1280 |
+
return wave
|
1281 |
+
|
1282 |
+
|
1283 |
+
def align_audio_test(wav1, wav2, sr1=44100):
|
1284 |
+
def get_diff(a, b):
|
1285 |
+
corr = np.correlate(a, b, "full")
|
1286 |
+
diff = corr.argmax() - (b.shape[0] - 1)
|
1287 |
+
return diff
|
1288 |
+
|
1289 |
+
# read tracks
|
1290 |
+
wav1 = wav1.transpose()
|
1291 |
+
wav2 = wav2.transpose()
|
1292 |
+
|
1293 |
+
# print(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n")
|
1294 |
+
|
1295 |
+
wav2_org = wav2.copy()
|
1296 |
+
|
1297 |
+
# pick a position at 1 second in and get diff
|
1298 |
+
index = sr1 # *seconds_length # 1 second in, assuming sr1 = sr2 = 44100
|
1299 |
+
samp1 = wav1[index : index + sr1, 0] # currently use left channel
|
1300 |
+
samp2 = wav2[index : index + sr1, 0]
|
1301 |
+
diff = get_diff(samp1, samp2)
|
1302 |
+
|
1303 |
+
# make aligned track 2
|
1304 |
+
if diff > 0:
|
1305 |
+
wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
|
1306 |
+
elif diff < 0:
|
1307 |
+
wav2_aligned = wav2_org[-diff:]
|
1308 |
+
else:
|
1309 |
+
wav2_aligned = wav2_org
|
1310 |
+
|
1311 |
+
return wav2_aligned
|
1312 |
+
|
1313 |
+
|
1314 |
+
def load_audio(audio_file):
|
1315 |
+
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
1316 |
+
|
1317 |
+
if wav.ndim == 1:
|
1318 |
+
wav = np.asfortranarray([wav, wav])
|
1319 |
+
|
1320 |
+
return wav
|
1321 |
+
|
1322 |
+
|
1323 |
+
def rerun_mp3(audio_file):
|
1324 |
+
with audioread.audio_open(audio_file) as f:
|
1325 |
+
track_length = int(f.duration)
|
1326 |
+
|
1327 |
+
return track_length
|
audio_separator/separator/uvr_lib_v5/stft.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class STFT:
|
5 |
+
"""
|
6 |
+
This class performs the Short-Time Fourier Transform (STFT) and its inverse (ISTFT).
|
7 |
+
These functions are essential for converting the audio between the time domain and the frequency domain,
|
8 |
+
which is a crucial aspect of audio processing in neural networks.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
12 |
+
self.logger = logger
|
13 |
+
self.n_fft = n_fft
|
14 |
+
self.hop_length = hop_length
|
15 |
+
self.dim_f = dim_f
|
16 |
+
self.device = device
|
17 |
+
# Create a Hann window tensor for use in the STFT.
|
18 |
+
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
19 |
+
|
20 |
+
def __call__(self, input_tensor):
|
21 |
+
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
|
22 |
+
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
23 |
+
|
24 |
+
# If on a non-standard device, temporarily move the tensor to CPU for processing.
|
25 |
+
if is_non_standard_device:
|
26 |
+
input_tensor = input_tensor.cpu()
|
27 |
+
|
28 |
+
# Transfer the pre-defined window tensor to the same device as the input tensor.
|
29 |
+
stft_window = self.hann_window.to(input_tensor.device)
|
30 |
+
|
31 |
+
# Extract batch dimensions (all dimensions except the last two which are channel and time).
|
32 |
+
batch_dimensions = input_tensor.shape[:-2]
|
33 |
+
|
34 |
+
# Extract channel and time dimensions (last two dimensions of the tensor).
|
35 |
+
channel_dim, time_dim = input_tensor.shape[-2:]
|
36 |
+
|
37 |
+
# Reshape the tensor to merge batch and channel dimensions for STFT processing.
|
38 |
+
reshaped_tensor = input_tensor.reshape([-1, time_dim])
|
39 |
+
|
40 |
+
# Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
|
41 |
+
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
|
42 |
+
|
43 |
+
# Rearrange the dimensions of the STFT output to bring the frequency dimension forward.
|
44 |
+
permuted_stft_output = stft_output.permute([0, 3, 1, 2])
|
45 |
+
|
46 |
+
# Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions.
|
47 |
+
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape(
|
48 |
+
[*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]
|
49 |
+
)
|
50 |
+
|
51 |
+
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
|
52 |
+
if is_non_standard_device:
|
53 |
+
final_output = final_output.to(self.device)
|
54 |
+
|
55 |
+
# Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`).
|
56 |
+
return final_output[..., : self.dim_f, :]
|
57 |
+
|
58 |
+
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
59 |
+
"""
|
60 |
+
Adds zero padding to the frequency dimension of the input tensor.
|
61 |
+
"""
|
62 |
+
# Create a padding tensor for the frequency dimension
|
63 |
+
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
|
64 |
+
|
65 |
+
# Concatenate the padding to the input tensor along the frequency dimension.
|
66 |
+
padded_tensor = torch.cat([input_tensor, freq_padding], -2)
|
67 |
+
|
68 |
+
return padded_tensor
|
69 |
+
|
70 |
+
def calculate_inverse_dimensions(self, input_tensor):
|
71 |
+
# Extract batch dimensions and frequency-time dimensions.
|
72 |
+
batch_dimensions = input_tensor.shape[:-3]
|
73 |
+
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
74 |
+
|
75 |
+
# Calculate the number of frequency bins for the inverse STFT.
|
76 |
+
num_freq_bins = self.n_fft // 2 + 1
|
77 |
+
|
78 |
+
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
|
79 |
+
|
80 |
+
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
81 |
+
"""
|
82 |
+
Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping
|
83 |
+
and creating a complex tensor from the real and imaginary parts.
|
84 |
+
"""
|
85 |
+
# Reshape the tensor to separate real and imaginary parts and prepare for ISTFT.
|
86 |
+
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
|
87 |
+
|
88 |
+
# Flatten batch dimensions and rearrange for ISTFT.
|
89 |
+
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
|
90 |
+
|
91 |
+
# Rearrange the dimensions of the tensor to bring the frequency dimension forward.
|
92 |
+
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
|
93 |
+
|
94 |
+
# Combine real and imaginary parts into a complex tensor.
|
95 |
+
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
96 |
+
|
97 |
+
return complex_tensor
|
98 |
+
|
99 |
+
def inverse(self, input_tensor):
|
100 |
+
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
|
101 |
+
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
102 |
+
|
103 |
+
# If on a non-standard device, temporarily move the tensor to CPU for processing.
|
104 |
+
if is_non_standard_device:
|
105 |
+
input_tensor = input_tensor.cpu()
|
106 |
+
|
107 |
+
# Transfer the pre-defined Hann window tensor to the same device as the input tensor.
|
108 |
+
stft_window = self.hann_window.to(input_tensor.device)
|
109 |
+
|
110 |
+
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
111 |
+
|
112 |
+
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
|
113 |
+
|
114 |
+
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
|
115 |
+
|
116 |
+
# Perform the Inverse Short-Time Fourier Transform (ISTFT).
|
117 |
+
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
|
118 |
+
|
119 |
+
# Reshape ISTFT result to restore original batch and channel dimensions.
|
120 |
+
final_output = istft_result.reshape([*batch_dimensions, 2, -1])
|
121 |
+
|
122 |
+
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
|
123 |
+
if is_non_standard_device:
|
124 |
+
final_output = final_output.to(self.device)
|
125 |
+
|
126 |
+
return final_output
|
audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
class STFT:
|
6 |
+
def __init__(self, n_fft, hop_length, dim_f, device):
|
7 |
+
self.n_fft = n_fft
|
8 |
+
self.hop_length = hop_length
|
9 |
+
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
10 |
+
self.dim_f = dim_f
|
11 |
+
self.device = device
|
12 |
+
|
13 |
+
def __call__(self, x):
|
14 |
+
|
15 |
+
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
16 |
+
if x_is_mps:
|
17 |
+
x = x.cpu()
|
18 |
+
|
19 |
+
window = self.window.to(x.device)
|
20 |
+
batch_dims = x.shape[:-2]
|
21 |
+
c, t = x.shape[-2:]
|
22 |
+
x = x.reshape([-1, t])
|
23 |
+
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False)
|
24 |
+
x = x.permute([0, 3, 1, 2])
|
25 |
+
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
26 |
+
|
27 |
+
if x_is_mps:
|
28 |
+
x = x.to(self.device)
|
29 |
+
|
30 |
+
return x[..., :self.dim_f, :]
|
31 |
+
|
32 |
+
def inverse(self, x):
|
33 |
+
|
34 |
+
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
35 |
+
if x_is_mps:
|
36 |
+
x = x.cpu()
|
37 |
+
|
38 |
+
window = self.window.to(x.device)
|
39 |
+
batch_dims = x.shape[:-3]
|
40 |
+
c, f, t = x.shape[-3:]
|
41 |
+
n = self.n_fft // 2 + 1
|
42 |
+
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
|
43 |
+
x = torch.cat([x, f_pad], -2)
|
44 |
+
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
|
45 |
+
x = x.permute([0, 2, 3, 1])
|
46 |
+
x = x[..., 0] + x[..., 1] * 1.j
|
47 |
+
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
|
48 |
+
x = x.reshape([*batch_dims, 2, -1])
|
49 |
+
|
50 |
+
if x_is_mps:
|
51 |
+
x = x.to(self.device)
|
52 |
+
|
53 |
+
return x
|
54 |
+
|
55 |
+
def get_norm(norm_type):
|
56 |
+
def norm(c, norm_type):
|
57 |
+
if norm_type == 'BatchNorm':
|
58 |
+
return nn.BatchNorm2d(c)
|
59 |
+
elif norm_type == 'InstanceNorm':
|
60 |
+
return nn.InstanceNorm2d(c, affine=True)
|
61 |
+
elif 'GroupNorm' in norm_type:
|
62 |
+
g = int(norm_type.replace('GroupNorm', ''))
|
63 |
+
return nn.GroupNorm(num_groups=g, num_channels=c)
|
64 |
+
else:
|
65 |
+
return nn.Identity()
|
66 |
+
|
67 |
+
return partial(norm, norm_type=norm_type)
|
68 |
+
|
69 |
+
|
70 |
+
def get_act(act_type):
|
71 |
+
if act_type == 'gelu':
|
72 |
+
return nn.GELU()
|
73 |
+
elif act_type == 'relu':
|
74 |
+
return nn.ReLU()
|
75 |
+
elif act_type[:3] == 'elu':
|
76 |
+
alpha = float(act_type.replace('elu', ''))
|
77 |
+
return nn.ELU(alpha)
|
78 |
+
else:
|
79 |
+
raise Exception
|
80 |
+
|
81 |
+
|
82 |
+
class Upscale(nn.Module):
|
83 |
+
def __init__(self, in_c, out_c, scale, norm, act):
|
84 |
+
super().__init__()
|
85 |
+
self.conv = nn.Sequential(
|
86 |
+
norm(in_c),
|
87 |
+
act,
|
88 |
+
nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
return self.conv(x)
|
93 |
+
|
94 |
+
|
95 |
+
class Downscale(nn.Module):
|
96 |
+
def __init__(self, in_c, out_c, scale, norm, act):
|
97 |
+
super().__init__()
|
98 |
+
self.conv = nn.Sequential(
|
99 |
+
norm(in_c),
|
100 |
+
act,
|
101 |
+
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
return self.conv(x)
|
106 |
+
|
107 |
+
|
108 |
+
class TFC_TDF(nn.Module):
|
109 |
+
def __init__(self, in_c, c, l, f, bn, norm, act):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
self.blocks = nn.ModuleList()
|
113 |
+
for i in range(l):
|
114 |
+
block = nn.Module()
|
115 |
+
|
116 |
+
block.tfc1 = nn.Sequential(
|
117 |
+
norm(in_c),
|
118 |
+
act,
|
119 |
+
nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
|
120 |
+
)
|
121 |
+
block.tdf = nn.Sequential(
|
122 |
+
norm(c),
|
123 |
+
act,
|
124 |
+
nn.Linear(f, f // bn, bias=False),
|
125 |
+
norm(c),
|
126 |
+
act,
|
127 |
+
nn.Linear(f // bn, f, bias=False),
|
128 |
+
)
|
129 |
+
block.tfc2 = nn.Sequential(
|
130 |
+
norm(c),
|
131 |
+
act,
|
132 |
+
nn.Conv2d(c, c, 3, 1, 1, bias=False),
|
133 |
+
)
|
134 |
+
block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
|
135 |
+
|
136 |
+
self.blocks.append(block)
|
137 |
+
in_c = c
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
for block in self.blocks:
|
141 |
+
s = block.shortcut(x)
|
142 |
+
x = block.tfc1(x)
|
143 |
+
x = x + block.tdf(x)
|
144 |
+
x = block.tfc2(x)
|
145 |
+
x = x + s
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class TFC_TDF_net(nn.Module):
|
150 |
+
def __init__(self, config, device):
|
151 |
+
super().__init__()
|
152 |
+
self.config = config
|
153 |
+
self.device = device
|
154 |
+
|
155 |
+
norm = get_norm(norm_type=config.model.norm)
|
156 |
+
act = get_act(act_type=config.model.act)
|
157 |
+
|
158 |
+
self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
|
159 |
+
self.num_subbands = config.model.num_subbands
|
160 |
+
|
161 |
+
dim_c = self.num_subbands * config.audio.num_channels * 2
|
162 |
+
n = config.model.num_scales
|
163 |
+
scale = config.model.scale
|
164 |
+
l = config.model.num_blocks_per_scale
|
165 |
+
c = config.model.num_channels
|
166 |
+
g = config.model.growth
|
167 |
+
bn = config.model.bottleneck_factor
|
168 |
+
f = config.audio.dim_f // self.num_subbands
|
169 |
+
|
170 |
+
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
|
171 |
+
|
172 |
+
self.encoder_blocks = nn.ModuleList()
|
173 |
+
for i in range(n):
|
174 |
+
block = nn.Module()
|
175 |
+
block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
|
176 |
+
block.downscale = Downscale(c, c + g, scale, norm, act)
|
177 |
+
f = f // scale[1]
|
178 |
+
c += g
|
179 |
+
self.encoder_blocks.append(block)
|
180 |
+
|
181 |
+
self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
|
182 |
+
|
183 |
+
self.decoder_blocks = nn.ModuleList()
|
184 |
+
for i in range(n):
|
185 |
+
block = nn.Module()
|
186 |
+
block.upscale = Upscale(c, c - g, scale, norm, act)
|
187 |
+
f = f * scale[1]
|
188 |
+
c -= g
|
189 |
+
block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
|
190 |
+
self.decoder_blocks.append(block)
|
191 |
+
|
192 |
+
self.final_conv = nn.Sequential(
|
193 |
+
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
|
194 |
+
act,
|
195 |
+
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
|
196 |
+
)
|
197 |
+
|
198 |
+
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
|
199 |
+
|
200 |
+
def cac2cws(self, x):
|
201 |
+
k = self.num_subbands
|
202 |
+
b, c, f, t = x.shape
|
203 |
+
x = x.reshape(b, c, k, f // k, t)
|
204 |
+
x = x.reshape(b, c * k, f // k, t)
|
205 |
+
return x
|
206 |
+
|
207 |
+
def cws2cac(self, x):
|
208 |
+
k = self.num_subbands
|
209 |
+
b, c, f, t = x.shape
|
210 |
+
x = x.reshape(b, c // k, k, f, t)
|
211 |
+
x = x.reshape(b, c // k, f * k, t)
|
212 |
+
return x
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
|
216 |
+
x = self.stft(x)
|
217 |
+
|
218 |
+
mix = x = self.cac2cws(x)
|
219 |
+
|
220 |
+
first_conv_out = x = self.first_conv(x)
|
221 |
+
|
222 |
+
x = x.transpose(-1, -2)
|
223 |
+
|
224 |
+
encoder_outputs = []
|
225 |
+
for block in self.encoder_blocks:
|
226 |
+
x = block.tfc_tdf(x)
|
227 |
+
encoder_outputs.append(x)
|
228 |
+
x = block.downscale(x)
|
229 |
+
|
230 |
+
x = self.bottleneck_block(x)
|
231 |
+
|
232 |
+
for block in self.decoder_blocks:
|
233 |
+
x = block.upscale(x)
|
234 |
+
x = torch.cat([x, encoder_outputs.pop()], 1)
|
235 |
+
x = block.tfc_tdf(x)
|
236 |
+
|
237 |
+
x = x.transpose(-1, -2)
|
238 |
+
|
239 |
+
x = x * first_conv_out # reduce artifacts
|
240 |
+
|
241 |
+
x = self.final_conv(torch.cat([mix, x], 1))
|
242 |
+
|
243 |
+
x = self.cws2cac(x)
|
244 |
+
|
245 |
+
if self.num_target_instruments > 1:
|
246 |
+
b, c, f, t = x.shape
|
247 |
+
x = x.reshape(b, self.num_target_instruments, -1, f, t)
|
248 |
+
|
249 |
+
x = self.stft.inverse(x)
|
250 |
+
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
audio_separator/separator/uvr_lib_v5/vr_network/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# VR init.
|
audio_separator/separator/uvr_lib_v5/vr_network/layers.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
6 |
+
|
7 |
+
|
8 |
+
class Conv2DBNActiv(nn.Module):
|
9 |
+
"""
|
10 |
+
This class implements a convolutional layer followed by batch normalization and an activation function.
|
11 |
+
It is a common pattern in deep learning for processing images or feature maps. The convolutional layer
|
12 |
+
applies a set of learnable filters to the input. Batch normalization then normalizes the output of the
|
13 |
+
convolution, and finally, an activation function introduces non-linearity to the model, allowing it to
|
14 |
+
learn more complex patterns.
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
conv (nn.Sequential): A sequential container of Conv2d, BatchNorm2d, and an activation layer.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
num_input_channels (int): Number of input channels.
|
21 |
+
num_output_channels (int): Number of output channels.
|
22 |
+
kernel_size (int, optional): Size of the kernel. Defaults to 3.
|
23 |
+
stride_length (int, optional): Stride of the convolution. Defaults to 1.
|
24 |
+
padding_size (int, optional): Padding added to all sides of the input. Defaults to 1.
|
25 |
+
dilation_rate (int, optional): Spacing between kernel elements. Defaults to 1.
|
26 |
+
activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
|
30 |
+
super(Conv2DBNActiv, self).__init__()
|
31 |
+
|
32 |
+
# The nn.Sequential container allows us to stack the Conv2d, BatchNorm2d, and activation layers
|
33 |
+
# into a single module, simplifying the forward pass.
|
34 |
+
self.conv = nn.Sequential(nn.Conv2d(nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False), nn.BatchNorm2d(nout), activ())
|
35 |
+
|
36 |
+
def __call__(self, input_tensor):
|
37 |
+
# Defines the computation performed at every call.
|
38 |
+
# Simply passes the input through the sequential container.
|
39 |
+
return self.conv(input_tensor)
|
40 |
+
|
41 |
+
|
42 |
+
class SeperableConv2DBNActiv(nn.Module):
|
43 |
+
"""
|
44 |
+
This class implements a separable convolutional layer followed by batch normalization and an activation function.
|
45 |
+
Separable convolutions are a type of convolution that splits the convolution operation into two simpler operations:
|
46 |
+
a depthwise convolution and a pointwise convolution. This can reduce the number of parameters and computational cost,
|
47 |
+
making the network more efficient while maintaining similar performance.
|
48 |
+
|
49 |
+
The depthwise convolution applies a single filter per input channel (input depth). The pointwise convolution,
|
50 |
+
which follows, applies a 1x1 convolution to combine the outputs of the depthwise convolution across channels.
|
51 |
+
Batch normalization is then applied to stabilize learning and reduce internal covariate shift. Finally,
|
52 |
+
an activation function introduces non-linearity, allowing the network to learn complex patterns.
|
53 |
+
Attributes:
|
54 |
+
conv (nn.Sequential): A sequential container of depthwise Conv2d, pointwise Conv2d, BatchNorm2d, and an activation layer.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
num_input_channels (int): Number of input channels.
|
58 |
+
num_output_channels (int): Number of output channels.
|
59 |
+
kernel_size (int, optional): Size of the kernel for the depthwise convolution. Defaults to 3.
|
60 |
+
stride_length (int, optional): Stride of the convolution. Defaults to 1.
|
61 |
+
padding_size (int, optional): Padding added to all sides of the input for the depthwise convolution. Defaults to 1.
|
62 |
+
dilation_rate (int, optional): Spacing between kernel elements for the depthwise convolution. Defaults to 1.
|
63 |
+
activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
|
67 |
+
super(SeperableConv2DBNActiv, self).__init__()
|
68 |
+
|
69 |
+
# Initialize the sequential container with the depthwise convolution.
|
70 |
+
# The number of groups in the depthwise convolution is set to num_input_channels, which means each input channel is treated separately.
|
71 |
+
# The pointwise convolution then combines these separate channels into num_output_channels channels.
|
72 |
+
# Batch normalization is applied to the output of the pointwise convolution.
|
73 |
+
# Finally, the activation function is applied to introduce non-linearity.
|
74 |
+
self.conv = nn.Sequential(
|
75 |
+
nn.Conv2d(
|
76 |
+
nin,
|
77 |
+
nin, # For depthwise convolution, in_channels = out_channels = num_input_channels
|
78 |
+
kernel_size=ksize,
|
79 |
+
stride=stride,
|
80 |
+
padding=pad,
|
81 |
+
dilation=dilation,
|
82 |
+
groups=nin, # This makes it a depthwise convolution
|
83 |
+
bias=False, # Bias is not used because it will be handled by BatchNorm2d
|
84 |
+
),
|
85 |
+
nn.Conv2d(
|
86 |
+
nin,
|
87 |
+
nout, # Pointwise convolution to combine channels
|
88 |
+
kernel_size=1, # Kernel size of 1 for pointwise convolution
|
89 |
+
bias=False, # Bias is not used because it will be handled by BatchNorm2d
|
90 |
+
),
|
91 |
+
nn.BatchNorm2d(nout), # Normalize the output of the pointwise convolution
|
92 |
+
activ(), # Apply the activation function
|
93 |
+
)
|
94 |
+
|
95 |
+
def __call__(self, input_tensor):
|
96 |
+
# Pass the input through the sequential container.
|
97 |
+
# This performs the depthwise convolution, followed by the pointwise convolution,
|
98 |
+
# batch normalization, and finally applies the activation function.
|
99 |
+
return self.conv(input_tensor)
|
100 |
+
|
101 |
+
|
102 |
+
class Encoder(nn.Module):
|
103 |
+
"""
|
104 |
+
The Encoder class is a part of the neural network architecture that is responsible for processing the input data.
|
105 |
+
It consists of two convolutional layers, each followed by batch normalization and an activation function.
|
106 |
+
The purpose of the Encoder is to transform the input data into a higher-level, abstract representation.
|
107 |
+
This is achieved by applying filters (through convolutions) that can capture patterns or features in the data.
|
108 |
+
The Encoder can be thought of as a feature extractor that prepares the data for further processing by the network.
|
109 |
+
Attributes:
|
110 |
+
conv1 (Conv2DBNActiv): The first convolutional layer in the encoder.
|
111 |
+
conv2 (Conv2DBNActiv): The second convolutional layer in the encoder.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
number_of_input_channels (int): Number of input channels for the first convolutional layer.
|
115 |
+
number_of_output_channels (int): Number of output channels for the convolutional layers.
|
116 |
+
kernel_size (int): Kernel size for the convolutional layers.
|
117 |
+
stride_length (int): Stride for the convolutional operations.
|
118 |
+
padding_size (int): Padding added to all sides of the input for the convolutional layers.
|
119 |
+
activation_function (callable): The activation function to use after each convolutional layer.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
|
123 |
+
super(Encoder, self).__init__()
|
124 |
+
|
125 |
+
# The first convolutional layer takes the input and applies a convolution,
|
126 |
+
# followed by batch normalization and an activation function specified by `activation_function`.
|
127 |
+
# This layer is responsible for capturing the initial set of features from the input data.
|
128 |
+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
|
129 |
+
|
130 |
+
# The second convolutional layer further processes the output from the first layer,
|
131 |
+
# applying another set of convolution, batch normalization, and activation.
|
132 |
+
# This layer helps in capturing more complex patterns in the data by building upon the initial features extracted by conv1.
|
133 |
+
self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
|
134 |
+
|
135 |
+
def __call__(self, input_tensor):
|
136 |
+
# The input data `input_tensor` is passed through the first convolutional layer.
|
137 |
+
# The output of this layer serves as a 'skip connection' that can be used later in the network to preserve spatial information.
|
138 |
+
skip = self.conv1(input_tensor)
|
139 |
+
|
140 |
+
# The output from the first layer is then passed through the second convolutional layer.
|
141 |
+
# This processed data `hidden` is the final output of the Encoder, representing the abstracted features of the input.
|
142 |
+
hidden = self.conv2(skip)
|
143 |
+
|
144 |
+
# The Encoder returns two outputs: `hidden`, the abstracted feature representation, and `skip`, the intermediate representation from conv1.
|
145 |
+
return hidden, skip
|
146 |
+
|
147 |
+
|
148 |
+
class Decoder(nn.Module):
|
149 |
+
"""
|
150 |
+
The Decoder class is part of the neural network architecture, specifically designed to perform the inverse operation of an encoder.
|
151 |
+
Its main role is to reconstruct or generate data from encoded representations, which is crucial in tasks like image segmentation or audio processing.
|
152 |
+
This class uses upsampling, convolution, optional dropout for regularization, and concatenation of skip connections to achieve its goal.
|
153 |
+
|
154 |
+
Attributes:
|
155 |
+
convolution (Conv2DBNActiv): A convolutional layer with batch normalization and activation function.
|
156 |
+
dropout_layer (nn.Dropout2d): An optional dropout layer for regularization to prevent overfitting.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
input_channels (int): Number of input channels for the convolutional layer.
|
160 |
+
output_channels (int): Number of output channels for the convolutional layer.
|
161 |
+
kernel_size (int): Kernel size for the convolutional layer.
|
162 |
+
stride (int): Stride for the convolutional operations.
|
163 |
+
padding (int): Padding added to all sides of the input for the convolutional layer.
|
164 |
+
activation_function (callable): The activation function to use after the convolutional layer.
|
165 |
+
include_dropout (bool): Whether to include a dropout layer for regularization.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
|
169 |
+
super(Decoder, self).__init__()
|
170 |
+
|
171 |
+
# Initialize the convolutional layer with specified parameters.
|
172 |
+
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
|
173 |
+
|
174 |
+
# Initialize the dropout layer if include_dropout is set to True
|
175 |
+
self.dropout = nn.Dropout2d(0.1) if dropout else None
|
176 |
+
|
177 |
+
def __call__(self, input_tensor, skip=None):
|
178 |
+
# Upsample the input tensor to a higher resolution using bilinear interpolation.
|
179 |
+
input_tensor = F.interpolate(input_tensor, scale_factor=2, mode="bilinear", align_corners=True)
|
180 |
+
# If a skip connection is provided, crop it to match the size of input_tensor and concatenate them along the channel dimension.
|
181 |
+
if skip is not None:
|
182 |
+
skip = spec_utils.crop_center(skip, input_tensor) # Crop skip_connection to match input_tensor's dimensions.
|
183 |
+
input_tensor = torch.cat([input_tensor, skip], dim=1) # Concatenate input_tensor and skip_connection along the channel dimension.
|
184 |
+
|
185 |
+
# Pass the concatenated tensor (or just input_tensor if no skip_connection is provided) through the convolutional layer.
|
186 |
+
output_tensor = self.conv(input_tensor)
|
187 |
+
|
188 |
+
# If dropout is enabled, apply it to the output of the convolutional layer.
|
189 |
+
if self.dropout is not None:
|
190 |
+
output_tensor = self.dropout(output_tensor)
|
191 |
+
|
192 |
+
# Return the final output tensor.
|
193 |
+
return output_tensor
|
194 |
+
|
195 |
+
|
196 |
+
class ASPPModule(nn.Module):
|
197 |
+
"""
|
198 |
+
Atrous Spatial Pyramid Pooling (ASPP) Module is designed for capturing multi-scale context by applying
|
199 |
+
atrous convolution at multiple rates. This is particularly useful in segmentation tasks where capturing
|
200 |
+
objects at various scales is beneficial. The module applies several parallel dilated convolutions with
|
201 |
+
different dilation rates to the input feature map, allowing it to efficiently capture information at
|
202 |
+
multiple scales.
|
203 |
+
|
204 |
+
Attributes:
|
205 |
+
conv1 (nn.Sequential): Applies adaptive average pooling followed by a 1x1 convolution.
|
206 |
+
nn_architecture (int): Identifier for the neural network architecture being used.
|
207 |
+
six_layer (list): List containing architecture identifiers that require six layers.
|
208 |
+
seven_layer (list): List containing architecture identifiers that require seven layers.
|
209 |
+
conv2-conv7 (nn.Module): Convolutional layers with varying dilation rates for multi-scale feature extraction.
|
210 |
+
bottleneck (nn.Sequential): A 1x1 convolutional layer that combines all features followed by dropout for regularization.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
|
214 |
+
"""
|
215 |
+
Initializes the ASPP module with specified parameters.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
nn_architecture (int): Identifier for the neural network architecture.
|
219 |
+
input_channels (int): Number of input channels.
|
220 |
+
output_channels (int): Number of output channels.
|
221 |
+
dilations (tuple): Tuple of dilation rates for the atrous convolutions.
|
222 |
+
activation (callable): Activation function to use after convolutional layers.
|
223 |
+
"""
|
224 |
+
super(ASPPModule, self).__init__()
|
225 |
+
|
226 |
+
# Adaptive average pooling reduces the spatial dimensions to 1x1, focusing on global context,
|
227 |
+
# followed by a 1x1 convolution to project back to the desired channel dimension.
|
228 |
+
self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ))
|
229 |
+
|
230 |
+
self.nn_architecture = nn_architecture
|
231 |
+
# Architecture identifiers for models requiring additional layers.
|
232 |
+
self.six_layer = [129605]
|
233 |
+
self.seven_layer = [537238, 537227, 33966]
|
234 |
+
|
235 |
+
# Extra convolutional layer used for six and seven layer configurations.
|
236 |
+
extra_conv = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
|
237 |
+
|
238 |
+
# Standard 1x1 convolution for channel reduction.
|
239 |
+
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
|
240 |
+
|
241 |
+
# Separable convolutions with different dilation rates for multi-scale feature extraction.
|
242 |
+
self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
|
243 |
+
self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
|
244 |
+
self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
|
245 |
+
|
246 |
+
# Depending on the architecture, include the extra convolutional layers.
|
247 |
+
if self.nn_architecture in self.six_layer:
|
248 |
+
self.conv6 = extra_conv
|
249 |
+
nin_x = 6
|
250 |
+
elif self.nn_architecture in self.seven_layer:
|
251 |
+
self.conv6 = extra_conv
|
252 |
+
self.conv7 = extra_conv
|
253 |
+
nin_x = 7
|
254 |
+
else:
|
255 |
+
nin_x = 5
|
256 |
+
|
257 |
+
# Bottleneck layer combines all the multi-scale features into the desired number of output channels.
|
258 |
+
self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
|
259 |
+
|
260 |
+
def forward(self, input_tensor):
|
261 |
+
"""
|
262 |
+
Forward pass of the ASPP module.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
input_tensor (Tensor): Input tensor.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
Tensor: Output tensor after applying ASPP.
|
269 |
+
"""
|
270 |
+
_, _, h, w = input_tensor.size()
|
271 |
+
|
272 |
+
# Apply the first convolutional sequence and upsample to the original resolution.
|
273 |
+
feat1 = F.interpolate(self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True)
|
274 |
+
|
275 |
+
# Apply the remaining convolutions directly on the input.
|
276 |
+
feat2 = self.conv2(input_tensor)
|
277 |
+
feat3 = self.conv3(input_tensor)
|
278 |
+
feat4 = self.conv4(input_tensor)
|
279 |
+
feat5 = self.conv5(input_tensor)
|
280 |
+
|
281 |
+
# Concatenate features from all layers. Depending on the architecture, include the extra features.
|
282 |
+
if self.nn_architecture in self.six_layer:
|
283 |
+
feat6 = self.conv6(input_tensor)
|
284 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1)
|
285 |
+
elif self.nn_architecture in self.seven_layer:
|
286 |
+
feat6 = self.conv6(input_tensor)
|
287 |
+
feat7 = self.conv7(input_tensor)
|
288 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
|
289 |
+
else:
|
290 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
|
291 |
+
|
292 |
+
# Apply the bottleneck layer to combine and reduce the channel dimensions.
|
293 |
+
bottleneck_output = self.bottleneck(out)
|
294 |
+
return bottleneck_output
|
audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
6 |
+
|
7 |
+
|
8 |
+
class Conv2DBNActiv(nn.Module):
|
9 |
+
"""
|
10 |
+
Conv2DBNActiv Class:
|
11 |
+
This class implements a convolutional layer followed by batch normalization and an activation function.
|
12 |
+
It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks.
|
13 |
+
The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
|
17 |
+
super(Conv2DBNActiv, self).__init__()
|
18 |
+
|
19 |
+
# Sequential model combining Conv2D, BatchNorm, and activation function into a single module
|
20 |
+
self.conv = nn.Sequential(nn.Conv2d(nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False), nn.BatchNorm2d(nout), activ())
|
21 |
+
|
22 |
+
def __call__(self, input_tensor):
|
23 |
+
# Forward pass through the sequential model
|
24 |
+
return self.conv(input_tensor)
|
25 |
+
|
26 |
+
|
27 |
+
class Encoder(nn.Module):
|
28 |
+
"""
|
29 |
+
Encoder Class:
|
30 |
+
This class defines an encoder module typically used in autoencoder architectures.
|
31 |
+
It consists of two convolutional layers, each followed by batch normalization and an activation function.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
|
35 |
+
super(Encoder, self).__init__()
|
36 |
+
|
37 |
+
# First convolutional layer of the encoder
|
38 |
+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
|
39 |
+
# Second convolutional layer of the encoder
|
40 |
+
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
|
41 |
+
|
42 |
+
def __call__(self, input_tensor):
|
43 |
+
# Applying the first and then the second convolutional layers
|
44 |
+
hidden = self.conv1(input_tensor)
|
45 |
+
hidden = self.conv2(hidden)
|
46 |
+
|
47 |
+
return hidden
|
48 |
+
|
49 |
+
|
50 |
+
class Decoder(nn.Module):
|
51 |
+
"""
|
52 |
+
Decoder Class:
|
53 |
+
This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures.
|
54 |
+
It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
|
58 |
+
super(Decoder, self).__init__()
|
59 |
+
# Convolutional layer with optional dropout for regularization
|
60 |
+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
|
61 |
+
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
|
62 |
+
self.dropout = nn.Dropout2d(0.1) if dropout else None
|
63 |
+
|
64 |
+
def __call__(self, input_tensor, skip=None):
|
65 |
+
# Forward pass through the convolutional layer and optional dropout
|
66 |
+
input_tensor = F.interpolate(input_tensor, scale_factor=2, mode="bilinear", align_corners=True)
|
67 |
+
|
68 |
+
if skip is not None:
|
69 |
+
skip = spec_utils.crop_center(skip, input_tensor)
|
70 |
+
input_tensor = torch.cat([input_tensor, skip], dim=1)
|
71 |
+
|
72 |
+
hidden = self.conv1(input_tensor)
|
73 |
+
# hidden = self.conv2(hidden)
|
74 |
+
|
75 |
+
if self.dropout is not None:
|
76 |
+
hidden = self.dropout(hidden)
|
77 |
+
|
78 |
+
return hidden
|
79 |
+
|
80 |
+
|
81 |
+
class ASPPModule(nn.Module):
|
82 |
+
"""
|
83 |
+
ASPPModule Class:
|
84 |
+
This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks.
|
85 |
+
It captures multi-scale contextual information by applying convolutions at multiple dilation rates.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
|
89 |
+
super(ASPPModule, self).__init__()
|
90 |
+
|
91 |
+
# Global context convolution captures the overall context
|
92 |
+
self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ))
|
93 |
+
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
|
94 |
+
self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ)
|
95 |
+
self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ)
|
96 |
+
self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ)
|
97 |
+
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
|
98 |
+
self.dropout = nn.Dropout2d(0.1) if dropout else None
|
99 |
+
|
100 |
+
def forward(self, input_tensor):
|
101 |
+
_, _, h, w = input_tensor.size()
|
102 |
+
|
103 |
+
# Upsample global context to match input size and combine with local and multi-scale features
|
104 |
+
feat1 = F.interpolate(self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True)
|
105 |
+
feat2 = self.conv2(input_tensor)
|
106 |
+
feat3 = self.conv3(input_tensor)
|
107 |
+
feat4 = self.conv4(input_tensor)
|
108 |
+
feat5 = self.conv5(input_tensor)
|
109 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
|
110 |
+
out = self.bottleneck(out)
|
111 |
+
|
112 |
+
if self.dropout is not None:
|
113 |
+
out = self.dropout(out)
|
114 |
+
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
class LSTMModule(nn.Module):
|
119 |
+
"""
|
120 |
+
LSTMModule Class:
|
121 |
+
This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling.
|
122 |
+
It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, nin_conv, nin_lstm, nout_lstm):
|
126 |
+
super(LSTMModule, self).__init__()
|
127 |
+
# Convolutional layer for initial feature extraction
|
128 |
+
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
|
129 |
+
|
130 |
+
# Bidirectional LSTM for capturing temporal dynamics
|
131 |
+
self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True)
|
132 |
+
|
133 |
+
# Dense layer for output dimensionality matching
|
134 |
+
self.dense = nn.Sequential(nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU())
|
135 |
+
|
136 |
+
def forward(self, input_tensor):
|
137 |
+
N, _, nbins, nframes = input_tensor.size()
|
138 |
+
|
139 |
+
# Extract features and prepare for LSTM
|
140 |
+
hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes
|
141 |
+
hidden = hidden.permute(2, 0, 1) # nframes, N, nbins
|
142 |
+
hidden, _ = self.lstm(hidden)
|
143 |
+
|
144 |
+
# Apply dense layer and reshape to match expected output format
|
145 |
+
hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
|
146 |
+
hidden = hidden.reshape(nframes, N, 1, nbins)
|
147 |
+
hidden = hidden.permute(1, 2, 3, 0)
|
148 |
+
|
149 |
+
return hidden
|
audio_separator/separator/uvr_lib_v5/vr_network/model_param_init.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
default_param = {}
|
4 |
+
default_param["bins"] = -1
|
5 |
+
default_param["unstable_bins"] = -1 # training only
|
6 |
+
default_param["stable_bins"] = -1 # training only
|
7 |
+
default_param["sr"] = 44100
|
8 |
+
default_param["pre_filter_start"] = -1
|
9 |
+
default_param["pre_filter_stop"] = -1
|
10 |
+
default_param["band"] = {}
|
11 |
+
|
12 |
+
N_BINS = "n_bins"
|
13 |
+
|
14 |
+
|
15 |
+
def int_keys(d):
|
16 |
+
"""
|
17 |
+
Converts string keys that represent integers into actual integer keys in a list.
|
18 |
+
|
19 |
+
This function is particularly useful when dealing with JSON data that may represent
|
20 |
+
integer keys as strings due to the nature of JSON encoding. By converting these keys
|
21 |
+
back to integers, it ensures that the data can be used in a manner consistent with
|
22 |
+
its original representation, especially in contexts where the distinction between
|
23 |
+
string and integer keys is important.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
input_list (list of tuples): A list of (key, value) pairs where keys are strings
|
27 |
+
that may represent integers.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
dict: A dictionary with keys converted to integers where applicable.
|
31 |
+
"""
|
32 |
+
# Initialize an empty dictionary to hold the converted key-value pairs.
|
33 |
+
result_dict = {}
|
34 |
+
# Iterate through each key-value pair in the input list.
|
35 |
+
for key, value in d:
|
36 |
+
# Check if the key is a digit (i.e., represents an integer).
|
37 |
+
if key.isdigit():
|
38 |
+
# Convert the key from a string to an integer.
|
39 |
+
key = int(key)
|
40 |
+
result_dict[key] = value
|
41 |
+
return result_dict
|
42 |
+
|
43 |
+
|
44 |
+
class ModelParameters(object):
|
45 |
+
"""
|
46 |
+
A class to manage model parameters, including loading from a configuration file.
|
47 |
+
|
48 |
+
Attributes:
|
49 |
+
param (dict): Dictionary holding all parameters for the model.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, config_path=""):
|
53 |
+
"""
|
54 |
+
Initializes the ModelParameters object by loading parameters from a JSON configuration file.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
config_path (str): Path to the JSON configuration file.
|
58 |
+
"""
|
59 |
+
|
60 |
+
# Load parameters from the given configuration file path.
|
61 |
+
with open(config_path, "r") as f:
|
62 |
+
self.param = json.loads(f.read(), object_pairs_hook=int_keys)
|
63 |
+
|
64 |
+
# Ensure certain parameters are set to False if not specified in the configuration.
|
65 |
+
for k in ["mid_side", "mid_side_b", "mid_side_b2", "stereo_w", "stereo_n", "reverse"]:
|
66 |
+
if not k in self.param:
|
67 |
+
self.param[k] = False
|
68 |
+
|
69 |
+
# If 'n_bins' is specified in the parameters, it's used as the value for 'bins'.
|
70 |
+
if N_BINS in self.param:
|
71 |
+
self.param["bins"] = self.param[N_BINS]
|
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bins": 1024,
|
3 |
+
"unstable_bins": 0,
|
4 |
+
"reduction_bins": 0,
|
5 |
+
"band": {
|
6 |
+
"1": {
|
7 |
+
"sr": 16000,
|
8 |
+
"hl": 512,
|
9 |
+
"n_fft": 2048,
|
10 |
+
"crop_start": 0,
|
11 |
+
"crop_stop": 1024,
|
12 |
+
"hpf_start": -1,
|
13 |
+
"res_type": "sinc_best"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"sr": 16000,
|
17 |
+
"pre_filter_start": 1023,
|
18 |
+
"pre_filter_stop": 1024
|
19 |
+
}
|
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bins": 1024,
|
3 |
+
"unstable_bins": 0,
|
4 |
+
"reduction_bins": 0,
|
5 |
+
"band": {
|
6 |
+
"1": {
|
7 |
+
"sr": 32000,
|
8 |
+
"hl": 512,
|
9 |
+
"n_fft": 2048,
|
10 |
+
"crop_start": 0,
|
11 |
+
"crop_stop": 1024,
|
12 |
+
"hpf_start": -1,
|
13 |
+
"res_type": "kaiser_fast"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"sr": 32000,
|
17 |
+
"pre_filter_start": 1000,
|
18 |
+
"pre_filter_stop": 1021
|
19 |
+
}
|
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bins": 1024,
|
3 |
+
"unstable_bins": 0,
|
4 |
+
"reduction_bins": 0,
|
5 |
+
"band": {
|
6 |
+
"1": {
|
7 |
+
"sr": 33075,
|
8 |
+
"hl": 384,
|
9 |
+
"n_fft": 2048,
|
10 |
+
"crop_start": 0,
|
11 |
+
"crop_stop": 1024,
|
12 |
+
"hpf_start": -1,
|
13 |
+
"res_type": "sinc_best"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"sr": 33075,
|
17 |
+
"pre_filter_start": 1000,
|
18 |
+
"pre_filter_stop": 1021
|
19 |
+
}
|