ASesYusuf1 commited on
Commit
01f8b5b
·
verified ·
1 Parent(s): 23546b1

Upload 131 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +20 -0
  2. audio_separator/__init__.py +0 -0
  3. audio_separator/model-data.json +22 -0
  4. audio_separator/models-scores.json +0 -0
  5. audio_separator/models.json +216 -0
  6. audio_separator/separator/__init__.py +1 -0
  7. audio_separator/separator/architectures/__init__.py +0 -0
  8. audio_separator/separator/architectures/demucs_separator.py +195 -0
  9. audio_separator/separator/architectures/mdx_separator.py +451 -0
  10. audio_separator/separator/architectures/mdxc_separator.py +423 -0
  11. audio_separator/separator/architectures/vr_separator.py +357 -0
  12. audio_separator/separator/common_separator.py +403 -0
  13. audio_separator/separator/separator.py +959 -0
  14. audio_separator/separator/uvr_lib_v5/__init__.py +0 -0
  15. audio_separator/separator/uvr_lib_v5/demucs/__init__.py +5 -0
  16. audio_separator/separator/uvr_lib_v5/demucs/__main__.py +212 -0
  17. audio_separator/separator/uvr_lib_v5/demucs/apply.py +294 -0
  18. audio_separator/separator/uvr_lib_v5/demucs/demucs.py +453 -0
  19. audio_separator/separator/uvr_lib_v5/demucs/filtering.py +451 -0
  20. audio_separator/separator/uvr_lib_v5/demucs/hdemucs.py +783 -0
  21. audio_separator/separator/uvr_lib_v5/demucs/htdemucs.py +620 -0
  22. audio_separator/separator/uvr_lib_v5/demucs/model.py +204 -0
  23. audio_separator/separator/uvr_lib_v5/demucs/model_v2.py +222 -0
  24. audio_separator/separator/uvr_lib_v5/demucs/pretrained.py +181 -0
  25. audio_separator/separator/uvr_lib_v5/demucs/repo.py +146 -0
  26. audio_separator/separator/uvr_lib_v5/demucs/spec.py +38 -0
  27. audio_separator/separator/uvr_lib_v5/demucs/states.py +131 -0
  28. audio_separator/separator/uvr_lib_v5/demucs/tasnet.py +401 -0
  29. audio_separator/separator/uvr_lib_v5/demucs/tasnet_v2.py +404 -0
  30. audio_separator/separator/uvr_lib_v5/demucs/transformer.py +675 -0
  31. audio_separator/separator/uvr_lib_v5/demucs/utils.py +496 -0
  32. audio_separator/separator/uvr_lib_v5/mdxnet.py +136 -0
  33. audio_separator/separator/uvr_lib_v5/mixer.ckpt +3 -0
  34. audio_separator/separator/uvr_lib_v5/modules.py +74 -0
  35. audio_separator/separator/uvr_lib_v5/playsound.py +241 -0
  36. audio_separator/separator/uvr_lib_v5/pyrb.py +92 -0
  37. audio_separator/separator/uvr_lib_v5/results.py +48 -0
  38. audio_separator/separator/uvr_lib_v5/roformer/attend.py +112 -0
  39. audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py +535 -0
  40. audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py +445 -0
  41. audio_separator/separator/uvr_lib_v5/spec_utils.py +1327 -0
  42. audio_separator/separator/uvr_lib_v5/stft.py +126 -0
  43. audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py +253 -0
  44. audio_separator/separator/uvr_lib_v5/vr_network/__init__.py +1 -0
  45. audio_separator/separator/uvr_lib_v5/vr_network/layers.py +294 -0
  46. audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py +149 -0
  47. audio_separator/separator/uvr_lib_v5/vr_network/model_param_init.py +71 -0
  48. audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json +19 -0
  49. audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json +19 -0
  50. audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json +19 -0
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tests/inputs/mardy20s.flac filter=lfs diff=lfs merge=lfs -text
37
+ tests/inputs/reference/expected_mardy20s_(Bass)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
38
+ tests/inputs/reference/expected_mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10_spectrogram.png filter=lfs diff=lfs merge=lfs -text
39
+ tests/inputs/reference/expected_mardy20s_(Drums)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
40
+ tests/inputs/reference/expected_mardy20s_(Guitar)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
41
+ tests/inputs/reference/expected_mardy20s_(Instrumental)_2_HP-UVR_spectrogram.png filter=lfs diff=lfs merge=lfs -text
42
+ tests/inputs/reference/expected_mardy20s_(Instrumental)_kuielab_b_vocals_spectrogram.png filter=lfs diff=lfs merge=lfs -text
43
+ tests/inputs/reference/expected_mardy20s_(Instrumental)_MGM_MAIN_v4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
44
+ tests/inputs/reference/expected_mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12_spectrogram.png filter=lfs diff=lfs merge=lfs -text
45
+ tests/inputs/reference/expected_mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
46
+ tests/inputs/reference/expected_mardy20s_(No[[:space:]]Drum-Bass)_model_bs_roformer_ep_937_sdr_10_spectrogram.png filter=lfs diff=lfs merge=lfs -text
47
+ tests/inputs/reference/expected_mardy20s_(Other)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
48
+ tests/inputs/reference/expected_mardy20s_(Piano)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
49
+ tests/inputs/reference/expected_mardy20s_(Vocals)_2_HP-UVR_spectrogram.png filter=lfs diff=lfs merge=lfs -text
50
+ tests/inputs/reference/expected_mardy20s_(Vocals)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
51
+ tests/inputs/reference/expected_mardy20s_(Vocals)_kuielab_b_vocals_spectrogram.png filter=lfs diff=lfs merge=lfs -text
52
+ tests/inputs/reference/expected_mardy20s_(Vocals)_MGM_MAIN_v4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
53
+ tests/inputs/reference/expected_mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12_spectrogram.png filter=lfs diff=lfs merge=lfs -text
54
+ tests/inputs/reference/expected_mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
55
+ tests/inputs/reference/expected_mardy20s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
audio_separator/__init__.py ADDED
File without changes
audio_separator/model-data.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vr_model_data": {
3
+ "97dc361a7a88b2c4542f68364b32c7f6": {
4
+ "vr_model_param": "4band_v4_ms_fullband",
5
+ "primary_stem": "Dry",
6
+ "nout": 32,
7
+ "nout_lstm": 128,
8
+ "is_karaoke": false,
9
+ "is_bv_model": false,
10
+ "is_bv_model_rebalanced": 0.0
11
+ }
12
+ },
13
+ "mdx_model_data": {
14
+ "cb790d0c913647ced70fc6b38f5bea1a": {
15
+ "compensate": 1.010,
16
+ "mdx_dim_f_set": 2560,
17
+ "mdx_dim_t_set": 8,
18
+ "mdx_n_fft_scale_set": 5120,
19
+ "primary_stem": "Instrumental"
20
+ }
21
+ }
22
+ }
audio_separator/models-scores.json ADDED
The diff for this file is too large to render. See raw diff
 
audio_separator/models.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vr_download_list": {
3
+ "VR Arch Single Model v4: UVR-De-Reverb by aufr33-jarredou": "UVR-De-Reverb-aufr33-jarredou.pth"
4
+ },
5
+ "mdx_download_list": {
6
+ "MDX-Net Model: UVR-MDX-NET Inst HQ 5": "UVR-MDX-NET-Inst_HQ_5.onnx"
7
+ },
8
+ "mdx23c_download_list": {
9
+ "MDX23C Model: MDX23C De-Reverb by aufr33-jarredou": {
10
+ "MDX23C-De-Reverb-aufr33-jarredou.ckpt": "config_dereverb_mdx23c.yaml"
11
+ },
12
+ "MDX23C Model: MDX23C DrumSep by aufr33-jarredou": {
13
+ "MDX23C-DrumSep-aufr33-jarredou.ckpt": "config_drumsep_mdx23c.yaml"
14
+ }
15
+ },
16
+ "roformer_download_list": {
17
+ "Roformer Model: Mel-Roformer-Karaoke-Aufr33-Viperx": {
18
+ "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
19
+ },
20
+ "Roformer Model: MelBand Roformer | Karaoke by Gabox": {
21
+ "mel_band_roformer_karaoke_gabox.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
22
+ },
23
+ "Roformer Model: MelBand Roformer | Karaoke by becruily": {
24
+ "mel_band_roformer_karaoke_becruily.ckpt": "config_mel_band_roformer_karaoke_becruily.yaml"
25
+ },
26
+ "Roformer Model: Mel-Roformer-Denoise-Aufr33": {
27
+ "denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt": "denoise_mel_band_roformer_aufr33_sdr_27.9959_config.yaml"
28
+ },
29
+ "Roformer Model: Mel-Roformer-Denoise-Aufr33-Aggr": {
30
+ "denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt": "denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768_config.yaml"
31
+ },
32
+ "Roformer Model: MelBand Roformer | Denoise-Debleed by Gabox": {
33
+ "mel_band_roformer_denoise_debleed_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
34
+ },
35
+ "Roformer Model: Mel-Roformer-Crowd-Aufr33-Viperx": {
36
+ "mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt": "mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144_config.yaml"
37
+ },
38
+ "Roformer Model: BS-Roformer-De-Reverb": {
39
+ "deverb_bs_roformer_8_384dim_10depth.ckpt": "deverb_bs_roformer_8_384dim_10depth_config.yaml"
40
+ },
41
+ "Roformer Model: MelBand Roformer | Vocals by Kimberley Jensen": {
42
+ "vocals_mel_band_roformer.ckpt": "vocals_mel_band_roformer.yaml"
43
+ },
44
+ "Roformer Model: MelBand Roformer Kim | FT by unwa": {
45
+ "mel_band_roformer_kim_ft_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
46
+ },
47
+ "Roformer Model: MelBand Roformer Kim | FT 2 by unwa": {
48
+ "mel_band_roformer_kim_ft2_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
49
+ },
50
+ "Roformer Model: MelBand Roformer Kim | FT 2 Bleedless by unwa": {
51
+ "mel_band_roformer_kim_ft2_bleedless_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
52
+ },
53
+ "Roformer Model: MelBand Roformer Kim | FT 3 by unwa": {
54
+ "mel_band_roformer_kim_ft3_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
55
+ },
56
+ "Roformer Model: MelBand Roformer Kim | Inst V1 Plus by Unwa": {
57
+ "melband_roformer_inst_v1_plus.ckpt": "config_melbandroformer_inst.yaml"
58
+ },
59
+ "Roformer Model: MelBand Roformer Kim | Inst V1 (E) by Unwa": {
60
+ "melband_roformer_inst_v1e.ckpt": "config_melbandroformer_inst.yaml"
61
+ },
62
+ "Roformer Model: MelBand Roformer Kim | Inst V1 (E) Plus by Unwa": {
63
+ "melband_roformer_inst_v1e_plus.ckpt": "config_melbandroformer_inst.yaml"
64
+ },
65
+ "Roformer Model: MelBand Roformer | Vocals by becruily": {
66
+ "mel_band_roformer_vocals_becruily.ckpt": "config_mel_band_roformer_vocals_becruily.yaml"
67
+ },
68
+ "Roformer Model: MelBand Roformer | Instrumental by becruily": {
69
+ "mel_band_roformer_instrumental_becruily.ckpt": "config_mel_band_roformer_instrumental_becruily.yaml"
70
+ },
71
+ "Roformer Model: MelBand Roformer | Vocals Fullness by Aname": {
72
+ "mel_band_roformer_vocal_fullness_aname.ckpt": "config_mel_band_roformer_vocal_fullness_aname.yaml"
73
+ },
74
+ "Roformer Model: BS Roformer | Vocals by Gabox": {
75
+ "bs_roformer_vocals_gabox.ckpt": "config_bs_roformer_vocals_gabox.yaml"
76
+ },
77
+ "Roformer Model: MelBand Roformer | Vocals by Gabox": {
78
+ "mel_band_roformer_vocals_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
79
+ },
80
+ "Roformer Model: MelBand Roformer | Vocals FV1 by Gabox": {
81
+ "mel_band_roformer_vocals_fv1_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
82
+ },
83
+ "Roformer Model: MelBand Roformer | Vocals FV2 by Gabox": {
84
+ "mel_band_roformer_vocals_fv2_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
85
+ },
86
+ "Roformer Model: MelBand Roformer | Vocals FV3 by Gabox": {
87
+ "mel_band_roformer_vocals_fv3_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
88
+ },
89
+ "Roformer Model: MelBand Roformer | Vocals FV4 by Gabox": {
90
+ "mel_band_roformer_vocals_fv4_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
91
+ },
92
+ "Roformer Model: MelBand Roformer | Instrumental by Gabox": {
93
+ "mel_band_roformer_instrumental_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
94
+ },
95
+ "Roformer Model: MelBand Roformer | Instrumental 2 by Gabox": {
96
+ "mel_band_roformer_instrumental_2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
97
+ },
98
+ "Roformer Model: MelBand Roformer | Instrumental 3 by Gabox": {
99
+ "mel_band_roformer_instrumental_3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
100
+ },
101
+ "Roformer Model: MelBand Roformer | Instrumental Bleedless V1 by Gabox": {
102
+ "mel_band_roformer_instrumental_bleedless_v1_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
103
+ },
104
+ "Roformer Model: MelBand Roformer | Instrumental Bleedless V2 by Gabox": {
105
+ "mel_band_roformer_instrumental_bleedless_v2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
106
+ },
107
+ "Roformer Model: MelBand Roformer | Instrumental Bleedless V3 by Gabox": {
108
+ "mel_band_roformer_instrumental_bleedless_v3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
109
+ },
110
+ "Roformer Model: MelBand Roformer | Instrumental Fullness V1 by Gabox": {
111
+ "mel_band_roformer_instrumental_fullness_v1_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
112
+ },
113
+ "Roformer Model: MelBand Roformer | Instrumental Fullness V2 by Gabox": {
114
+ "mel_band_roformer_instrumental_fullness_v2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
115
+ },
116
+ "Roformer Model: MelBand Roformer | Instrumental Fullness V3 by Gabox": {
117
+ "mel_band_roformer_instrumental_fullness_v3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
118
+ },
119
+ "Roformer Model: MelBand Roformer | Instrumental Fullness Noisy V4 by Gabox": {
120
+ "mel_band_roformer_instrumental_fullness_noise_v4_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
121
+ },
122
+ "Roformer Model: MelBand Roformer | INSTV5 by Gabox": {
123
+ "mel_band_roformer_instrumental_instv5_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
124
+ },
125
+ "Roformer Model: MelBand Roformer | INSTV5N by Gabox": {
126
+ "mel_band_roformer_instrumental_instv5n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
127
+ },
128
+ "Roformer Model: MelBand Roformer | INSTV6 by Gabox": {
129
+ "mel_band_roformer_instrumental_instv6_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
130
+ },
131
+ "Roformer Model: MelBand Roformer | INSTV6N by Gabox": {
132
+ "mel_band_roformer_instrumental_instv6n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
133
+ },
134
+ "Roformer Model: MelBand Roformer | INSTV7 by Gabox": {
135
+ "mel_band_roformer_instrumental_instv7_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
136
+ },
137
+ "Roformer Model: MelBand Roformer | INSTV7N by Gabox": {
138
+ "mel_band_roformer_instrumental_instv7n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
139
+ },
140
+ "Roformer Model: MelBand Roformer | INSTV8 by Gabox": {
141
+ "mel_band_roformer_instrumental_instv8_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
142
+ },
143
+ "Roformer Model: MelBand Roformer | INSTV8N by Gabox": {
144
+ "mel_band_roformer_instrumental_instv8n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
145
+ },
146
+ "Roformer Model: MelBand Roformer | FVX by Gabox": {
147
+ "mel_band_roformer_instrumental_fvx_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
148
+ },
149
+ "Roformer Model: MelBand Roformer | De-Reverb by anvuew": {
150
+ "dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
151
+ },
152
+ "Roformer Model: MelBand Roformer | De-Reverb Less Aggressive by anvuew": {
153
+ "dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
154
+ },
155
+ "Roformer Model: MelBand Roformer | De-Reverb Mono by anvuew": {
156
+ "dereverb_mel_band_roformer_mono_anvuew.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
157
+ },
158
+ "Roformer Model: MelBand Roformer | De-Reverb Big by Sucial": {
159
+ "dereverb_big_mbr_ep_362.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
160
+ },
161
+ "Roformer Model: MelBand Roformer | De-Reverb Super Big by Sucial": {
162
+ "dereverb_super_big_mbr_ep_346.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
163
+ },
164
+ "Roformer Model: MelBand Roformer | De-Reverb-Echo by Sucial": {
165
+ "dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt": "config_dereverb-echo_mel_band_roformer.yaml"
166
+ },
167
+ "Roformer Model: MelBand Roformer | De-Reverb-Echo V2 by Sucial": {
168
+ "dereverb-echo_mel_band_roformer_sdr_13.4843_v2.ckpt": "config_dereverb-echo_mel_band_roformer_sdr_13.4843_v2.yaml"
169
+ },
170
+ "Roformer Model: MelBand Roformer | De-Reverb-Echo Fused by Sucial": {
171
+ "dereverb_echo_mbr_fused.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
172
+ },
173
+ "Roformer Model: MelBand Roformer Kim | SYHFT by SYH99999": {
174
+ "MelBandRoformerSYHFT.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
175
+ },
176
+ "Roformer Model: MelBand Roformer Kim | SYHFT V2 by SYH99999": {
177
+ "MelBandRoformerSYHFTV2.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
178
+ },
179
+ "Roformer Model: MelBand Roformer Kim | SYHFT V2.5 by SYH99999": {
180
+ "MelBandRoformerSYHFTV2.5.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
181
+ },
182
+ "Roformer Model: MelBand Roformer Kim | SYHFT V3 by SYH99999": {
183
+ "MelBandRoformerSYHFTV3Epsilon.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
184
+ },
185
+ "Roformer Model: MelBand Roformer Kim | Big SYHFT V1 by SYH99999": {
186
+ "MelBandRoformerBigSYHFTV1.ckpt": "config_vocals_mel_band_roformer_big_v1_ft.yaml"
187
+ },
188
+ "Roformer Model: MelBand Roformer Kim | Big Beta 4 FT by unwa": {
189
+ "melband_roformer_big_beta4.ckpt": "config_melbandroformer_big_beta4.yaml"
190
+ },
191
+ "Roformer Model: MelBand Roformer Kim | Big Beta 5e FT by unwa": {
192
+ "melband_roformer_big_beta5e.ckpt": "config_melband_roformer_big_beta5e.yaml"
193
+ },
194
+ "Roformer Model: MelBand Roformer | Big Beta 6 by unwa": {
195
+ "melband_roformer_big_beta6.ckpt": "config_melbandroformer_big_beta6.yaml"
196
+ },
197
+ "Roformer Model: MelBand Roformer | Big Beta 6X by unwa": {
198
+ "melband_roformer_big_beta6x.ckpt": "config_melbandroformer_big_beta6x.yaml"
199
+ },
200
+ "Roformer Model: BS Roformer | Chorus Male-Female by Sucial": {
201
+ "model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt": "config_chorus_male_female_bs_roformer.yaml"
202
+ },
203
+ "Roformer Model: BS Roformer | Male-Female by aufr33": {
204
+ "bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt": "config_chorus_male_female_bs_roformer.yaml"
205
+ },
206
+ "Roformer Model: MelBand Roformer | Aspiration by Sucial": {
207
+ "aspiration_mel_band_roformer_sdr_18.9845.ckpt": "config_aspiration_mel_band_roformer.yaml"
208
+ },
209
+ "Roformer Model: MelBand Roformer | Aspiration Less Aggressive by Sucial": {
210
+ "aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt": "config_aspiration_mel_band_roformer.yaml"
211
+ },
212
+ "Roformer Model: MelBand Roformer | Bleed Suppressor V1 by unwa-97chris": {
213
+ "mel_band_roformer_bleed_suppressor_v1.ckpt": "config_mel_band_roformer_bleed_suppressor_v1.yaml"
214
+ }
215
+ }
216
+ }
audio_separator/separator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .separator import Separator
audio_separator/separator/architectures/__init__.py ADDED
File without changes
audio_separator/separator/architectures/demucs_separator.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ import torch
5
+ import numpy as np
6
+ from audio_separator.separator.common_separator import CommonSeparator
7
+ from audio_separator.separator.uvr_lib_v5.demucs.apply import apply_model, demucs_segments
8
+ from audio_separator.separator.uvr_lib_v5.demucs.hdemucs import HDemucs
9
+ from audio_separator.separator.uvr_lib_v5.demucs.pretrained import get_model as get_demucs_model
10
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
11
+
12
+ DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
13
+
14
+ DEMUCS_2_SOURCE_MAPPER = {CommonSeparator.INST_STEM: 0, CommonSeparator.VOCAL_STEM: 1}
15
+ DEMUCS_4_SOURCE_MAPPER = {CommonSeparator.BASS_STEM: 0, CommonSeparator.DRUM_STEM: 1, CommonSeparator.OTHER_STEM: 2, CommonSeparator.VOCAL_STEM: 3}
16
+ DEMUCS_6_SOURCE_MAPPER = {
17
+ CommonSeparator.BASS_STEM: 0,
18
+ CommonSeparator.DRUM_STEM: 1,
19
+ CommonSeparator.OTHER_STEM: 2,
20
+ CommonSeparator.VOCAL_STEM: 3,
21
+ CommonSeparator.GUITAR_STEM: 4,
22
+ CommonSeparator.PIANO_STEM: 5,
23
+ }
24
+
25
+
26
+ class DemucsSeparator(CommonSeparator):
27
+ """
28
+ DemucsSeparator is responsible for separating audio sources using Demucs models.
29
+ It initializes with configuration parameters and prepares the model for separation tasks.
30
+ """
31
+
32
+ def __init__(self, common_config, arch_config):
33
+ # Any configuration values which can be shared between architectures should be set already in CommonSeparator,
34
+ # e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
35
+ super().__init__(config=common_config)
36
+
37
+ # Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance
38
+
39
+ # Adjust segments to manage RAM or V-RAM usage:
40
+ # - Smaller sizes consume less resources.
41
+ # - Bigger sizes consume more resources, but may provide better results.
42
+ # - "Default" picks the optimal size.
43
+ # DEMUCS_SEGMENTS = (DEF_OPT, '1', '5', '10', '15', '20',
44
+ # '25', '30', '35', '40', '45', '50',
45
+ # '55', '60', '65', '70', '75', '80',
46
+ # '85', '90', '95', '100')
47
+ self.segment_size = arch_config.get("segment_size", "Default")
48
+
49
+ # Performs multiple predictions with random shifts of the input and averages them.
50
+ # The higher number of shifts, the longer the prediction will take.
51
+ # Not recommended unless you have a GPU.
52
+ # DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5,
53
+ # 6, 7, 8, 9, 10, 11,
54
+ # 12, 13, 14, 15, 16, 17,
55
+ # 18, 19, 20)
56
+ self.shifts = arch_config.get("shifts", 2)
57
+
58
+ # This option controls the amount of overlap between prediction windows.
59
+ # - Higher values can provide better results, but will lead to longer processing times.
60
+ # - You can choose between 0.001-0.999
61
+ # DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99)
62
+ self.overlap = arch_config.get("overlap", 0.25)
63
+
64
+ # Enables "Segments". Deselecting this option is only recommended for those with powerful PCs.
65
+ self.segments_enabled = arch_config.get("segments_enabled", True)
66
+
67
+ self.logger.debug(f"Demucs arch params: segment_size={self.segment_size}, segments_enabled={self.segments_enabled}")
68
+ self.logger.debug(f"Demucs arch params: shifts={self.shifts}, overlap={self.overlap}")
69
+
70
+ self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
71
+
72
+ self.audio_file_path = None
73
+ self.audio_file_base = None
74
+ self.demucs_model_instance = None
75
+
76
+ # Add uvr_lib_v5 folder to system path so pytorch serialization can find the demucs module
77
+ current_dir = os.path.dirname(__file__)
78
+ uvr_lib_v5_path = os.path.join(current_dir, "..", "uvr_lib_v5")
79
+ sys.path.insert(0, uvr_lib_v5_path)
80
+
81
+ self.logger.info("Demucs Separator initialisation complete")
82
+
83
+ def separate(self, audio_file_path, custom_output_names=None):
84
+ """
85
+ Separates the audio file into its component stems using the Demucs model.
86
+
87
+ Args:
88
+ audio_file_path (str): The path to the audio file to be processed.
89
+ custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
90
+
91
+ Returns:
92
+ list: A list of paths to the output files generated by the separation process.
93
+ """
94
+ self.logger.debug("Starting separation process...")
95
+ source = None
96
+ stem_source = None
97
+ inst_source = {}
98
+
99
+ self.audio_file_path = audio_file_path
100
+ self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
101
+
102
+ # Prepare the mix for processing
103
+ self.logger.debug("Preparing mix...")
104
+ mix = self.prepare_mix(self.audio_file_path)
105
+
106
+ self.logger.debug(f"Mix prepared for demixing. Shape: {mix.shape}")
107
+
108
+ self.logger.debug("Loading model for demixing...")
109
+
110
+ self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE)
111
+ self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
112
+ self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance)
113
+ self.demucs_model_instance.to(self.torch_device)
114
+ self.demucs_model_instance.eval()
115
+
116
+ self.logger.debug("Model loaded and set to evaluation mode.")
117
+
118
+ source = self.demix_demucs(mix)
119
+
120
+ del self.demucs_model_instance
121
+ self.clear_gpu_cache()
122
+ self.logger.debug("Model and GPU cache cleared after demixing.")
123
+
124
+ output_files = []
125
+ self.logger.debug("Processing output files...")
126
+
127
+ if isinstance(inst_source, np.ndarray):
128
+ self.logger.debug("Processing instance source...")
129
+ source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]])
130
+ inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape
131
+ source = inst_source
132
+
133
+ if isinstance(source, np.ndarray):
134
+ source_length = len(source)
135
+ self.logger.debug(f"Processing source array, source length is {source_length}")
136
+ match source_length:
137
+ case 2:
138
+ self.logger.debug("Setting source map to 2-stem...")
139
+ self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
140
+ case 6:
141
+ self.logger.debug("Setting source map to 6-stem...")
142
+ self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER
143
+ case _:
144
+ self.logger.debug("Setting source map to 4-stem...")
145
+ self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
146
+
147
+ self.logger.debug("Processing for all stems...")
148
+ for stem_name, stem_value in self.demucs_source_map.items():
149
+ if self.output_single_stem is not None:
150
+ if stem_name.lower() != self.output_single_stem.lower():
151
+ self.logger.debug(f"Skipping writing stem {stem_name} as output_single_stem is set to {self.output_single_stem}...")
152
+ continue
153
+
154
+ stem_path = self.get_stem_output_path(stem_name, custom_output_names)
155
+ stem_source = source[stem_value].T
156
+
157
+ self.final_process(stem_path, stem_source, stem_name)
158
+ output_files.append(stem_path)
159
+
160
+ return output_files
161
+
162
+ def demix_demucs(self, mix):
163
+ """
164
+ Demixes the input mix using the demucs model.
165
+ """
166
+ self.logger.debug("Starting demixing process in demix_demucs...")
167
+
168
+ processed = {}
169
+ mix = torch.tensor(mix, dtype=torch.float32)
170
+ ref = mix.mean(0)
171
+ mix = (mix - ref.mean()) / ref.std()
172
+ mix_infer = mix
173
+
174
+ with torch.no_grad():
175
+ self.logger.debug("Running model inference...")
176
+ sources = apply_model(
177
+ model=self.demucs_model_instance,
178
+ mix=mix_infer[None],
179
+ shifts=self.shifts,
180
+ split=self.segments_enabled,
181
+ overlap=self.overlap,
182
+ static_shifts=1 if self.shifts == 0 else self.shifts,
183
+ set_progress_bar=None,
184
+ device=self.torch_device,
185
+ progress=True,
186
+ )[0]
187
+
188
+ sources = (sources * ref.std() + ref.mean()).cpu().numpy()
189
+ sources[[0, 1]] = sources[[1, 0]]
190
+ processed[mix] = sources[:, :, 0:None].copy()
191
+ sources = list(processed.values())
192
+ sources = [s[:, :, 0:None] for s in sources]
193
+ sources = np.concatenate(sources, axis=-1)
194
+
195
+ return sources
audio_separator/separator/architectures/mdx_separator.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for separating audio sources using MDX architecture models."""
2
+
3
+ import os
4
+ import platform
5
+ import torch
6
+ import onnx
7
+ import onnxruntime as ort
8
+ import numpy as np
9
+ import onnx2torch
10
+ from tqdm import tqdm
11
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
12
+ from audio_separator.separator.uvr_lib_v5.stft import STFT
13
+ from audio_separator.separator.common_separator import CommonSeparator
14
+
15
+
16
+ class MDXSeparator(CommonSeparator):
17
+ """
18
+ MDXSeparator is responsible for separating audio sources using MDX models.
19
+ It initializes with configuration parameters and prepares the model for separation tasks.
20
+ """
21
+
22
+ def __init__(self, common_config, arch_config):
23
+ # Any configuration values which can be shared between architectures should be set already in CommonSeparator,
24
+ # e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
25
+ super().__init__(config=common_config)
26
+
27
+ # Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance
28
+
29
+ # Pick a segment size to balance speed, resource use, and quality:
30
+ # - Smaller sizes consume less resources.
31
+ # - Bigger sizes consume more resources, but may provide better results.
32
+ # - Default size is 256. Quality can change based on your pick.
33
+ self.segment_size = arch_config.get("segment_size")
34
+
35
+ # This option controls the amount of overlap between prediction windows.
36
+ # - Higher values can provide better results, but will lead to longer processing times.
37
+ # - For Non-MDX23C models: You can choose between 0.001-0.999
38
+ self.overlap = arch_config.get("overlap")
39
+
40
+ # Number of batches to be processed at a time.
41
+ # - Higher values mean more RAM usage but slightly faster processing times.
42
+ # - Lower values mean less RAM usage but slightly longer processing times.
43
+ # - Batch size value has no effect on output quality.
44
+ # BATCH_SIZE = ('1', ''2', '3', '4', '5', '6', '7', '8', '9', '10')
45
+ self.batch_size = arch_config.get("batch_size", 1)
46
+
47
+ # hop_length is equivalent to the more commonly used term "stride" in convolutional neural networks
48
+ # In machine learning, particularly in the context of convolutional neural networks (CNNs),
49
+ # the term "stride" refers to the number of pixels by which we move the filter across the input image.
50
+ # Strides are a crucial component in the convolution operation, a fundamental building block of CNNs used primarily in the field of computer vision.
51
+ # Stride is a parameter that dictates the movement of the kernel, or filter, across the input data, such as an image.
52
+ # When performing a convolution operation, the stride determines how many units the filter shifts at each step.
53
+ # The choice of stride affects the model in several ways:
54
+ # Output Size: A larger stride will result in a smaller output spatial dimension.
55
+ # Computational Efficiency: Increasing the stride can decrease the computational load.
56
+ # Field of View: A higher stride means that each step of the filter takes into account a wider area of the input image.
57
+ # This can be beneficial when the model needs to capture more global features rather than focusing on finer details.
58
+ self.hop_length = arch_config.get("hop_length")
59
+
60
+ # If enabled, model will be run twice to reduce noise in output audio.
61
+ self.enable_denoise = arch_config.get("enable_denoise")
62
+
63
+ self.logger.debug(f"MDX arch params: batch_size={self.batch_size}, segment_size={self.segment_size}")
64
+ self.logger.debug(f"MDX arch params: overlap={self.overlap}, hop_length={self.hop_length}, enable_denoise={self.enable_denoise}")
65
+
66
+ # Initializing model-specific parameters from model_data JSON
67
+ self.compensate = self.model_data["compensate"]
68
+ self.dim_f = self.model_data["mdx_dim_f_set"]
69
+ self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
70
+ self.n_fft = self.model_data["mdx_n_fft_scale_set"]
71
+ self.config_yaml = self.model_data.get("config_yaml", None)
72
+
73
+ self.logger.debug(f"MDX arch params: compensate={self.compensate}, dim_f={self.dim_f}, dim_t={self.dim_t}, n_fft={self.n_fft}")
74
+ self.logger.debug(f"MDX arch params: config_yaml={self.config_yaml}")
75
+
76
+ # In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
77
+ # Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.
78
+
79
+ # "chunks" is not actually used for anything in UVR...
80
+ # self.chunks = 0
81
+
82
+ # "adjust" is hard-coded to 1 in UVR, and only used as a multiplier in run_model, so it does nothing.
83
+ # self.adjust = 1
84
+
85
+ # "hop" is hard-coded to 1024 in UVR. We have a "hop_length" parameter instead
86
+ # self.hop = 1024
87
+
88
+ # "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead.
89
+ # self.margin = 44100
90
+
91
+ # "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
92
+ # We haven't implemented support for the checkpoint models here, so we're not using it.
93
+ # self.dim_c = 4
94
+
95
+ self.load_model()
96
+
97
+ self.n_bins = 0
98
+ self.trim = 0
99
+ self.chunk_size = 0
100
+ self.gen_size = 0
101
+ self.stft = None
102
+
103
+ self.primary_source = None
104
+ self.secondary_source = None
105
+ self.audio_file_path = None
106
+ self.audio_file_base = None
107
+
108
+ def load_model(self):
109
+ """
110
+ Load the model into memory from file on disk, initialize it with config from the model data,
111
+ and prepare for inferencing using hardware accelerated Torch device.
112
+ """
113
+ self.logger.debug("Loading ONNX model for inference...")
114
+
115
+ if self.segment_size == self.dim_t:
116
+ ort_session_options = ort.SessionOptions()
117
+ if self.log_level > 10:
118
+ ort_session_options.log_severity_level = 3
119
+ else:
120
+ ort_session_options.log_severity_level = 0
121
+
122
+ ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
123
+ self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
124
+ self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
125
+ else:
126
+ if platform.system() == 'Windows':
127
+ onnx_model = onnx.load(self.model_path)
128
+ self.model_run = onnx2torch.convert(onnx_model)
129
+ else:
130
+ self.model_run = onnx2torch.convert(self.model_path)
131
+
132
+ self.model_run.to(self.torch_device).eval()
133
+ self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
134
+
135
+ def separate(self, audio_file_path, custom_output_names=None):
136
+ """
137
+ Separates the audio file into primary and secondary sources based on the model's configuration.
138
+ It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
139
+
140
+ Args:
141
+ audio_file_path (str): The path to the audio file to be processed.
142
+ custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
143
+
144
+ Returns:
145
+ list: A list of paths to the output files generated by the separation process.
146
+ """
147
+ self.audio_file_path = audio_file_path
148
+ self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
149
+
150
+ # Prepare the mix for processing
151
+ self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
152
+ mix = self.prepare_mix(self.audio_file_path)
153
+
154
+ self.logger.debug("Normalizing mix before demixing...")
155
+ mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
156
+
157
+ # Start the demixing process
158
+ source = self.demix(mix)
159
+ self.logger.debug("Demixing completed.")
160
+
161
+ # In UVR, the source is cached here if it's a vocal split model, but we're not supporting that yet
162
+
163
+ # Initialize the list for output files
164
+ output_files = []
165
+ self.logger.debug("Processing output files...")
166
+
167
+ # Normalize and transpose the primary source if it's not already an array
168
+ if not isinstance(self.primary_source, np.ndarray):
169
+ self.logger.debug("Normalizing primary source...")
170
+ self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
171
+
172
+ # Process the secondary source if not already an array
173
+ if not isinstance(self.secondary_source, np.ndarray):
174
+ self.logger.debug("Producing secondary source: demixing in match_mix mode")
175
+ raw_mix = self.demix(mix, is_match_mix=True)
176
+
177
+ if self.invert_using_spec:
178
+ self.logger.debug("Inverting secondary stem using spectogram as invert_using_spec is set to True")
179
+ self.secondary_source = spec_utils.invert_stem(raw_mix, source)
180
+ else:
181
+ self.logger.debug("Inverting secondary stem by subtracting of transposed demixed stem from transposed original mix")
182
+ self.secondary_source = mix.T - source.T
183
+
184
+ # Save and process the secondary stem if needed
185
+ if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
186
+ self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
187
+
188
+ self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
189
+ self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
190
+ output_files.append(self.secondary_stem_output_path)
191
+
192
+ # Save and process the primary stem if needed
193
+ if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
194
+ self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
195
+
196
+ if not isinstance(self.primary_source, np.ndarray):
197
+ self.primary_source = source.T
198
+
199
+ self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
200
+ self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
201
+ output_files.append(self.primary_stem_output_path)
202
+
203
+ # Not yet implemented from UVR features:
204
+ # self.process_vocal_split_chain(secondary_sources)
205
+ # self.logger.debug("Vocal split chain processed.")
206
+
207
+ return output_files
208
+
209
+ def initialize_model_settings(self):
210
+ """
211
+ This function sets up the necessary parameters for the model, like the number of frequency bins (n_bins), the trimming size (trim),
212
+ the size of each audio chunk (chunk_size), and the window function for spectral transformations (window).
213
+ It ensures that the model is configured with the correct settings for processing the audio data.
214
+ """
215
+ self.logger.debug("Initializing model settings...")
216
+
217
+ # n_bins is half the FFT size plus one (self.n_fft // 2 + 1).
218
+ self.n_bins = self.n_fft // 2 + 1
219
+
220
+ # trim is half the FFT size (self.n_fft // 2).
221
+ self.trim = self.n_fft // 2
222
+
223
+ # chunk_size is the hop_length size times the segment size minus one
224
+ self.chunk_size = self.hop_length * (self.segment_size - 1)
225
+
226
+ # gen_size is the chunk size minus twice the trim size
227
+ self.gen_size = self.chunk_size - 2 * self.trim
228
+
229
+ self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
230
+
231
+ self.logger.debug(f"Model input params: n_fft={self.n_fft} hop_length={self.hop_length} dim_f={self.dim_f}")
232
+ self.logger.debug(f"Model settings: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}")
233
+
234
+ def initialize_mix(self, mix, is_ckpt=False):
235
+ """
236
+ After prepare_mix segments the audio, initialize_mix further processes each segment.
237
+ It ensures each audio segment is in the correct format for the model, applies necessary padding,
238
+ and converts the segments into tensors for processing with the model.
239
+ This step is essential for preparing the audio data in a format that the neural network can process.
240
+ """
241
+ # Log the initialization of the mix and whether checkpoint mode is used
242
+ self.logger.debug(f"Initializing mix with is_ckpt={is_ckpt}. Initial mix shape: {mix.shape}")
243
+
244
+ # Ensure the mix is a 2-channel (stereo) audio signal
245
+ if mix.shape[0] != 2:
246
+ error_message = f"Expected a 2-channel audio signal, but got {mix.shape[0]} channels"
247
+ self.logger.error(error_message)
248
+ raise ValueError(error_message)
249
+
250
+ # If in checkpoint mode, process the mix differently
251
+ if is_ckpt:
252
+ self.logger.debug("Processing in checkpoint mode...")
253
+ # Calculate padding based on the generation size and trim
254
+ pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
255
+ self.logger.debug(f"Padding calculated: {pad}")
256
+ # Add padding at the beginning and the end of the mix
257
+ mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
258
+ # Determine the number of chunks based on the mixture's length
259
+ num_chunks = mixture.shape[-1] // self.gen_size
260
+ self.logger.debug(f"Mixture shape after padding: {mixture.shape}, Number of chunks: {num_chunks}")
261
+ # Split the mixture into chunks
262
+ mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
263
+ else:
264
+ # If not in checkpoint mode, process normally
265
+ self.logger.debug("Processing in non-checkpoint mode...")
266
+ mix_waves = []
267
+ n_sample = mix.shape[1]
268
+ # Calculate necessary padding to make the total length divisible by the generation size
269
+ pad = self.gen_size - n_sample % self.gen_size
270
+ self.logger.debug(f"Number of samples: {n_sample}, Padding calculated: {pad}")
271
+ # Apply padding to the mix
272
+ mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
273
+ self.logger.debug(f"Shape of mix after padding: {mix_p.shape}")
274
+
275
+ # Process the mix in chunks
276
+ i = 0
277
+ while i < n_sample + pad:
278
+ waves = np.array(mix_p[:, i : i + self.chunk_size])
279
+ mix_waves.append(waves)
280
+ self.logger.debug(f"Processed chunk {len(mix_waves)}: Start {i}, End {i + self.chunk_size}")
281
+ i += self.gen_size
282
+
283
+ # Convert the list of wave chunks into a tensor for processing on the specified device
284
+ mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
285
+ self.logger.debug(f"Converted mix_waves to tensor. Tensor shape: {mix_waves_tensor.shape}")
286
+
287
+ return mix_waves_tensor, pad
288
+
289
+ def demix(self, mix, is_match_mix=False):
290
+ """
291
+ Demixes the input mix into its constituent sources. If is_match_mix is True, the function adjusts the processing
292
+ to better match the mix, affecting chunk sizes and overlaps. The demixing process involves padding the mix,
293
+ processing it in chunks, applying windowing for overlaps, and accumulating the results to separate the sources.
294
+ """
295
+ self.logger.debug(f"Starting demixing process with is_match_mix: {is_match_mix}...")
296
+ self.initialize_model_settings()
297
+
298
+ # Preserves the original mix for later use.
299
+ # In UVR, this is used for the pitch fix and VR denoise processes, which aren't yet implemented here.
300
+ org_mix = mix
301
+ self.logger.debug(f"Original mix stored. Shape: {org_mix.shape}")
302
+
303
+ # Initializes a list to store the separated waveforms.
304
+ tar_waves_ = []
305
+
306
+ # Handling different chunk sizes and overlaps based on the matching requirement.
307
+ if is_match_mix:
308
+ # Sets a smaller chunk size specifically for matching the mix.
309
+ chunk_size = self.hop_length * (self.segment_size - 1)
310
+ # Sets a small overlap for the chunks.
311
+ overlap = 0.02
312
+ self.logger.debug(f"Chunk size for matching mix: {chunk_size}, Overlap: {overlap}")
313
+ else:
314
+ # Uses the regular chunk size defined in model settings.
315
+ chunk_size = self.chunk_size
316
+ # Uses the overlap specified in the model settings.
317
+ overlap = self.overlap
318
+ self.logger.debug(f"Standard chunk size: {chunk_size}, Overlap: {overlap}")
319
+
320
+ # Calculates the generated size after subtracting the trim from both ends of the chunk.
321
+ gen_size = chunk_size - 2 * self.trim
322
+ self.logger.debug(f"Generated size calculated: {gen_size}")
323
+
324
+ # Calculates padding to make the mix length a multiple of the generated size.
325
+ pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
326
+ # Prepares the mixture with padding at the beginning and the end.
327
+ mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
328
+ self.logger.debug(f"Mixture prepared with padding. Mixture shape: {mixture.shape}")
329
+
330
+ # Calculates the step size for processing chunks based on the overlap.
331
+ step = int((1 - overlap) * chunk_size)
332
+ self.logger.debug(f"Step size for processing chunks: {step} as overlap is set to {overlap}.")
333
+
334
+ # Initializes arrays to store the results and to account for overlap.
335
+ result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
336
+ divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
337
+
338
+ # Initializes counters for processing chunks.
339
+ total = 0
340
+ total_chunks = (mixture.shape[-1] + step - 1) // step
341
+ self.logger.debug(f"Total chunks to process: {total_chunks}")
342
+
343
+ # Processes each chunk of the mixture.
344
+ for i in tqdm(range(0, mixture.shape[-1], step)):
345
+ total += 1
346
+ start = i
347
+ end = min(i + chunk_size, mixture.shape[-1])
348
+ self.logger.debug(f"Processing chunk {total}/{total_chunks}: Start {start}, End {end}")
349
+
350
+ # Handles windowing for overlapping chunks.
351
+ chunk_size_actual = end - start
352
+ window = None
353
+ if overlap != 0:
354
+ window = np.hanning(chunk_size_actual)
355
+ window = np.tile(window[None, None, :], (1, 2, 1))
356
+ self.logger.debug("Window applied to the chunk.")
357
+
358
+ # Zero-pad the chunk to prepare it for processing.
359
+ mix_part_ = mixture[:, start:end]
360
+ if end != i + chunk_size:
361
+ pad_size = (i + chunk_size) - end
362
+ mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
363
+
364
+ # Converts the chunk to a tensor for processing.
365
+ mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
366
+ # Splits the chunk into smaller batches if necessary.
367
+ mix_waves = mix_part.split(self.batch_size)
368
+ total_batches = len(mix_waves)
369
+ self.logger.debug(f"Mix part split into batches. Number of batches: {total_batches}")
370
+
371
+ with torch.no_grad():
372
+ # Processes each batch in the chunk.
373
+ batches_processed = 0
374
+ for mix_wave in mix_waves:
375
+ batches_processed += 1
376
+ self.logger.debug(f"Processing mix_wave batch {batches_processed}/{total_batches}")
377
+
378
+ # Runs the model to separate the sources.
379
+ tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
380
+
381
+ # Applies windowing if needed and accumulates the results.
382
+ if window is not None:
383
+ tar_waves[..., :chunk_size_actual] *= window
384
+ divider[..., start:end] += window
385
+ else:
386
+ divider[..., start:end] += 1
387
+
388
+ result[..., start:end] += tar_waves[..., : end - start]
389
+
390
+ # Normalizes the results by the divider to account for overlap.
391
+ self.logger.debug("Normalizing result by dividing result by divider.")
392
+ tar_waves = result / divider
393
+ tar_waves_.append(tar_waves)
394
+
395
+ # Reshapes the results to match the original dimensions.
396
+ tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
397
+ tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]]
398
+
399
+ # Extracts the source from the results.
400
+ source = tar_waves[:, 0:None]
401
+ self.logger.debug(f"Concatenated tar_waves. Shape: {tar_waves.shape}")
402
+
403
+ # TODO: In UVR, pitch changing happens here. Consider implementing this as a feature.
404
+
405
+ # Compensates the source if not matching the mix.
406
+ if not is_match_mix:
407
+ source *= self.compensate
408
+ self.logger.debug("Match mix mode; compensate multiplier applied.")
409
+
410
+ # TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature.
411
+
412
+ self.logger.debug("Demixing process completed.")
413
+ return source
414
+
415
+ def run_model(self, mix, is_match_mix=False):
416
+ """
417
+ Processes the input mix through the model to separate the sources.
418
+ Applies STFT, handles spectrum modifications, and runs the model for source separation.
419
+ """
420
+ # Applying the STFT to the mix. The mix is moved to the specified device (e.g., GPU) before processing.
421
+ # self.logger.debug(f"Running STFT on the mix. Mix shape before STFT: {mix.shape}")
422
+ spek = self.stft(mix.to(self.torch_device))
423
+ self.logger.debug(f"STFT applied on mix. Spectrum shape: {spek.shape}")
424
+
425
+ # Zeroing out the first 3 bins of the spectrum. This is often done to reduce low-frequency noise.
426
+ spek[:, :, :3, :] *= 0
427
+ # self.logger.debug("First 3 bins of the spectrum zeroed out.")
428
+
429
+ # Handling the case where the mix needs to be matched (is_match_mix = True)
430
+ if is_match_mix:
431
+ # self.logger.debug("Match mix mode is enabled. Converting spectrum to NumPy array.")
432
+ spec_pred = spek.cpu().numpy()
433
+ self.logger.debug("is_match_mix: spectrum prediction obtained directly from STFT output.")
434
+ else:
435
+ # If denoising is enabled, the model is run on both the negative and positive spectrums.
436
+ if self.enable_denoise:
437
+ # Assuming spek is a tensor and self.model_run can process it directly
438
+ spec_pred_neg = self.model_run(-spek) # Ensure this line correctly negates spek and runs the model
439
+ spec_pred_pos = self.model_run(spek)
440
+ # Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations
441
+ spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type]
442
+ self.logger.debug("Model run on both negative and positive spectrums for denoising.")
443
+ else:
444
+ spec_pred = self.model_run(spek)
445
+ self.logger.debug("Model run on the spectrum without denoising.")
446
+
447
+ # Applying the inverse STFT to convert the spectrum back to the time domain.
448
+ result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
449
+ self.logger.debug(f"Inverse STFT applied. Returning result with shape: {result.shape}")
450
+
451
+ return result
audio_separator/separator/architectures/mdxc_separator.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from ml_collections import ConfigDict
8
+ from scipy import signal
9
+
10
+ from audio_separator.separator.common_separator import CommonSeparator
11
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
12
+ from audio_separator.separator.uvr_lib_v5.tfc_tdf_v3 import TFC_TDF_net
13
+ from audio_separator.separator.uvr_lib_v5.roformer.mel_band_roformer import MelBandRoformer
14
+ from audio_separator.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer
15
+
16
+
17
+ class MDXCSeparator(CommonSeparator):
18
+ """
19
+ MDXCSeparator is responsible for separating audio sources using MDXC models.
20
+ It initializes with configuration parameters and prepares the model for separation tasks.
21
+ """
22
+
23
+ def __init__(self, common_config, arch_config):
24
+ # Any configuration values which can be shared between architectures should be set already in CommonSeparator,
25
+ # e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
26
+ super().__init__(config=common_config)
27
+
28
+ # Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
29
+ # It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
30
+ # The instance variable self.model_data is passed through from Separator and set in CommonSeparator
31
+ self.logger.debug(f"Model data: {self.model_data}")
32
+
33
+ # Arch Config is the MDXC architecture specific user configuration options, which should all be configurable by the user
34
+ # either by their Separator class instantiation or by passing in a CLI parameter.
35
+ # While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
36
+ # this way as they have architecture-specific default values.
37
+ self.segment_size = arch_config.get("segment_size", 256)
38
+
39
+ # Whether or not to use the segment size from model config, or the default
40
+ # The segment size is set based on the value provided in a chosen model's associated config file (yaml).
41
+ self.override_model_segment_size = arch_config.get("override_model_segment_size", False)
42
+
43
+ self.overlap = arch_config.get("overlap", 8)
44
+ self.batch_size = arch_config.get("batch_size", 1)
45
+
46
+ # Amount of pitch shift to apply during processing (this does NOT affect the pitch of the output audio):
47
+ # • Whole numbers indicate semitones.
48
+ # • Using higher pitches may cut the upper bandwidth, even in high-quality models.
49
+ # • Upping the pitch can be better for tracks with deeper vocals.
50
+ # • Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.
51
+ self.pitch_shift = arch_config.get("pitch_shift", 0)
52
+
53
+ self.process_all_stems = arch_config.get("process_all_stems", True)
54
+
55
+ self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}")
56
+ self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
57
+ self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")
58
+
59
+ self.is_roformer = "is_roformer" in self.model_data
60
+
61
+ self.load_model()
62
+
63
+ self.primary_source = None
64
+ self.secondary_source = None
65
+ self.audio_file_path = None
66
+ self.audio_file_base = None
67
+
68
+ self.is_primary_stem_main_target = False
69
+ if self.model_data_cfgdict.training.target_instrument == "Vocals" or len(self.model_data_cfgdict.training.instruments) > 1:
70
+ self.is_primary_stem_main_target = True
71
+
72
+ self.logger.debug(f"is_primary_stem_main_target: {self.is_primary_stem_main_target}")
73
+
74
+ self.logger.info("MDXC Separator initialisation complete")
75
+
76
+ def load_model(self):
77
+ """
78
+ Load the model into memory from file on disk, initialize it with config from the model data,
79
+ and prepare for inferencing using hardware accelerated Torch device.
80
+ """
81
+ self.logger.debug("Loading checkpoint model for inference...")
82
+
83
+ self.model_data_cfgdict = ConfigDict(self.model_data)
84
+
85
+ try:
86
+ if self.is_roformer:
87
+ self.logger.debug("Loading Roformer model...")
88
+
89
+ # Determine the model type based on the configuration and instantiate it
90
+ if "num_bands" in self.model_data_cfgdict.model:
91
+ self.logger.debug("Loading MelBandRoformer model...")
92
+ model = MelBandRoformer(**self.model_data_cfgdict.model)
93
+ elif "freqs_per_bands" in self.model_data_cfgdict.model:
94
+ self.logger.debug("Loading BSRoformer model...")
95
+ model = BSRoformer(**self.model_data_cfgdict.model)
96
+ else:
97
+ raise ValueError("Unknown Roformer model type in the configuration.")
98
+
99
+ # Load model checkpoint
100
+ checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
101
+ self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
102
+ self.model_run.load_state_dict(checkpoint)
103
+ self.model_run.to(self.torch_device).eval()
104
+
105
+ else:
106
+ self.logger.debug("Loading TFC_TDF_net model...")
107
+ self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
108
+ self.logger.debug("Loading model onto cpu")
109
+ # For some reason loading the state onto a hardware accelerated devices causes issues,
110
+ # so we load it onto CPU first then move it to the device
111
+ self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
112
+ self.model_run.to(self.torch_device).eval()
113
+
114
+ except RuntimeError as e:
115
+ self.logger.error(f"Error: {e}")
116
+ self.logger.error("An error occurred while loading the model file. This often occurs when the model file is corrupt or incomplete.")
117
+ self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
118
+ sys.exit(1)
119
+
120
+ def separate(self, audio_file_path, custom_output_names=None):
121
+ """
122
+ Separates the audio file into primary and secondary sources based on the model's configuration.
123
+ It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
124
+
125
+ Args:
126
+ audio_file_path (str): The path to the audio file to be processed.
127
+ custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
128
+
129
+ Returns:
130
+ list: A list of paths to the output files generated by the separation process.
131
+ """
132
+ self.primary_source = None
133
+ self.secondary_source = None
134
+
135
+ self.audio_file_path = audio_file_path
136
+ self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
137
+
138
+ self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
139
+ mix = self.prepare_mix(self.audio_file_path)
140
+
141
+ self.logger.debug("Normalizing mix before demixing...")
142
+ mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
143
+
144
+ source = self.demix(mix=mix)
145
+ self.logger.debug("Demixing completed.")
146
+
147
+ output_files = []
148
+ self.logger.debug("Processing output files...")
149
+
150
+ if isinstance(source, dict):
151
+ self.logger.debug("Source is a dict, processing each stem...")
152
+
153
+ stem_list = []
154
+ if self.model_data_cfgdict.training.target_instrument:
155
+ stem_list = [self.model_data_cfgdict.training.target_instrument]
156
+ else:
157
+ stem_list = self.model_data_cfgdict.training.instruments
158
+
159
+ self.logger.debug(f"Available stems: {stem_list}")
160
+
161
+ is_multi_stem_model = len(stem_list) > 2
162
+ should_process_all_stems = self.process_all_stems and is_multi_stem_model
163
+
164
+ if should_process_all_stems:
165
+ self.logger.debug("Processing all stems from multi-stem model...")
166
+ for stem_name in stem_list:
167
+ stem_output_path = self.get_stem_output_path(stem_name, custom_output_names)
168
+ stem_source = spec_utils.normalize(
169
+ wave=source[stem_name],
170
+ max_peak=self.normalization_threshold,
171
+ min_peak=self.amplification_threshold
172
+ ).T
173
+
174
+ self.logger.info(f"Saving {stem_name} stem to {stem_output_path}...")
175
+ self.final_process(stem_output_path, stem_source, stem_name)
176
+ output_files.append(stem_output_path)
177
+ else:
178
+ # Standard processing for primary and secondary stems
179
+ if not isinstance(self.primary_source, np.ndarray):
180
+ self.logger.debug(f"Normalizing primary source for primary stem {self.primary_stem_name}...")
181
+ self.primary_source = spec_utils.normalize(
182
+ wave=source[self.primary_stem_name],
183
+ max_peak=self.normalization_threshold,
184
+ min_peak=self.amplification_threshold
185
+ ).T
186
+
187
+ if not isinstance(self.secondary_source, np.ndarray):
188
+ self.logger.debug(f"Normalizing secondary source for secondary stem {self.secondary_stem_name}...")
189
+ self.secondary_source = spec_utils.normalize(
190
+ wave=source[self.secondary_stem_name],
191
+ max_peak=self.normalization_threshold,
192
+ min_peak=self.amplification_threshold
193
+ ).T
194
+
195
+ if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
196
+ self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
197
+
198
+ self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
199
+ self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
200
+ output_files.append(self.secondary_stem_output_path)
201
+
202
+ if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
203
+ self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
204
+
205
+ self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
206
+ self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
207
+ output_files.append(self.primary_stem_output_path)
208
+
209
+ else:
210
+ # Handle case when source is not a dictionary (single source model)
211
+ if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
212
+ self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
213
+
214
+ if not isinstance(self.primary_source, np.ndarray):
215
+ self.primary_source = source.T
216
+
217
+ self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
218
+ self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
219
+ output_files.append(self.primary_stem_output_path)
220
+
221
+ return output_files
222
+
223
+ def pitch_fix(self, source, sr_pitched, orig_mix):
224
+ """
225
+ Change the pitch of the source audio by a number of semitones.
226
+
227
+ Args:
228
+ source (np.ndarray): The source audio to be pitch-shifted.
229
+ sr_pitched (int): The sample rate of the pitch-shifted audio.
230
+ orig_mix (np.ndarray): The original mix, used to match the shape of the pitch-shifted audio.
231
+
232
+ Returns:
233
+ np.ndarray: The pitch-shifted source audio.
234
+ """
235
+ source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=self.pitch_shift)[0]
236
+ source = spec_utils.match_array_shapes(source, orig_mix)
237
+ return source
238
+
239
+ def overlap_add(self, result, x, weights, start, length):
240
+ """
241
+ Adds the overlapping part of the result to the result tensor.
242
+ """
243
+ result[..., start : start + length] += x[..., :length] * weights[:length]
244
+ return result
245
+
246
+ def demix(self, mix: np.ndarray) -> dict:
247
+ """
248
+ Demixes the input mix into primary and secondary sources using the model and model data.
249
+
250
+ Args:
251
+ mix (np.ndarray): The mix to be demixed.
252
+ Returns:
253
+ dict: A dictionary containing the demixed sources.
254
+ """
255
+ orig_mix = mix
256
+
257
+ if self.pitch_shift != 0:
258
+ self.logger.debug(f"Shifting pitch by -{self.pitch_shift} semitones...")
259
+ mix, sample_rate = spec_utils.change_pitch_semitones(mix, self.sample_rate, semitone_shift=-self.pitch_shift)
260
+
261
+ if self.is_roformer:
262
+ # Note: Currently, for Roformer models, `batch_size` is not utilized due to negligible performance improvements.
263
+
264
+ mix = torch.tensor(mix, dtype=torch.float32)
265
+
266
+ if self.override_model_segment_size:
267
+ mdx_segment_size = self.segment_size
268
+ self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
269
+ else:
270
+ mdx_segment_size = self.model_data_cfgdict.inference.dim_t
271
+ self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
272
+
273
+ # num_stems aka "S" in UVR
274
+ num_stems = 1 if self.model_data_cfgdict.training.target_instrument else len(self.model_data_cfgdict.training.instruments)
275
+ self.logger.debug(f"Number of stems: {num_stems}")
276
+
277
+ # chunk_size aka "C" in UVR
278
+ chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
279
+ self.logger.debug(f"Chunk size: {chunk_size}")
280
+
281
+ step = int(self.overlap * self.model_data_cfgdict.audio.sample_rate)
282
+ self.logger.debug(f"Step: {step}")
283
+
284
+ # Create a weighting table and convert it to a PyTorch tensor
285
+ window = torch.tensor(signal.windows.hamming(chunk_size), dtype=torch.float32)
286
+
287
+ device = next(self.model_run.parameters()).device
288
+
289
+
290
+ with torch.no_grad():
291
+ req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
292
+ result = torch.zeros(req_shape, dtype=torch.float32)
293
+ counter = torch.zeros(req_shape, dtype=torch.float32)
294
+
295
+ for i in tqdm(range(0, mix.shape[1], step)):
296
+ part = mix[:, i : i + chunk_size]
297
+ length = part.shape[-1]
298
+ if i + chunk_size > mix.shape[1]:
299
+ part = mix[:, -chunk_size:]
300
+ length = chunk_size
301
+ part = part.to(device)
302
+ x = self.model_run(part.unsqueeze(0))[0]
303
+ x = x.cpu()
304
+ # Perform overlap_add on CPU
305
+ if i + chunk_size > mix.shape[1]:
306
+ # Fixed to correctly add to the end of the tensor
307
+ result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
308
+ counter[..., result.shape[-1] - chunk_size :] += window[:length]
309
+ else:
310
+ result = self.overlap_add(result, x, window, i, length)
311
+ counter[..., i : i + length] += window[:length]
312
+
313
+ inferenced_outputs = result / counter.clamp(min=1e-10)
314
+
315
+ else:
316
+ mix = torch.tensor(mix, dtype=torch.float32)
317
+
318
+ try:
319
+ num_stems = self.model_run.num_target_instruments
320
+ except AttributeError:
321
+ num_stems = self.model_run.module.num_target_instruments
322
+ self.logger.debug(f"Number of stems: {num_stems}")
323
+
324
+ if self.override_model_segment_size:
325
+ mdx_segment_size = self.segment_size
326
+ self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
327
+ else:
328
+ mdx_segment_size = self.model_data_cfgdict.inference.dim_t
329
+ self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
330
+
331
+ chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
332
+ self.logger.debug(f"Chunk size: {chunk_size}")
333
+
334
+ hop_size = chunk_size // self.overlap
335
+ self.logger.debug(f"Hop size: {hop_size}")
336
+
337
+ mix_shape = mix.shape[1]
338
+ pad_size = hop_size - (mix_shape - chunk_size) % hop_size
339
+ self.logger.debug(f"Pad size: {pad_size}")
340
+
341
+ mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
342
+ self.logger.debug(f"Mix shape: {mix.shape}")
343
+
344
+ chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
345
+ self.logger.debug(f"Chunks length: {len(chunks)} and shape: {chunks.shape}")
346
+
347
+ batches = [chunks[i : i + self.batch_size] for i in range(0, len(chunks), self.batch_size)]
348
+ self.logger.debug(f"Batch size: {self.batch_size}, number of batches: {len(batches)}")
349
+
350
+ # accumulated_outputs is used to accumulate the output from processing each batch of chunks through the model.
351
+ # It starts as a tensor of zeros and is updated in-place as the model processes each batch.
352
+ # The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
353
+ accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
354
+
355
+ with torch.no_grad():
356
+ count = 0
357
+ for batch in tqdm(batches):
358
+ # Since the model processes the audio data in batches, single_batch_result temporarily holds the model's output
359
+ # for each batch before it is accumulated into accumulated_outputs.
360
+ single_batch_result = self.model_run(batch.to(self.torch_device))
361
+
362
+ # Each individual output tensor from the current batch's processing result.
363
+ # Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
364
+ # individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
365
+ for individual_output in single_batch_result:
366
+ individual_output_cpu = individual_output.cpu()
367
+ # Accumulate outputs on CPU
368
+ accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
369
+ count += 1
370
+
371
+ self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
372
+ inferenced_outputs = accumulated_outputs[..., chunk_size - hop_size : -(pad_size + chunk_size - hop_size)] / self.overlap
373
+ self.logger.debug("Deleting accumulated outputs to free up memory")
374
+ del accumulated_outputs
375
+
376
+ if num_stems > 1 or self.is_primary_stem_main_target:
377
+ self.logger.debug("Number of stems is greater than 1 or vocals are main target, detaching individual sources and correcting pitch if necessary...")
378
+
379
+ sources = {}
380
+
381
+ # Iterates over each instrument specified in the model's configuration and its corresponding separated audio source.
382
+ # self.model_data_cfgdict.training.instruments provides the list of stems.
383
+ # estimated_sources.cpu().detach().numpy() converts the separated sources tensor to a NumPy array for processing.
384
+ # Each iteration provides an instrument name ('key') and its separated audio ('value') for further processing.
385
+ for key, value in zip(self.model_data_cfgdict.training.instruments, inferenced_outputs.cpu().detach().numpy()):
386
+ self.logger.debug(f"Processing instrument: {key}")
387
+ if self.pitch_shift != 0:
388
+ self.logger.debug(f"Applying pitch correction for {key}")
389
+ sources[key] = self.pitch_fix(value, sample_rate, orig_mix)
390
+ else:
391
+ sources[key] = value
392
+
393
+ if self.is_primary_stem_main_target:
394
+ self.logger.debug(f"Primary stem: {self.primary_stem_name} is main target, detaching and matching array shapes if necessary...")
395
+ if sources[self.primary_stem_name].shape[1] != orig_mix.shape[1]:
396
+ sources[self.primary_stem_name] = spec_utils.match_array_shapes(sources[self.primary_stem_name], orig_mix)
397
+ sources[self.secondary_stem_name] = orig_mix - sources[self.primary_stem_name]
398
+
399
+ self.logger.debug("Deleting inferenced outputs to free up memory")
400
+ del inferenced_outputs
401
+
402
+ self.logger.debug("Returning separated sources")
403
+ return sources
404
+ else:
405
+ self.logger.debug("Processing single source...")
406
+
407
+ if self.is_roformer:
408
+ sources = {k: v.cpu().detach().numpy() for k, v in zip([self.model_data_cfgdict.training.target_instrument], inferenced_outputs)}
409
+ inferenced_output = sources[self.model_data_cfgdict.training.target_instrument]
410
+ else:
411
+ inferenced_output = inferenced_outputs.cpu().detach().numpy()
412
+
413
+ self.logger.debug("Demix process completed for single source.")
414
+
415
+ self.logger.debug("Deleting inferenced outputs to free up memory")
416
+ del inferenced_outputs
417
+
418
+ if self.pitch_shift != 0:
419
+ self.logger.debug("Applying pitch correction for single instrument")
420
+ return self.pitch_fix(inferenced_output, sample_rate, orig_mix)
421
+ else:
422
+ self.logger.debug("Returning inferenced output for single instrument")
423
+ return inferenced_output
audio_separator/separator/architectures/vr_separator.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for separating audio sources using VR architecture models."""
2
+
3
+ import os
4
+ import math
5
+
6
+ import torch
7
+ import librosa
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ # Check if we really need the rerun_mp3 function, remove if not
12
+ import audioread
13
+
14
+ from audio_separator.separator.common_separator import CommonSeparator
15
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
16
+ from audio_separator.separator.uvr_lib_v5.vr_network import nets
17
+ from audio_separator.separator.uvr_lib_v5.vr_network import nets_new
18
+ from audio_separator.separator.uvr_lib_v5.vr_network.model_param_init import ModelParameters
19
+
20
+
21
+ class VRSeparator(CommonSeparator):
22
+ """
23
+ VRSeparator is responsible for separating audio sources using VR models.
24
+ It initializes with configuration parameters and prepares the model for separation tasks.
25
+ """
26
+
27
+ def __init__(self, common_config, arch_config: dict):
28
+ # Any configuration values which can be shared between architectures should be set already in CommonSeparator,
29
+ # e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
30
+ super().__init__(config=common_config)
31
+
32
+ # Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
33
+ # It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
34
+ # The instance variable self.model_data is passed through from Separator and set in CommonSeparator
35
+ self.logger.debug(f"Model data: {self.model_data}")
36
+
37
+ # Most of the VR models use the same number of output channels, but the VR 51 models have specific values set in model_data JSON
38
+ self.model_capacity = 32, 128
39
+ self.is_vr_51_model = False
40
+
41
+ if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys():
42
+ self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
43
+ self.is_vr_51_model = True
44
+
45
+ # Model params are additional technical parameter values from JSON files in separator/uvr_lib_v5/vr_network/modelparams/*.json,
46
+ # with filenames referenced by the model_data["vr_model_param"] value
47
+ package_root_filepath = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
48
+ vr_params_json_dir = os.path.join(package_root_filepath, "uvr_lib_v5", "vr_network", "modelparams")
49
+ vr_params_json_filename = f"{self.model_data['vr_model_param']}.json"
50
+ vr_params_json_filepath = os.path.join(vr_params_json_dir, vr_params_json_filename)
51
+ self.model_params = ModelParameters(vr_params_json_filepath)
52
+
53
+ self.logger.debug(f"Model params: {self.model_params.param}")
54
+
55
+ # Arch Config is the VR architecture specific user configuration options, which should all be configurable by the user
56
+ # either by their Separator class instantiation or by passing in a CLI parameter.
57
+ # While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
58
+ # this way as they have architecture-specific default values.
59
+
60
+ # This option performs Test-Time-Augmentation to improve the separation quality.
61
+ # Note: Having this selected will increase the time it takes to complete a conversion
62
+ self.enable_tta = arch_config.get("enable_tta", False)
63
+
64
+ # This option can potentially identify leftover instrumental artifacts within the vocal outputs; may improve the separation of some songs.
65
+ # Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.
66
+ self.enable_post_process = arch_config.get("enable_post_process", False)
67
+
68
+ # post_process_threshold values = ('0.1', '0.2', '0.3')
69
+ self.post_process_threshold = arch_config.get("post_process_threshold", 0.2)
70
+
71
+ # Number of batches to be processed at a time.
72
+ # - Higher values mean more RAM usage but slightly faster processing times.
73
+ # - Lower values mean less RAM usage but slightly longer processing times.
74
+ # - Batch size value has no effect on output quality.
75
+
76
+ # Andrew note: for some reason, lower batch sizes seem to cause broken output for VR arch; need to investigate why
77
+ self.batch_size = arch_config.get("batch_size", 1)
78
+
79
+ # Select window size to balance quality and speed:
80
+ # - 1024 - Quick but lesser quality.
81
+ # - 512 - Medium speed and quality.
82
+ # - 320 - Takes longer but may offer better quality.
83
+ self.window_size = arch_config.get("window_size", 512)
84
+
85
+ # The application will mirror the missing frequency range of the output.
86
+ self.high_end_process = arch_config.get("high_end_process", False)
87
+ self.input_high_end_h = None
88
+ self.input_high_end = None
89
+
90
+ # Adjust the intensity of primary stem extraction:
91
+ # - Ranges from -100 - 100.
92
+ # - Bigger values mean deeper extractions.
93
+ # - Typically, it's set to 5 for vocals & instrumentals.
94
+ # - Values beyond 5 might muddy the sound for non-vocal models.
95
+ self.aggression = float(int(arch_config.get("aggression", 5)) / 100)
96
+
97
+ self.aggressiveness = {"value": self.aggression, "split_bin": self.model_params.param["band"][1]["crop_stop"], "aggr_correction": self.model_params.param.get("aggr_correction")}
98
+
99
+ self.model_samplerate = self.model_params.param["sr"]
100
+
101
+ self.logger.debug(f"VR arch params: enable_tta={self.enable_tta}, enable_post_process={self.enable_post_process}, post_process_threshold={self.post_process_threshold}")
102
+ self.logger.debug(f"VR arch params: batch_size={self.batch_size}, window_size={self.window_size}")
103
+ self.logger.debug(f"VR arch params: high_end_process={self.high_end_process}, aggression={self.aggression}")
104
+ self.logger.debug(f"VR arch params: is_vr_51_model={self.is_vr_51_model}, model_samplerate={self.model_samplerate}, model_capacity={self.model_capacity}")
105
+
106
+ self.model_run = lambda *args, **kwargs: self.logger.error("Model run method is not initialised yet.")
107
+
108
+ # This should go away once we refactor to remove soundfile.write and replace with pydub like we did for the MDX rewrite
109
+ self.wav_subtype = "PCM_16"
110
+
111
+ self.logger.info("VR Separator initialisation complete")
112
+
113
+ def separate(self, audio_file_path, custom_output_names=None):
114
+ """
115
+ Separates the audio file into primary and secondary sources based on the model's configuration.
116
+ It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
117
+
118
+ Args:
119
+ audio_file_path (str): The path to the audio file to be processed.
120
+ custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
121
+
122
+ Returns:
123
+ list: A list of paths to the output files generated by the separation process.
124
+ """
125
+ self.primary_source = None
126
+ self.secondary_source = None
127
+
128
+ self.audio_file_path = audio_file_path
129
+ self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
130
+
131
+ self.logger.debug(f"Starting separation for input audio file {self.audio_file_path}...")
132
+
133
+ nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
134
+ vr_5_1_models = [56817, 218409]
135
+ model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
136
+ nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
137
+ self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}")
138
+
139
+ if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
140
+ self.logger.debug("Using CascadedNet for VR 5.1 model...")
141
+ self.model_run = nets_new.CascadedNet(self.model_params.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
142
+ self.is_vr_51_model = True
143
+ else:
144
+ self.logger.debug("Determining model capacity...")
145
+ self.model_run = nets.determine_model_capacity(self.model_params.param["bins"] * 2, nn_arch_size)
146
+
147
+ self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
148
+ self.model_run.to(self.torch_device)
149
+ self.logger.debug("Model loaded and moved to device.")
150
+
151
+ y_spec, v_spec = self.inference_vr(self.loading_mix(), self.torch_device, self.aggressiveness)
152
+ self.logger.debug("Inference completed.")
153
+
154
+ # Sanitize y_spec and v_spec to replace NaN and infinite values
155
+ y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0)
156
+ v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0)
157
+
158
+ self.logger.debug("Sanitization completed. Replaced NaN and infinite values in y_spec and v_spec.")
159
+
160
+ # After inference_vr call
161
+ self.logger.debug(f"Inference VR completed. y_spec shape: {y_spec.shape}, v_spec shape: {v_spec.shape}")
162
+ self.logger.debug(f"y_spec stats - min: {np.min(y_spec)}, max: {np.max(y_spec)}, isnan: {np.isnan(y_spec).any()}, isinf: {np.isinf(y_spec).any()}")
163
+ self.logger.debug(f"v_spec stats - min: {np.min(v_spec)}, max: {np.max(v_spec)}, isnan: {np.isnan(v_spec).any()}, isinf: {np.isinf(v_spec).any()}")
164
+
165
+ # Not yet implemented from UVR features:
166
+ #
167
+ # if not self.is_vocal_split_model:
168
+ # self.cache_source((y_spec, v_spec))
169
+
170
+ # if self.is_secondary_model_activated and self.secondary_model:
171
+ # self.logger.debug("Processing secondary model...")
172
+ # self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(
173
+ # self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem
174
+ # )
175
+
176
+ # Initialize the list for output files
177
+ output_files = []
178
+ self.logger.debug("Processing output files...")
179
+
180
+ # Note: logic similar to the following should probably be added to the other architectures
181
+ # Check if output_single_stem is set to a value that would result in no output files
182
+ if self.output_single_stem and (self.output_single_stem.lower() != self.primary_stem_name.lower() and self.output_single_stem.lower() != self.secondary_stem_name.lower()):
183
+ # If so, reset output_single_stem to None to save both stems
184
+ self.output_single_stem = None
185
+ self.logger.warning(f"The output_single_stem setting '{self.output_single_stem}' does not match any of the output files: '{self.primary_stem_name}' and '{self.secondary_stem_name}'. For this model '{self.model_name}', the output_single_stem setting will be ignored and all output files will be saved.")
186
+
187
+ # Save and process the primary stem if needed
188
+ if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
189
+ self.logger.debug(f"Processing primary stem: {self.primary_stem_name}")
190
+ if not isinstance(self.primary_source, np.ndarray):
191
+ self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {y_spec.shape}")
192
+
193
+ self.primary_source = self.spec_to_wav(y_spec).T
194
+ self.logger.debug("Converting primary source spectrogram to waveform.")
195
+ if not self.model_samplerate == 44100:
196
+ self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
197
+ self.logger.debug("Resampling primary source to 44100Hz.")
198
+
199
+ self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
200
+
201
+ self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
202
+ self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
203
+ output_files.append(self.primary_stem_output_path)
204
+
205
+ # Save and process the secondary stem if needed
206
+ if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
207
+ self.logger.debug(f"Processing secondary stem: {self.secondary_stem_name}")
208
+ if not isinstance(self.secondary_source, np.ndarray):
209
+ self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {v_spec.shape}")
210
+
211
+ self.secondary_source = self.spec_to_wav(v_spec).T
212
+ self.logger.debug("Converting secondary source spectrogram to waveform.")
213
+ if not self.model_samplerate == 44100:
214
+ self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
215
+ self.logger.debug("Resampling secondary source to 44100Hz.")
216
+
217
+ self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
218
+
219
+ self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
220
+ self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
221
+ output_files.append(self.secondary_stem_output_path)
222
+
223
+ # Not yet implemented from UVR features:
224
+ # self.process_vocal_split_chain(secondary_sources)
225
+ # self.logger.debug("Vocal split chain processed.")
226
+
227
+ return output_files
228
+
229
+ def loading_mix(self):
230
+ X_wave, X_spec_s = {}, {}
231
+
232
+ bands_n = len(self.model_params.param["band"])
233
+
234
+ audio_file = spec_utils.write_array_to_mem(self.audio_file_path, subtype=self.wav_subtype)
235
+ is_mp3 = audio_file.endswith(".mp3") if isinstance(audio_file, str) else False
236
+
237
+ self.logger.debug(f"loading_mix iteraring through {bands_n} bands")
238
+ for d in tqdm(range(bands_n, 0, -1)):
239
+ bp = self.model_params.param["band"][d]
240
+
241
+ wav_resolution = bp["res_type"]
242
+
243
+ if self.torch_device_mps is not None:
244
+ wav_resolution = "polyphase"
245
+
246
+ if d == bands_n: # high-end band
247
+ X_wave[d], _ = librosa.load(audio_file, sr=bp["sr"], mono=False, dtype=np.float32, res_type=wav_resolution)
248
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model)
249
+
250
+ if not np.any(X_wave[d]) and is_mp3:
251
+ X_wave[d] = rerun_mp3(audio_file, bp["sr"])
252
+
253
+ if X_wave[d].ndim == 1:
254
+ X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
255
+ else: # lower bands
256
+ X_wave[d] = librosa.resample(X_wave[d + 1], orig_sr=self.model_params.param["band"][d + 1]["sr"], target_sr=bp["sr"], res_type=wav_resolution)
257
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model)
258
+
259
+ if d == bands_n and self.high_end_process:
260
+ self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (self.model_params.param["pre_filter_stop"] - self.model_params.param["pre_filter_start"])
261
+ self.input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, :]
262
+
263
+ X_spec = spec_utils.combine_spectrograms(X_spec_s, self.model_params, is_v51_model=self.is_vr_51_model)
264
+
265
+ del X_wave, X_spec_s, audio_file
266
+
267
+ return X_spec
268
+
269
+ def inference_vr(self, X_spec, device, aggressiveness):
270
+ def _execute(X_mag_pad, roi_size):
271
+ X_dataset = []
272
+ patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
273
+
274
+ self.logger.debug(f"inference_vr appending to X_dataset for each of {patches} patches")
275
+ for i in tqdm(range(patches)):
276
+ start = i * roi_size
277
+ X_mag_window = X_mag_pad[:, :, start : start + self.window_size]
278
+ X_dataset.append(X_mag_window)
279
+
280
+ total_iterations = patches // self.batch_size if not self.enable_tta else (patches // self.batch_size) * 2
281
+ self.logger.debug(f"inference_vr iterating through {total_iterations} batches, batch_size = {self.batch_size}")
282
+
283
+ X_dataset = np.asarray(X_dataset)
284
+ self.model_run.eval()
285
+ with torch.no_grad():
286
+ mask = []
287
+
288
+ for i in tqdm(range(0, patches, self.batch_size)):
289
+
290
+ X_batch = X_dataset[i : i + self.batch_size]
291
+ X_batch = torch.from_numpy(X_batch).to(device)
292
+ pred = self.model_run.predict_mask(X_batch)
293
+ if not pred.size()[3] > 0:
294
+ raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]")
295
+ pred = pred.detach().cpu().numpy()
296
+ pred = np.concatenate(pred, axis=2)
297
+ mask.append(pred)
298
+ if len(mask) == 0:
299
+ raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]")
300
+
301
+ mask = np.concatenate(mask, axis=2)
302
+ return mask
303
+
304
+ def postprocess(mask, X_mag, X_phase):
305
+ is_non_accom_stem = False
306
+ for stem in CommonSeparator.NON_ACCOM_STEMS:
307
+ if stem == self.primary_stem_name:
308
+ is_non_accom_stem = True
309
+
310
+ mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
311
+
312
+ if self.enable_post_process:
313
+ mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold)
314
+
315
+ y_spec = mask * X_mag * np.exp(1.0j * X_phase)
316
+ v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
317
+
318
+ return y_spec, v_spec
319
+
320
+ X_mag, X_phase = spec_utils.preprocess(X_spec)
321
+ n_frame = X_mag.shape[2]
322
+ pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
323
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
324
+ X_mag_pad /= X_mag_pad.max()
325
+ mask = _execute(X_mag_pad, roi_size)
326
+
327
+ if self.enable_tta:
328
+ pad_l += roi_size // 2
329
+ pad_r += roi_size // 2
330
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
331
+ X_mag_pad /= X_mag_pad.max()
332
+ mask_tta = _execute(X_mag_pad, roi_size)
333
+ mask_tta = mask_tta[:, :, roi_size // 2 :]
334
+ mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
335
+ else:
336
+ mask = mask[:, :, :n_frame]
337
+
338
+ y_spec, v_spec = postprocess(mask, X_mag, X_phase)
339
+
340
+ return y_spec, v_spec
341
+
342
+ def spec_to_wav(self, spec):
343
+ if self.high_end_process and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h:
344
+ input_high_end_ = spec_utils.mirroring("mirroring", spec, self.input_high_end, self.model_params)
345
+ wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model)
346
+ else:
347
+ wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, is_v51_model=self.is_vr_51_model)
348
+
349
+ return wav
350
+
351
+
352
+ # Check if we really need the rerun_mp3 function, refactor or remove if not
353
+ def rerun_mp3(audio_file, sample_rate=44100):
354
+ with audioread.audio_open(audio_file) as f:
355
+ track_length = int(f.duration)
356
+
357
+ return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0]
audio_separator/separator/common_separator.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ This file contains the CommonSeparator class, common to all architecture-specific Separator classes. """
2
+
3
+ from logging import Logger
4
+ import os
5
+ import re
6
+ import gc
7
+ import numpy as np
8
+ import librosa
9
+ import torch
10
+ from pydub import AudioSegment
11
+ import soundfile as sf
12
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
13
+
14
+
15
+ class CommonSeparator:
16
+ """
17
+ This class contains the common methods and attributes common to all architecture-specific Separator classes.
18
+ """
19
+
20
+ ALL_STEMS = "All Stems"
21
+ VOCAL_STEM = "Vocals"
22
+ INST_STEM = "Instrumental"
23
+ OTHER_STEM = "Other"
24
+ BASS_STEM = "Bass"
25
+ DRUM_STEM = "Drums"
26
+ GUITAR_STEM = "Guitar"
27
+ PIANO_STEM = "Piano"
28
+ SYNTH_STEM = "Synthesizer"
29
+ STRINGS_STEM = "Strings"
30
+ WOODWINDS_STEM = "Woodwinds"
31
+ BRASS_STEM = "Brass"
32
+ WIND_INST_STEM = "Wind Inst"
33
+ NO_OTHER_STEM = "No Other"
34
+ NO_BASS_STEM = "No Bass"
35
+ NO_DRUM_STEM = "No Drums"
36
+ NO_GUITAR_STEM = "No Guitar"
37
+ NO_PIANO_STEM = "No Piano"
38
+ NO_SYNTH_STEM = "No Synthesizer"
39
+ NO_STRINGS_STEM = "No Strings"
40
+ NO_WOODWINDS_STEM = "No Woodwinds"
41
+ NO_WIND_INST_STEM = "No Wind Inst"
42
+ NO_BRASS_STEM = "No Brass"
43
+ PRIMARY_STEM = "Primary Stem"
44
+ SECONDARY_STEM = "Secondary Stem"
45
+ LEAD_VOCAL_STEM = "lead_only"
46
+ BV_VOCAL_STEM = "backing_only"
47
+ LEAD_VOCAL_STEM_I = "with_lead_vocals"
48
+ BV_VOCAL_STEM_I = "with_backing_vocals"
49
+ LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
50
+ BV_VOCAL_STEM_LABEL = "Backing Vocals"
51
+ NO_STEM = "No "
52
+
53
+ STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
54
+
55
+ NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
56
+
57
+ def __init__(self, config):
58
+
59
+ self.logger: Logger = config.get("logger")
60
+ self.log_level: int = config.get("log_level")
61
+
62
+ # Inferencing device / acceleration config
63
+ self.torch_device = config.get("torch_device")
64
+ self.torch_device_cpu = config.get("torch_device_cpu")
65
+ self.torch_device_mps = config.get("torch_device_mps")
66
+ self.onnx_execution_provider = config.get("onnx_execution_provider")
67
+
68
+ # Model data
69
+ self.model_name = config.get("model_name")
70
+ self.model_path = config.get("model_path")
71
+ self.model_data = config.get("model_data")
72
+
73
+ # Output directory and format
74
+ self.output_dir = config.get("output_dir")
75
+ self.output_format = config.get("output_format")
76
+ self.output_bitrate = config.get("output_bitrate")
77
+
78
+ # Functional options which are applicable to all architectures and the user may tweak to affect the output
79
+ self.normalization_threshold = config.get("normalization_threshold")
80
+ self.amplification_threshold = config.get("amplification_threshold")
81
+ self.enable_denoise = config.get("enable_denoise")
82
+ self.output_single_stem = config.get("output_single_stem")
83
+ self.invert_using_spec = config.get("invert_using_spec")
84
+ self.sample_rate = config.get("sample_rate")
85
+ self.use_soundfile = config.get("use_soundfile")
86
+
87
+ # Model specific properties
88
+
89
+ # Check if model_data has a "training" key with "instruments" list
90
+ self.primary_stem_name = None
91
+ self.secondary_stem_name = None
92
+
93
+ if "training" in self.model_data and "instruments" in self.model_data["training"]:
94
+ instruments = self.model_data["training"]["instruments"]
95
+ if instruments:
96
+ self.primary_stem_name = instruments[0]
97
+ self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
98
+
99
+ if self.primary_stem_name is None:
100
+ self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
101
+ self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
102
+
103
+ self.is_karaoke = self.model_data.get("is_karaoke", False)
104
+ self.is_bv_model = self.model_data.get("is_bv_model", False)
105
+ self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
106
+
107
+ self.logger.debug(f"Common params: model_name={self.model_name}, model_path={self.model_path}")
108
+ self.logger.debug(f"Common params: output_dir={self.output_dir}, output_format={self.output_format}")
109
+ self.logger.debug(f"Common params: normalization_threshold={self.normalization_threshold}, amplification_threshold={self.amplification_threshold}")
110
+ self.logger.debug(f"Common params: enable_denoise={self.enable_denoise}, output_single_stem={self.output_single_stem}")
111
+ self.logger.debug(f"Common params: invert_using_spec={self.invert_using_spec}, sample_rate={self.sample_rate}")
112
+
113
+ self.logger.debug(f"Common params: primary_stem_name={self.primary_stem_name}, secondary_stem_name={self.secondary_stem_name}")
114
+ self.logger.debug(f"Common params: is_karaoke={self.is_karaoke}, is_bv_model={self.is_bv_model}, bv_model_rebalance={self.bv_model_rebalance}")
115
+
116
+ # File-specific variables which need to be cleared between processing different audio inputs
117
+ self.audio_file_path = None
118
+ self.audio_file_base = None
119
+
120
+ self.primary_source = None
121
+ self.secondary_source = None
122
+
123
+ self.primary_stem_output_path = None
124
+ self.secondary_stem_output_path = None
125
+
126
+ self.cached_sources_map = {}
127
+
128
+ def secondary_stem(self, primary_stem: str):
129
+ """Determines secondary stem name based on the primary stem name."""
130
+ primary_stem = primary_stem if primary_stem else self.NO_STEM
131
+
132
+ if primary_stem in self.STEM_PAIR_MAPPER:
133
+ secondary_stem = self.STEM_PAIR_MAPPER[primary_stem]
134
+ else:
135
+ secondary_stem = primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
136
+
137
+ return secondary_stem
138
+
139
+ def separate(self, audio_file_path):
140
+ """
141
+ Placeholder method for separating audio sources. Should be overridden by subclasses.
142
+ """
143
+ raise NotImplementedError("This method should be overridden by subclasses.")
144
+
145
+ def final_process(self, stem_path, source, stem_name):
146
+ """
147
+ Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
148
+ """
149
+ self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
150
+ self.write_audio(stem_path, source)
151
+
152
+ return {stem_name: source}
153
+
154
+ def cached_sources_clear(self):
155
+ """
156
+ Clears the cache dictionaries for VR, MDX, and Demucs models.
157
+
158
+ This function is essential for ensuring that the cache does not hold outdated or irrelevant data
159
+ between different processing sessions or when a new batch of audio files is processed.
160
+ It helps in managing memory efficiently and prevents potential errors due to stale data.
161
+ """
162
+ self.cached_sources_map = {}
163
+
164
+ def cached_source_callback(self, model_architecture, model_name=None):
165
+ """
166
+ Retrieves the model and sources from the cache based on the processing method and model name.
167
+
168
+ Args:
169
+ model_architecture: The architecture type (VR, MDX, or Demucs) being used for processing.
170
+ model_name: The specific model name within the architecture type, if applicable.
171
+
172
+ Returns:
173
+ A tuple containing the model and its sources if found in the cache; otherwise, None.
174
+
175
+ This function is crucial for optimizing performance by avoiding redundant processing.
176
+ If the requested model and its sources are already in the cache, they can be reused directly,
177
+ saving time and computational resources.
178
+ """
179
+ model, sources = None, None
180
+
181
+ mapper = self.cached_sources_map[model_architecture]
182
+
183
+ for key, value in mapper.items():
184
+ if model_name in key:
185
+ model = key
186
+ sources = value
187
+
188
+ return model, sources
189
+
190
+ def cached_model_source_holder(self, model_architecture, sources, model_name=None):
191
+ """
192
+ Update the dictionary for the given model_architecture with the new model name and its sources.
193
+ Use the model_architecture as a key to access the corresponding cache source mapper dictionary.
194
+ """
195
+ self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
196
+
197
+ def prepare_mix(self, mix):
198
+ """
199
+ Prepares the mix for processing. This includes loading the audio from a file if necessary,
200
+ ensuring the mix is in the correct format, and converting mono to stereo if needed.
201
+ """
202
+ # Store the original path or the mix itself for later checks
203
+ audio_path = mix
204
+
205
+ # Check if the input is a file path (string) and needs to be loaded
206
+ if not isinstance(mix, np.ndarray):
207
+ self.logger.debug(f"Loading audio from file: {mix}")
208
+ mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
209
+ self.logger.debug(f"Audio loaded. Sample rate: {sr}, Audio shape: {mix.shape}")
210
+ else:
211
+ # Transpose the mix if it's already an ndarray (expected shape: [channels, samples])
212
+ self.logger.debug("Transposing the provided mix array.")
213
+ mix = mix.T
214
+ self.logger.debug(f"Transposed mix shape: {mix.shape}")
215
+
216
+ # If the original input was a filepath, check if the loaded mix is empty
217
+ if isinstance(audio_path, str):
218
+ if not np.any(mix):
219
+ error_msg = f"Audio file {audio_path} is empty or not valid"
220
+ self.logger.error(error_msg)
221
+ raise ValueError(error_msg)
222
+ else:
223
+ self.logger.debug("Audio file is valid and contains data.")
224
+
225
+ # Ensure the mix is in stereo format
226
+ if mix.ndim == 1:
227
+ self.logger.debug("Mix is mono. Converting to stereo.")
228
+ mix = np.asfortranarray([mix, mix])
229
+ self.logger.debug("Converted to stereo mix.")
230
+
231
+ # Final log indicating successful preparation of the mix
232
+ self.logger.debug("Mix preparation completed.")
233
+ return mix
234
+
235
+ def write_audio(self, stem_path: str, stem_source):
236
+ """
237
+ Writes the separated audio source to a file using pydub or soundfile
238
+ Pydub supports a much wider range of audio formats and produces better encoded lossy files for some formats.
239
+ Soundfile is used for very large files (longer than 1 hour), as pydub has memory issues with large files:
240
+ https://github.com/jiaaro/pydub/issues/135
241
+ """
242
+ # Get the duration of the input audio file
243
+ duration_seconds = librosa.get_duration(filename=self.audio_file_path)
244
+ duration_hours = duration_seconds / 3600
245
+ self.logger.info(f"Audio duration is {duration_hours:.2f} hours ({duration_seconds:.2f} seconds).")
246
+
247
+ if self.use_soundfile:
248
+ self.logger.warning(f"Using soundfile for writing.")
249
+ self.write_audio_soundfile(stem_path, stem_source)
250
+ else:
251
+ self.logger.info(f"Using pydub for writing.")
252
+ self.write_audio_pydub(stem_path, stem_source)
253
+
254
+ def write_audio_pydub(self, stem_path: str, stem_source):
255
+ """
256
+ Writes the separated audio source to a file using pydub (ffmpeg)
257
+ """
258
+ self.logger.debug(f"Entering write_audio_pydub with stem_path: {stem_path}")
259
+
260
+ stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
261
+
262
+ # Check if the numpy array is empty or contains very low values
263
+ if np.max(np.abs(stem_source)) < 1e-6:
264
+ self.logger.warning("Warning: stem_source array is near-silent or empty.")
265
+ return
266
+
267
+ # If output_dir is specified, create it and join it with stem_path
268
+ if self.output_dir:
269
+ os.makedirs(self.output_dir, exist_ok=True)
270
+ stem_path = os.path.join(self.output_dir, stem_path)
271
+
272
+ self.logger.debug(f"Audio data shape before processing: {stem_source.shape}")
273
+ self.logger.debug(f"Data type before conversion: {stem_source.dtype}")
274
+
275
+ # Ensure the audio data is in the correct format (e.g., int16)
276
+ if stem_source.dtype != np.int16:
277
+ stem_source = (stem_source * 32767).astype(np.int16)
278
+ self.logger.debug("Converted stem_source to int16.")
279
+
280
+ # Correctly interleave stereo channels
281
+ stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
282
+ stem_source_interleaved[0::2] = stem_source[:, 0] # Left channel
283
+ stem_source_interleaved[1::2] = stem_source[:, 1] # Right channel
284
+
285
+ self.logger.debug(f"Interleaved audio data shape: {stem_source_interleaved.shape}")
286
+
287
+ # Create a pydub AudioSegment
288
+ try:
289
+ audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
290
+ self.logger.debug("Created AudioSegment successfully.")
291
+ except (IOError, ValueError) as e:
292
+ self.logger.error(f"Specific error creating AudioSegment: {e}")
293
+ return
294
+
295
+ # Determine file format based on the file extension
296
+ file_format = stem_path.lower().split(".")[-1]
297
+
298
+ # For m4a files, specify mp4 as the container format as the extension doesn't match the format name
299
+ if file_format == "m4a":
300
+ file_format = "mp4"
301
+ elif file_format == "mka":
302
+ file_format = "matroska"
303
+
304
+ # Set the bitrate to 320k for mp3 files if output_bitrate is not specified
305
+ bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate
306
+
307
+ # Export using the determined format
308
+ try:
309
+ audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
310
+ self.logger.debug(f"Exported audio file successfully to {stem_path}")
311
+ except (IOError, ValueError) as e:
312
+ self.logger.error(f"Error exporting audio file: {e}")
313
+
314
+ def write_audio_soundfile(self, stem_path: str, stem_source):
315
+ """
316
+ Writes the separated audio source to a file using soundfile library.
317
+ """
318
+ self.logger.debug(f"Entering write_audio_soundfile with stem_path: {stem_path}")
319
+
320
+ # Correctly interleave stereo channels if needed
321
+ if stem_source.shape[1] == 2:
322
+ # If the audio is already interleaved, ensure it's in the correct order
323
+ # Check if the array is Fortran contiguous (column-major)
324
+ if stem_source.flags["F_CONTIGUOUS"]:
325
+ # Convert to C contiguous (row-major)
326
+ stem_source = np.ascontiguousarray(stem_source)
327
+ # Otherwise, perform interleaving
328
+ else:
329
+ stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
330
+ # Left channel
331
+ stereo_interleaved[0::2] = stem_source[:, 0]
332
+ # Right channel
333
+ stereo_interleaved[1::2] = stem_source[:, 1]
334
+ stem_source = stereo_interleaved
335
+
336
+ self.logger.debug(f"Interleaved audio data shape: {stem_source.shape}")
337
+
338
+ """
339
+ Write audio using soundfile (for formats other than M4A).
340
+ """
341
+ # Save audio using soundfile
342
+ try:
343
+ # Specify the subtype to define the sample width
344
+ sf.write(stem_path, stem_source, self.sample_rate)
345
+ self.logger.debug(f"Exported audio file successfully to {stem_path}")
346
+ except Exception as e:
347
+ self.logger.error(f"Error exporting audio file: {e}")
348
+
349
+ def clear_gpu_cache(self):
350
+ """
351
+ This method clears the GPU cache to free up memory.
352
+ """
353
+ self.logger.debug("Running garbage collection...")
354
+ gc.collect()
355
+ if self.torch_device == torch.device("mps"):
356
+ self.logger.debug("Clearing MPS cache...")
357
+ torch.mps.empty_cache()
358
+ if self.torch_device == torch.device("cuda"):
359
+ self.logger.debug("Clearing CUDA cache...")
360
+ torch.cuda.empty_cache()
361
+
362
+ def clear_file_specific_paths(self):
363
+ """
364
+ Clears the file-specific variables which need to be cleared between processing different audio inputs.
365
+ """
366
+ self.logger.info("Clearing input audio file paths, sources and stems...")
367
+
368
+ self.audio_file_path = None
369
+ self.audio_file_base = None
370
+
371
+ self.primary_source = None
372
+ self.secondary_source = None
373
+
374
+ self.primary_stem_output_path = None
375
+ self.secondary_stem_output_path = None
376
+
377
+ def sanitize_filename(self, filename):
378
+ """
379
+ Cleans the filename by replacing invalid characters with underscores.
380
+ """
381
+ sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename)
382
+ sanitized = re.sub(r'_+', '_', sanitized)
383
+ sanitized = sanitized.strip('_. ')
384
+ return sanitized
385
+
386
+ def get_stem_output_path(self, stem_name, custom_output_names):
387
+ """
388
+ Gets the output path for a stem based on the stem name and custom output names.
389
+ """
390
+ # Convert custom_output_names keys to lowercase for case-insensitive comparison
391
+ if custom_output_names:
392
+ custom_output_names_lower = {k.lower(): v for k, v in custom_output_names.items()}
393
+ stem_name_lower = stem_name.lower()
394
+ if stem_name_lower in custom_output_names_lower:
395
+ sanitized_custom_name = self.sanitize_filename(custom_output_names_lower[stem_name_lower])
396
+ return os.path.join(f"{sanitized_custom_name}.{self.output_format.lower()}")
397
+
398
+ sanitized_audio_base = self.sanitize_filename(self.audio_file_base)
399
+ sanitized_stem_name = self.sanitize_filename(stem_name)
400
+ sanitized_model_name = self.sanitize_filename(self.model_name)
401
+
402
+ filename = f"{sanitized_audio_base}_({sanitized_stem_name})_{sanitized_model_name}.{self.output_format.lower()}"
403
+ return os.path.join(filename)
audio_separator/separator/separator.py ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ This file contains the Separator class, to facilitate the separation of stems from audio. """
2
+
3
+ from importlib import metadata, resources
4
+ import os
5
+ import sys
6
+ import platform
7
+ import subprocess
8
+ import time
9
+ import logging
10
+ import warnings
11
+ import importlib
12
+ import io
13
+ from typing import Optional
14
+
15
+ import hashlib
16
+ import json
17
+ import yaml
18
+ import requests
19
+ import torch
20
+ import torch.amp.autocast_mode as autocast_mode
21
+ import onnxruntime as ort
22
+ from tqdm import tqdm
23
+
24
+
25
+ class Separator:
26
+ """
27
+ The Separator class is designed to facilitate the separation of audio sources from a given audio file.
28
+ It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides
29
+ functionalities to configure separation parameters, load models, and perform audio source separation.
30
+ It also handles logging, normalization, and output formatting of the separated audio stems.
31
+
32
+ The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
33
+ this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
34
+ initiating the separation process and passing outputs back to the caller.
35
+
36
+ Common Attributes:
37
+ log_level (int): The logging level.
38
+ log_formatter (logging.Formatter): The logging formatter.
39
+ model_file_dir (str): The directory where model files are stored.
40
+ output_dir (str): The directory where output files will be saved.
41
+ output_format (str): The format of the output audio file.
42
+ output_bitrate (str): The bitrate of the output audio file.
43
+ amplification_threshold (float): The threshold for audio amplification.
44
+ normalization_threshold (float): The threshold for audio normalization.
45
+ output_single_stem (str): Option to output a single stem.
46
+ invert_using_spec (bool): Flag to invert using spectrogram.
47
+ sample_rate (int): The sample rate of the audio.
48
+ use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
49
+ use_autocast (bool): Flag to use PyTorch autocast for faster inference.
50
+
51
+ MDX Architecture Specific Attributes:
52
+ hop_length (int): The hop length for STFT.
53
+ segment_size (int): The segment size for processing.
54
+ overlap (float): The overlap between segments.
55
+ batch_size (int): The batch size for processing.
56
+ enable_denoise (bool): Flag to enable or disable denoising.
57
+
58
+ VR Architecture Specific Attributes & Defaults:
59
+ batch_size: 16
60
+ window_size: 512
61
+ aggression: 5
62
+ enable_tta: False
63
+ enable_post_process: False
64
+ post_process_threshold: 0.2
65
+ high_end_process: False
66
+
67
+ Demucs Architecture Specific Attributes & Defaults:
68
+ segment_size: "Default"
69
+ shifts: 2
70
+ overlap: 0.25
71
+ segments_enabled: True
72
+
73
+ MDXC Architecture Specific Attributes & Defaults:
74
+ segment_size: 256
75
+ override_model_segment_size: False
76
+ batch_size: 1
77
+ overlap: 8
78
+ pitch_shift: 0
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ log_level=logging.INFO,
84
+ log_formatter=None,
85
+ model_file_dir="/tmp/audio-separator-models/",
86
+ output_dir=None,
87
+ output_format="WAV",
88
+ output_bitrate=None,
89
+ normalization_threshold=0.9,
90
+ amplification_threshold=0.0,
91
+ output_single_stem=None,
92
+ invert_using_spec=False,
93
+ sample_rate=44100,
94
+ use_soundfile=False,
95
+ use_autocast=False,
96
+ use_directml=False,
97
+ mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
98
+ vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
99
+ demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
100
+ mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
101
+ info_only=False,
102
+ ):
103
+ """Initialize the separator."""
104
+ self.logger = logging.getLogger(__name__)
105
+ self.logger.setLevel(log_level)
106
+ self.log_level = log_level
107
+ self.log_formatter = log_formatter
108
+
109
+ self.log_handler = logging.StreamHandler()
110
+
111
+ if self.log_formatter is None:
112
+ self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
113
+
114
+ self.log_handler.setFormatter(self.log_formatter)
115
+
116
+ if not self.logger.hasHandlers():
117
+ self.logger.addHandler(self.log_handler)
118
+
119
+ # Filter out noisy warnings from PyTorch for users who don't care about them
120
+ if log_level > logging.DEBUG:
121
+ warnings.filterwarnings("ignore")
122
+
123
+ # Skip initialization logs if info_only is True
124
+ if not info_only:
125
+ package_version = self.get_package_distribution("audio-separator").version
126
+ self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
127
+
128
+ if output_dir is None:
129
+ output_dir = os.getcwd()
130
+ if not info_only:
131
+ self.logger.info("Output directory not specified. Using current working directory.")
132
+
133
+ self.output_dir = output_dir
134
+
135
+ # Check for environment variable to override model_file_dir
136
+ env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
137
+ if env_model_dir:
138
+ self.model_file_dir = env_model_dir
139
+ self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
140
+ if not os.path.exists(self.model_file_dir):
141
+ raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
142
+ else:
143
+ self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
144
+ self.model_file_dir = model_file_dir
145
+
146
+ # Create the model directory if it does not exist
147
+ os.makedirs(self.model_file_dir, exist_ok=True)
148
+ os.makedirs(self.output_dir, exist_ok=True)
149
+
150
+ self.output_format = output_format
151
+ self.output_bitrate = output_bitrate
152
+
153
+ if self.output_format is None:
154
+ self.output_format = "WAV"
155
+
156
+ self.normalization_threshold = normalization_threshold
157
+ if normalization_threshold <= 0 or normalization_threshold > 1:
158
+ raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
159
+
160
+ self.amplification_threshold = amplification_threshold
161
+ if amplification_threshold < 0 or amplification_threshold > 1:
162
+ raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
163
+
164
+ self.output_single_stem = output_single_stem
165
+ if output_single_stem is not None:
166
+ self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
167
+
168
+ self.invert_using_spec = invert_using_spec
169
+ if self.invert_using_spec:
170
+ self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
171
+
172
+ try:
173
+ self.sample_rate = int(sample_rate)
174
+ if self.sample_rate <= 0:
175
+ raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
176
+ if self.sample_rate > 12800000:
177
+ raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
178
+ except ValueError:
179
+ raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
180
+
181
+ self.use_soundfile = use_soundfile
182
+ self.use_autocast = use_autocast
183
+ self.use_directml = use_directml
184
+
185
+ # These are parameters which users may want to configure so we expose them to the top-level Separator class,
186
+ # even though they are specific to a single model architecture
187
+ self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
188
+
189
+ self.torch_device = None
190
+ self.torch_device_cpu = None
191
+ self.torch_device_mps = None
192
+
193
+ self.onnx_execution_provider = None
194
+ self.model_instance = None
195
+
196
+ self.model_is_uvr_vip = False
197
+ self.model_friendly_name = None
198
+
199
+ if not info_only:
200
+ self.setup_accelerated_inferencing_device()
201
+
202
+ def setup_accelerated_inferencing_device(self):
203
+ """
204
+ This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
205
+ """
206
+ system_info = self.get_system_info()
207
+ self.check_ffmpeg_installed()
208
+ self.log_onnxruntime_packages()
209
+ self.setup_torch_device(system_info)
210
+
211
+ def get_system_info(self):
212
+ """
213
+ This method logs the system information, including the operating system, CPU archutecture and Python version
214
+ """
215
+ os_name = platform.system()
216
+ os_version = platform.version()
217
+ self.logger.info(f"Operating System: {os_name} {os_version}")
218
+
219
+ system_info = platform.uname()
220
+ self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
221
+
222
+ python_version = platform.python_version()
223
+ self.logger.info(f"Python Version: {python_version}")
224
+
225
+ pytorch_version = torch.__version__
226
+ self.logger.info(f"PyTorch Version: {pytorch_version}")
227
+ return system_info
228
+
229
+ def check_ffmpeg_installed(self):
230
+ """
231
+ This method checks if ffmpeg is installed and logs its version.
232
+ """
233
+ try:
234
+ ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
235
+ first_line = ffmpeg_version_output.splitlines()[0]
236
+ self.logger.info(f"FFmpeg installed: {first_line}")
237
+ except FileNotFoundError:
238
+ self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
239
+ # Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
240
+ # but if we're just running unit tests in CI, no reason to throw
241
+ if "PYTEST_CURRENT_TEST" not in os.environ:
242
+ raise
243
+
244
+ def log_onnxruntime_packages(self):
245
+ """
246
+ This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
247
+ """
248
+ onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
249
+ onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
250
+ onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
251
+ onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
252
+
253
+ if onnxruntime_gpu_package is not None:
254
+ self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
255
+ if onnxruntime_silicon_package is not None:
256
+ self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
257
+ if onnxruntime_cpu_package is not None:
258
+ self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
259
+ if onnxruntime_dml_package is not None:
260
+ self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
261
+
262
+ def setup_torch_device(self, system_info):
263
+ """
264
+ This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
265
+ """
266
+ hardware_acceleration_enabled = False
267
+ ort_providers = ort.get_available_providers()
268
+ has_torch_dml_installed = self.get_package_distribution("torch_directml")
269
+
270
+ self.torch_device_cpu = torch.device("cpu")
271
+
272
+ if torch.cuda.is_available():
273
+ self.configure_cuda(ort_providers)
274
+ hardware_acceleration_enabled = True
275
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
276
+ self.configure_mps(ort_providers)
277
+ hardware_acceleration_enabled = True
278
+ elif self.use_directml and has_torch_dml_installed:
279
+ import torch_directml
280
+ if torch_directml.is_available():
281
+ self.configure_dml(ort_providers)
282
+ hardware_acceleration_enabled = True
283
+
284
+ if not hardware_acceleration_enabled:
285
+ self.logger.info("No hardware acceleration could be configured, running in CPU mode")
286
+ self.torch_device = self.torch_device_cpu
287
+ self.onnx_execution_provider = ["CPUExecutionProvider"]
288
+
289
+ def configure_cuda(self, ort_providers):
290
+ """
291
+ This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
292
+ """
293
+ self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
294
+ self.torch_device = torch.device("cuda")
295
+ if "CUDAExecutionProvider" in ort_providers:
296
+ self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
297
+ self.onnx_execution_provider = ["CUDAExecutionProvider"]
298
+ else:
299
+ self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
300
+
301
+ def configure_mps(self, ort_providers):
302
+ """
303
+ This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
304
+ """
305
+ self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
306
+ self.torch_device_mps = torch.device("mps")
307
+
308
+ self.torch_device = self.torch_device_mps
309
+
310
+ if "CoreMLExecutionProvider" in ort_providers:
311
+ self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
312
+ self.onnx_execution_provider = ["CoreMLExecutionProvider"]
313
+ else:
314
+ self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
315
+
316
+ def configure_dml(self, ort_providers):
317
+ """
318
+ This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
319
+ """
320
+ import torch_directml
321
+ self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
322
+ self.torch_device_dml = torch_directml.device()
323
+ self.torch_device = self.torch_device_dml
324
+
325
+ if "DmlExecutionProvider" in ort_providers:
326
+ self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
327
+ self.onnx_execution_provider = ["DmlExecutionProvider"]
328
+ else:
329
+ self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
330
+
331
+ def get_package_distribution(self, package_name):
332
+ """
333
+ This method returns the package distribution for a given package name if installed, or None otherwise.
334
+ """
335
+ try:
336
+ return metadata.distribution(package_name)
337
+ except metadata.PackageNotFoundError:
338
+ self.logger.debug(f"Python package: {package_name} not installed")
339
+ return None
340
+
341
+ def get_model_hash(self, model_path):
342
+ """
343
+ This method returns the MD5 hash of a given model file.
344
+ """
345
+ self.logger.debug(f"Calculating hash of model file {model_path}")
346
+ # Use the specific byte count from the original logic
347
+ BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
348
+
349
+ try:
350
+ file_size = os.path.getsize(model_path)
351
+
352
+ with open(model_path, "rb") as f:
353
+ if file_size < BYTES_TO_HASH:
354
+ # Hash the entire file if smaller than the target byte count
355
+ self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.")
356
+ hash_value = hashlib.md5(f.read()).hexdigest()
357
+ else:
358
+ # Seek to the specific position before the end (from the beginning) and hash
359
+ seek_pos = file_size - BYTES_TO_HASH
360
+ self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
361
+ f.seek(seek_pos, io.SEEK_SET)
362
+ hash_value = hashlib.md5(f.read()).hexdigest()
363
+
364
+ # Log the calculated hash
365
+ self.logger.info(f"Hash of model file {model_path} is {hash_value}")
366
+ return hash_value
367
+
368
+ except FileNotFoundError:
369
+ self.logger.error(f"Model file not found at {model_path}")
370
+ raise # Re-raise the specific error
371
+ except Exception as e:
372
+ # Catch other potential errors (e.g., permissions, other IOErrors)
373
+ self.logger.error(f"Error calculating hash for {model_path}: {e}")
374
+ raise # Re-raise other errors
375
+
376
+ def download_file_if_not_exists(self, url, output_path):
377
+ """
378
+ This method downloads a file from a given URL to a given output path, if the file does not already exist.
379
+ """
380
+
381
+ if os.path.isfile(output_path):
382
+ self.logger.debug(f"File already exists at {output_path}, skipping download")
383
+ return
384
+
385
+ self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s")
386
+ response = requests.get(url, stream=True, timeout=300)
387
+
388
+ if response.status_code == 200:
389
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
390
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
391
+
392
+ with open(output_path, "wb") as f:
393
+ for chunk in response.iter_content(chunk_size=8192):
394
+ progress_bar.update(len(chunk))
395
+ f.write(chunk)
396
+ progress_bar.close()
397
+ else:
398
+ raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}")
399
+
400
+ def list_supported_model_files(self):
401
+ """
402
+ This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these.
403
+ Also includes model performance scores where available.
404
+
405
+ Example response object:
406
+
407
+ {
408
+ "MDX": {
409
+ "MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": {
410
+ "filename": "UVR-MDX-NET-Inst_full_292.onnx",
411
+ "scores": {
412
+ "vocals": {
413
+ "SDR": 10.6497,
414
+ "SIR": 20.3786,
415
+ "SAR": 10.692,
416
+ "ISR": 14.848
417
+ },
418
+ "instrumental": {
419
+ "SDR": 15.2149,
420
+ "SIR": 25.6075,
421
+ "SAR": 17.1363,
422
+ "ISR": 17.7893
423
+ }
424
+ },
425
+ "download_files": [
426
+ "UVR-MDX-NET-Inst_full_292.onnx"
427
+ ]
428
+ }
429
+ },
430
+ "Demucs": {
431
+ "Demucs v4: htdemucs_ft": {
432
+ "filename": "htdemucs_ft.yaml",
433
+ "scores": {
434
+ "vocals": {
435
+ "SDR": 11.2685,
436
+ "SIR": 21.257,
437
+ "SAR": 11.0359,
438
+ "ISR": 19.3753
439
+ },
440
+ "drums": {
441
+ "SDR": 13.235,
442
+ "SIR": 23.3053,
443
+ "SAR": 13.0313,
444
+ "ISR": 17.2889
445
+ },
446
+ "bass": {
447
+ "SDR": 9.72743,
448
+ "SIR": 19.5435,
449
+ "SAR": 9.20801,
450
+ "ISR": 13.5037
451
+ }
452
+ },
453
+ "download_files": [
454
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
455
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
456
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
457
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
458
+ "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
459
+ ]
460
+ }
461
+ },
462
+ "MDXC": {
463
+ "MDX23C Model: MDX23C-InstVoc HQ": {
464
+ "filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
465
+ "scores": {
466
+ "vocals": {
467
+ "SDR": 11.9504,
468
+ "SIR": 23.1166,
469
+ "SAR": 12.093,
470
+ "ISR": 15.4782
471
+ },
472
+ "instrumental": {
473
+ "SDR": 16.3035,
474
+ "SIR": 26.6161,
475
+ "SAR": 18.5167,
476
+ "ISR": 18.3939
477
+ }
478
+ },
479
+ "download_files": [
480
+ "MDX23C-8KFFT-InstVoc_HQ.ckpt",
481
+ "model_2_stem_full_band_8k.yaml"
482
+ ]
483
+ }
484
+ }
485
+ }
486
+ """
487
+ download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
488
+
489
+ self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
490
+
491
+ model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
492
+ self.logger.debug(f"UVR model download list loaded")
493
+
494
+ # Load the model scores with error handling
495
+ model_scores = {}
496
+ try:
497
+ with resources.open_text("audio_separator", "models-scores.json") as f:
498
+ model_scores = json.load(f)
499
+ self.logger.debug(f"Model scores loaded")
500
+ except json.JSONDecodeError as e:
501
+ self.logger.warning(f"Failed to load model scores: {str(e)}")
502
+ self.logger.warning("Continuing without model scores")
503
+
504
+ # Only show Demucs v4 models as we've only implemented support for v4
505
+ filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
506
+
507
+ # Modified Demucs handling to use YAML files as identifiers and include download files
508
+ demucs_models = {}
509
+ for name, files in filtered_demucs_v4.items():
510
+ # Find the YAML file in the model files
511
+ yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
512
+ if yaml_file:
513
+ model_score_data = model_scores.get(yaml_file, {})
514
+ demucs_models[name] = {
515
+ "filename": yaml_file,
516
+ "scores": model_score_data.get("median_scores", {}),
517
+ "stems": model_score_data.get("stems", []),
518
+ "target_stem": model_score_data.get("target_stem"),
519
+ "download_files": list(files.values()), # List of all download URLs/filenames
520
+ }
521
+
522
+ # Load the JSON file using importlib.resources
523
+ with resources.open_text("audio_separator", "models.json") as f:
524
+ audio_separator_models_list = json.load(f)
525
+ self.logger.debug(f"Audio-Separator model list loaded")
526
+
527
+ # Return object with list of model names
528
+ model_files_grouped_by_type = {
529
+ "VR": {
530
+ name: {
531
+ "filename": filename,
532
+ "scores": model_scores.get(filename, {}).get("median_scores", {}),
533
+ "stems": model_scores.get(filename, {}).get("stems", []),
534
+ "target_stem": model_scores.get(filename, {}).get("target_stem"),
535
+ "download_files": [filename],
536
+ } # Just the filename for VR models
537
+ for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
538
+ },
539
+ "MDX": {
540
+ name: {
541
+ "filename": filename,
542
+ "scores": model_scores.get(filename, {}).get("median_scores", {}),
543
+ "stems": model_scores.get(filename, {}).get("stems", []),
544
+ "target_stem": model_scores.get(filename, {}).get("target_stem"),
545
+ "download_files": [filename],
546
+ } # Just the filename for MDX models
547
+ for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
548
+ },
549
+ "Demucs": demucs_models,
550
+ "MDXC": {
551
+ name: {
552
+ "filename": next(iter(files.keys())),
553
+ "scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}),
554
+ "stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []),
555
+ "target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"),
556
+ "download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames
557
+ }
558
+ for name, files in {
559
+ **model_downloads_list["mdx23c_download_list"],
560
+ **model_downloads_list["mdx23c_download_vip_list"],
561
+ **model_downloads_list["roformer_download_list"],
562
+ **audio_separator_models_list["mdx23c_download_list"],
563
+ **audio_separator_models_list["roformer_download_list"],
564
+ }.items()
565
+ },
566
+ }
567
+
568
+ return model_files_grouped_by_type
569
+
570
+ def print_uvr_vip_message(self):
571
+ """
572
+ This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
573
+ """
574
+ if self.model_is_uvr_vip:
575
+ self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
576
+ self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
577
+
578
+ def download_model_files(self, model_filename):
579
+ """
580
+ This method downloads the model files for a given model filename, if they are not already present.
581
+ Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename)
582
+ """
583
+ model_path = os.path.join(self.model_file_dir, f"{model_filename}")
584
+
585
+ supported_model_files_grouped = self.list_supported_model_files()
586
+ public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
587
+ vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
588
+ audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
589
+
590
+ yaml_config_filename = None
591
+
592
+ self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
593
+
594
+ # Iterate through model types (MDX, Demucs, MDXC)
595
+ for model_type, models in supported_model_files_grouped.items():
596
+ # Iterate through each model in this type
597
+ for model_friendly_name, model_info in models.items():
598
+ self.model_is_uvr_vip = "VIP" in model_friendly_name
599
+ model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
600
+
601
+ # Check if this model matches our target filename
602
+ if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
603
+ self.logger.debug(f"Found matching model: {model_friendly_name}")
604
+ self.model_friendly_name = model_friendly_name
605
+ self.print_uvr_vip_message()
606
+
607
+ # Download each required file for this model
608
+ for file_to_download in model_info["download_files"]:
609
+ # For URLs, extract just the filename portion
610
+ if file_to_download.startswith("http"):
611
+ filename = file_to_download.split("/")[-1]
612
+ download_path = os.path.join(self.model_file_dir, filename)
613
+ self.download_file_if_not_exists(file_to_download, download_path)
614
+ continue
615
+
616
+ download_path = os.path.join(self.model_file_dir, file_to_download)
617
+
618
+ # For MDXC models, handle YAML config files specially
619
+ if model_type == "MDXC" and file_to_download.endswith(".yaml"):
620
+ yaml_config_filename = file_to_download
621
+ try:
622
+ yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
623
+ self.download_file_if_not_exists(yaml_url, download_path)
624
+ except RuntimeError:
625
+ self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
626
+ yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
627
+ self.download_file_if_not_exists(yaml_url, download_path)
628
+ continue
629
+
630
+ # For regular model files, try UVR repo first, then audio-separator repo
631
+ try:
632
+ download_url = f"{model_repo_url_prefix}/{file_to_download}"
633
+ self.download_file_if_not_exists(download_url, download_path)
634
+ except RuntimeError:
635
+ self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...")
636
+ download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
637
+ self.download_file_if_not_exists(download_url, download_path)
638
+
639
+ return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
640
+
641
+ raise ValueError(f"Model file {model_filename} not found in supported model files")
642
+
643
+ def load_model_data_from_yaml(self, yaml_config_filename):
644
+ """
645
+ This method loads model-specific parameters from the YAML file for that model.
646
+ The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training.
647
+ """
648
+ # Verify if the YAML filename includes a full path or just the filename
649
+ if not os.path.exists(yaml_config_filename):
650
+ model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
651
+ else:
652
+ model_data_yaml_filepath = yaml_config_filename
653
+
654
+ self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")
655
+
656
+ model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
657
+ self.logger.debug(f"Model data loaded from YAML file: {model_data}")
658
+
659
+ if "roformer" in model_data_yaml_filepath:
660
+ model_data["is_roformer"] = True
661
+
662
+ return model_data
663
+
664
+ def load_model_data_using_hash(self, model_path):
665
+ """
666
+ This method loads model-specific parameters from UVR model data files.
667
+ These parameters are critical to inferencing using a given model, as they need to match whatever was used during training.
668
+ The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files.
669
+ """
670
+ # Model data and configuration sources from UVR
671
+ model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
672
+
673
+ vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
674
+ mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
675
+
676
+ # Calculate hash for the downloaded model
677
+ self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
678
+ model_hash = self.get_model_hash(model_path)
679
+ self.logger.debug(f"Model {model_path} has hash {model_hash}")
680
+
681
+ # Setting up the path for model data and checking its existence
682
+ vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json")
683
+ self.logger.debug(f"VR model data path set to {vr_model_data_path}")
684
+ self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path)
685
+
686
+ mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
687
+ self.logger.debug(f"MDX model data path set to {mdx_model_data_path}")
688
+ self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
689
+
690
+ # Loading model data from UVR
691
+ self.logger.debug("Loading MDX and VR model parameters from UVR model data files...")
692
+ vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8"))
693
+ mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
694
+
695
+ # Load additional model data from audio-separator
696
+ self.logger.debug("Loading additional model parameters from audio-separator model data file...")
697
+ with resources.open_text("audio_separator", "model-data.json") as f:
698
+ audio_separator_model_data = json.load(f)
699
+
700
+ # Merge the model data objects, with audio-separator data taking precedence
701
+ vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
702
+ mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
703
+
704
+ if model_hash in mdx_model_data_object:
705
+ model_data = mdx_model_data_object[model_hash]
706
+ elif model_hash in vr_model_data_object:
707
+ model_data = vr_model_data_object[model_hash]
708
+ else:
709
+ raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
710
+
711
+ self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
712
+
713
+ return model_data
714
+
715
+ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
716
+ """
717
+ This method instantiates the architecture-specific separation class,
718
+ loading the separation model into memory, downloading it first if necessary.
719
+ """
720
+ self.logger.info(f"Loading model {model_filename}...")
721
+
722
+ load_model_start_time = time.perf_counter()
723
+
724
+ # Setting up the model path
725
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
726
+ model_name = model_filename.split(".")[0]
727
+ self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
728
+
729
+ if model_path.lower().endswith(".yaml"):
730
+ yaml_config_filename = model_path
731
+
732
+ if yaml_config_filename is not None:
733
+ model_data = self.load_model_data_from_yaml(yaml_config_filename)
734
+ else:
735
+ model_data = self.load_model_data_using_hash(model_path)
736
+
737
+ common_params = {
738
+ "logger": self.logger,
739
+ "log_level": self.log_level,
740
+ "torch_device": self.torch_device,
741
+ "torch_device_cpu": self.torch_device_cpu,
742
+ "torch_device_mps": self.torch_device_mps,
743
+ "onnx_execution_provider": self.onnx_execution_provider,
744
+ "model_name": model_name,
745
+ "model_path": model_path,
746
+ "model_data": model_data,
747
+ "output_format": self.output_format,
748
+ "output_bitrate": self.output_bitrate,
749
+ "output_dir": self.output_dir,
750
+ "normalization_threshold": self.normalization_threshold,
751
+ "amplification_threshold": self.amplification_threshold,
752
+ "output_single_stem": self.output_single_stem,
753
+ "invert_using_spec": self.invert_using_spec,
754
+ "sample_rate": self.sample_rate,
755
+ "use_soundfile": self.use_soundfile,
756
+ }
757
+
758
+ # Instantiate the appropriate separator class depending on the model type
759
+ separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
760
+
761
+ if model_type not in self.arch_specific_params or model_type not in separator_classes:
762
+ raise ValueError(f"Model type not supported (yet): {model_type}")
763
+
764
+ if model_type == "Demucs" and sys.version_info < (3, 10):
765
+ raise Exception("Demucs models require Python version 3.10 or newer.")
766
+
767
+ self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
768
+
769
+ module_name, class_name = separator_classes[model_type].split(".")
770
+ module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
771
+ separator_class = getattr(module, class_name)
772
+
773
+ self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
774
+ self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
775
+
776
+ # Log the completion of the model load process
777
+ self.logger.debug("Loading model completed.")
778
+ self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
779
+
780
+ def separate(self, audio_file_path, custom_output_names=None):
781
+ """
782
+ Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model.
783
+
784
+ This method takes the path to an audio file or a directory containing audio files, processes them through
785
+ the loaded separation model, and returns the paths to the output files containing the separated audio stems.
786
+ It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
787
+
788
+ Parameters:
789
+ - audio_file_path (str or list): The path to the audio file or directory, or a list of paths.
790
+ - custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
791
+
792
+ Returns:
793
+ - output_files (list of str): A list containing the paths to the separated audio stem files.
794
+ """
795
+ # Check if the model and device are properly initialized
796
+ if not (self.torch_device and self.model_instance):
797
+ raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
798
+
799
+ # If audio_file_path is a string, convert it to a list for uniform processing
800
+ if isinstance(audio_file_path, str):
801
+ audio_file_path = [audio_file_path]
802
+
803
+ # Initialize a list to store paths of all output files
804
+ output_files = []
805
+
806
+ # Process each path in the list
807
+ for path in audio_file_path:
808
+ if os.path.isdir(path):
809
+ # If the path is a directory, recursively search for all audio files
810
+ for root, dirs, files in os.walk(path):
811
+ for file in files:
812
+ # Check the file extension to ensure it's an audio file
813
+ if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
814
+ full_path = os.path.join(root, file)
815
+ self.logger.info(f"Processing file: {full_path}")
816
+ try:
817
+ # Perform separation for each file
818
+ files_output = self._separate_file(full_path, custom_output_names)
819
+ output_files.extend(files_output)
820
+ except Exception as e:
821
+ self.logger.error(f"Failed to process file {full_path}: {e}")
822
+ else:
823
+ # If the path is a file, process it directly
824
+ self.logger.info(f"Processing file: {path}")
825
+ try:
826
+ files_output = self._separate_file(path, custom_output_names)
827
+ output_files.extend(files_output)
828
+ except Exception as e:
829
+ self.logger.error(f"Failed to process file {path}: {e}")
830
+
831
+ return output_files
832
+
833
+ def _separate_file(self, audio_file_path, custom_output_names=None):
834
+ """
835
+ Internal method to handle separation for a single audio file.
836
+ This method performs the actual separation process for a single audio file. It logs the start and end of the process,
837
+ handles autocast if enabled, and ensures GPU cache is cleared after processing.
838
+ Parameters:
839
+ - audio_file_path (str): The path to the audio file.
840
+ - custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
841
+ Returns:
842
+ - output_files (list of str): A list containing the paths to the separated audio stem files.
843
+ """
844
+ # Log the start of the separation process
845
+ self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
846
+ separate_start_time = time.perf_counter()
847
+
848
+ # Log normalization and amplification thresholds
849
+ self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
850
+ self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
851
+
852
+ # Run separation method for the loaded model with autocast enabled if supported by the device
853
+ output_files = None
854
+ if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
855
+ self.logger.debug("Autocast available.")
856
+ with autocast_mode.autocast(self.torch_device.type):
857
+ output_files = self.model_instance.separate(audio_file_path, custom_output_names)
858
+ else:
859
+ self.logger.debug("Autocast unavailable.")
860
+ output_files = self.model_instance.separate(audio_file_path, custom_output_names)
861
+
862
+ # Clear GPU cache to free up memory
863
+ self.model_instance.clear_gpu_cache()
864
+
865
+ # Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
866
+ self.model_instance.clear_file_specific_paths()
867
+
868
+ # Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
869
+ self.print_uvr_vip_message()
870
+
871
+ # Log the completion of the separation process
872
+ self.logger.debug("Separation process completed.")
873
+ self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')
874
+
875
+ return output_files
876
+
877
+ def download_model_and_data(self, model_filename):
878
+ """
879
+ Downloads the model file without loading it into memory.
880
+ """
881
+ self.logger.info(f"Downloading model {model_filename}...")
882
+
883
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
884
+
885
+ if model_path.lower().endswith(".yaml"):
886
+ yaml_config_filename = model_path
887
+
888
+ if yaml_config_filename is not None:
889
+ model_data = self.load_model_data_from_yaml(yaml_config_filename)
890
+ else:
891
+ model_data = self.load_model_data_using_hash(model_path)
892
+
893
+ model_data_dict_size = len(model_data)
894
+
895
+ self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
896
+
897
+ def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
898
+ """
899
+ Returns a simplified, user-friendly list of models with their key metrics.
900
+ Optionally sorts the list based on the specified criteria.
901
+
902
+ :param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
903
+ """
904
+ model_files = self.list_supported_model_files()
905
+ simplified_list = {}
906
+
907
+ for model_type, models in model_files.items():
908
+ for name, data in models.items():
909
+ filename = data["filename"]
910
+ scores = data.get("scores") or {}
911
+ stems = data.get("stems") or []
912
+ target_stem = data.get("target_stem")
913
+
914
+ # Format stems with their SDR scores where available
915
+ stems_with_scores = []
916
+ stem_sdr_dict = {}
917
+
918
+ # Process each stem from the model's stem list
919
+ for stem in stems:
920
+ stem_scores = scores.get(stem, {})
921
+ # Add asterisk if this is the target stem
922
+ stem_display = f"{stem}*" if stem == target_stem else stem
923
+
924
+ if isinstance(stem_scores, dict) and "SDR" in stem_scores:
925
+ sdr = round(stem_scores["SDR"], 1)
926
+ stems_with_scores.append(f"{stem_display} ({sdr})")
927
+ stem_sdr_dict[stem.lower()] = sdr
928
+ else:
929
+ # Include stem without SDR score
930
+ stems_with_scores.append(stem_display)
931
+ stem_sdr_dict[stem.lower()] = None
932
+
933
+ # If no stems listed, mark as Unknown
934
+ if not stems_with_scores:
935
+ stems_with_scores = ["Unknown"]
936
+ stem_sdr_dict["unknown"] = None
937
+
938
+ simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict}
939
+
940
+ # Sort and filter the list if a sort_by parameter is provided
941
+ if filter_sort_by:
942
+ if filter_sort_by == "name":
943
+ return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
944
+ elif filter_sort_by == "filename":
945
+ return dict(sorted(simplified_list.items()))
946
+ else:
947
+ # Convert sort_by to lowercase for case-insensitive comparison
948
+ sort_by_lower = filter_sort_by.lower()
949
+ # Filter out models that don't have the specified stem
950
+ filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
951
+
952
+ # Sort by SDR score if available, putting None values last
953
+ def sort_key(item):
954
+ sdr = item[1]["SDR"][sort_by_lower]
955
+ return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
956
+
957
+ return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
958
+
959
+ return simplified_list
audio_separator/separator/uvr_lib_v5/__init__.py ADDED
File without changes
audio_separator/separator/uvr_lib_v5/demucs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
audio_separator/separator/uvr_lib_v5/demucs/__main__.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from fractions import Fraction
13
+
14
+ import torch as th
15
+ from torch import distributed, nn
16
+ from torch.nn.parallel.distributed import DistributedDataParallel
17
+
18
+ from .augment import FlipChannels, FlipSign, Remix, Shift
19
+ from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
20
+ from .model import Demucs
21
+ from .parser import get_name, get_parser
22
+ from .raw import Rawset
23
+ from .tasnet import ConvTasNet
24
+ from .test import evaluate
25
+ from .train import train_model, validate_model
26
+ from .utils import human_seconds, load_model, save_model, sizeof_fmt
27
+
28
+
29
+ @dataclass
30
+ class SavedState:
31
+ metrics: list = field(default_factory=list)
32
+ last_state: dict = None
33
+ best_state: dict = None
34
+ optimizer: dict = None
35
+
36
+
37
+ def main():
38
+ parser = get_parser()
39
+ args = parser.parse_args()
40
+ name = get_name(parser, args)
41
+ print(f"Experiment {name}")
42
+
43
+ if args.musdb is None and args.rank == 0:
44
+ print("You must provide the path to the MusDB dataset with the --musdb flag. " "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", file=sys.stderr)
45
+ sys.exit(1)
46
+
47
+ eval_folder = args.evals / name
48
+ eval_folder.mkdir(exist_ok=True, parents=True)
49
+ args.logs.mkdir(exist_ok=True)
50
+ metrics_path = args.logs / f"{name}.json"
51
+ eval_folder.mkdir(exist_ok=True, parents=True)
52
+ args.checkpoints.mkdir(exist_ok=True, parents=True)
53
+ args.models.mkdir(exist_ok=True, parents=True)
54
+
55
+ if args.device is None:
56
+ device = "cpu"
57
+ if th.cuda.is_available():
58
+ device = "cuda"
59
+ else:
60
+ device = args.device
61
+
62
+ th.manual_seed(args.seed)
63
+ # Prevents too many threads to be started when running `museval` as it can be quite
64
+ # inefficient on NUMA architectures.
65
+ os.environ["OMP_NUM_THREADS"] = "1"
66
+
67
+ if args.world_size > 1:
68
+ if device != "cuda" and args.rank == 0:
69
+ print("Error: distributed training is only available with cuda device", file=sys.stderr)
70
+ sys.exit(1)
71
+ th.cuda.set_device(args.rank % th.cuda.device_count())
72
+ distributed.init_process_group(backend="nccl", init_method="tcp://" + args.master, rank=args.rank, world_size=args.world_size)
73
+
74
+ checkpoint = args.checkpoints / f"{name}.th"
75
+ checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
76
+ if args.restart and checkpoint.exists():
77
+ checkpoint.unlink()
78
+
79
+ if args.test:
80
+ args.epochs = 1
81
+ args.repeat = 0
82
+ model = load_model(args.models / args.test)
83
+ elif args.tasnet:
84
+ model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
85
+ else:
86
+ model = Demucs(
87
+ audio_channels=args.audio_channels,
88
+ channels=args.channels,
89
+ context=args.context,
90
+ depth=args.depth,
91
+ glu=args.glu,
92
+ growth=args.growth,
93
+ kernel_size=args.kernel_size,
94
+ lstm_layers=args.lstm_layers,
95
+ rescale=args.rescale,
96
+ rewrite=args.rewrite,
97
+ sources=4,
98
+ stride=args.conv_stride,
99
+ upsample=args.upsample,
100
+ samplerate=args.samplerate,
101
+ )
102
+ model.to(device)
103
+ if args.show:
104
+ print(model)
105
+ size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
106
+ print(f"Model size {size}")
107
+ return
108
+
109
+ optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
110
+
111
+ try:
112
+ saved = th.load(checkpoint, map_location="cpu")
113
+ except IOError:
114
+ saved = SavedState()
115
+ else:
116
+ model.load_state_dict(saved.last_state)
117
+ optimizer.load_state_dict(saved.optimizer)
118
+
119
+ if args.save_model:
120
+ if args.rank == 0:
121
+ model.to("cpu")
122
+ model.load_state_dict(saved.best_state)
123
+ save_model(model, args.models / f"{name}.th")
124
+ return
125
+
126
+ if args.rank == 0:
127
+ done = args.logs / f"{name}.done"
128
+ if done.exists():
129
+ done.unlink()
130
+
131
+ if args.augment:
132
+ augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), Remix(group_size=args.remix_group_size)).to(device)
133
+ else:
134
+ augment = Shift(args.data_stride)
135
+
136
+ if args.mse:
137
+ criterion = nn.MSELoss()
138
+ else:
139
+ criterion = nn.L1Loss()
140
+
141
+ # Setting number of samples so that all convolution windows are full.
142
+ # Prevents hard to debug mistake with the prediction being shifted compared
143
+ # to the input mixture.
144
+ samples = model.valid_length(args.samples)
145
+ print(f"Number of training samples adjusted to {samples}")
146
+
147
+ if args.raw:
148
+ train_set = Rawset(args.raw / "train", samples=samples + args.data_stride, channels=args.audio_channels, streams=[0, 1, 2, 3, 4], stride=args.data_stride)
149
+
150
+ valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
151
+ else:
152
+ if not args.metadata.is_file() and args.rank == 0:
153
+ build_musdb_metadata(args.metadata, args.musdb, args.workers)
154
+ if args.world_size > 1:
155
+ distributed.barrier()
156
+ metadata = json.load(open(args.metadata))
157
+ duration = Fraction(samples + args.data_stride, args.samplerate)
158
+ stride = Fraction(args.data_stride, args.samplerate)
159
+ train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), metadata, duration=duration, stride=stride, samplerate=args.samplerate, channels=args.audio_channels)
160
+ valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), metadata, samplerate=args.samplerate, channels=args.audio_channels)
161
+
162
+ best_loss = float("inf")
163
+ for epoch, metrics in enumerate(saved.metrics):
164
+ print(f"Epoch {epoch:03d}: " f"train={metrics['train']:.8f} " f"valid={metrics['valid']:.8f} " f"best={metrics['best']:.4f} " f"duration={human_seconds(metrics['duration'])}")
165
+ best_loss = metrics["best"]
166
+
167
+ if args.world_size > 1:
168
+ dmodel = DistributedDataParallel(model, device_ids=[th.cuda.current_device()], output_device=th.cuda.current_device())
169
+ else:
170
+ dmodel = model
171
+
172
+ for epoch in range(len(saved.metrics), args.epochs):
173
+ begin = time.time()
174
+ model.train()
175
+ train_loss = train_model(
176
+ epoch, train_set, dmodel, criterion, optimizer, augment, batch_size=args.batch_size, device=device, repeat=args.repeat, seed=args.seed, workers=args.workers, world_size=args.world_size
177
+ )
178
+ model.eval()
179
+ valid_loss = validate_model(epoch, valid_set, model, criterion, device=device, rank=args.rank, split=args.split_valid, world_size=args.world_size)
180
+
181
+ duration = time.time() - begin
182
+ if valid_loss < best_loss:
183
+ best_loss = valid_loss
184
+ saved.best_state = {key: value.to("cpu").clone() for key, value in model.state_dict().items()}
185
+ saved.metrics.append({"train": train_loss, "valid": valid_loss, "best": best_loss, "duration": duration})
186
+ if args.rank == 0:
187
+ json.dump(saved.metrics, open(metrics_path, "w"))
188
+
189
+ saved.last_state = model.state_dict()
190
+ saved.optimizer = optimizer.state_dict()
191
+ if args.rank == 0 and not args.test:
192
+ th.save(saved, checkpoint_tmp)
193
+ checkpoint_tmp.rename(checkpoint)
194
+
195
+ print(f"Epoch {epoch:03d}: " f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " f"duration={human_seconds(duration)}")
196
+
197
+ del dmodel
198
+ model.load_state_dict(saved.best_state)
199
+ if args.eval_cpu:
200
+ device = "cpu"
201
+ model.to(device)
202
+ model.eval()
203
+ evaluate(model, args.musdb, eval_folder, rank=args.rank, world_size=args.world_size, device=device, save=args.save, split=args.split_valid, shifts=args.shifts, workers=args.eval_workers)
204
+ model.to("cpu")
205
+ save_model(model, args.models / f"{name}.th")
206
+ if args.rank == 0:
207
+ print("done")
208
+ done.write_text("done")
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()
audio_separator/separator/uvr_lib_v5/demucs/apply.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Code to apply a model to a mix. It will handle chunking with overlaps and
8
+ inteprolation between chunks, as well as the "shift trick".
9
+ """
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import random
12
+ import typing as tp
13
+
14
+ import torch as th
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ import tqdm
18
+
19
+ from .demucs import Demucs
20
+ from .hdemucs import HDemucs
21
+ from .utils import center_trim, DummyPoolExecutor
22
+
23
+ Model = tp.Union[Demucs, HDemucs]
24
+
25
+ progress_bar_num = 0
26
+
27
+
28
+ class BagOfModels(nn.Module):
29
+ def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None):
30
+ """
31
+ Represents a bag of models with specific weights.
32
+ You should call `apply_model` rather than calling directly the forward here for
33
+ optimal performance.
34
+
35
+ Args:
36
+ models (list[nn.Module]): list of Demucs/HDemucs models.
37
+ weights (list[list[float]]): list of weights. If None, assumed to
38
+ be all ones, otherwise it should be a list of N list (N number of models),
39
+ each containing S floats (S number of sources).
40
+ segment (None or float): overrides the `segment` attribute of each model
41
+ (this is performed inplace, be careful if you reuse the models passed).
42
+ """
43
+
44
+ super().__init__()
45
+ assert len(models) > 0
46
+ first = models[0]
47
+ for other in models:
48
+ assert other.sources == first.sources
49
+ assert other.samplerate == first.samplerate
50
+ assert other.audio_channels == first.audio_channels
51
+ if segment is not None:
52
+ other.segment = segment
53
+
54
+ self.audio_channels = first.audio_channels
55
+ self.samplerate = first.samplerate
56
+ self.sources = first.sources
57
+ self.models = nn.ModuleList(models)
58
+
59
+ if weights is None:
60
+ weights = [[1.0 for _ in first.sources] for _ in models]
61
+ else:
62
+ assert len(weights) == len(models)
63
+ for weight in weights:
64
+ assert len(weight) == len(first.sources)
65
+ self.weights = weights
66
+
67
+ def forward(self, x):
68
+ raise NotImplementedError("Call `apply_model` on this.")
69
+
70
+
71
+ class TensorChunk:
72
+ def __init__(self, tensor, offset=0, length=None):
73
+ total_length = tensor.shape[-1]
74
+ assert offset >= 0
75
+ assert offset < total_length
76
+
77
+ if length is None:
78
+ length = total_length - offset
79
+ else:
80
+ length = min(total_length - offset, length)
81
+
82
+ if isinstance(tensor, TensorChunk):
83
+ self.tensor = tensor.tensor
84
+ self.offset = offset + tensor.offset
85
+ else:
86
+ self.tensor = tensor
87
+ self.offset = offset
88
+ self.length = length
89
+ self.device = tensor.device
90
+
91
+ @property
92
+ def shape(self):
93
+ shape = list(self.tensor.shape)
94
+ shape[-1] = self.length
95
+ return shape
96
+
97
+ def padded(self, target_length):
98
+ delta = target_length - self.length
99
+ total_length = self.tensor.shape[-1]
100
+ assert delta >= 0
101
+
102
+ start = self.offset - delta // 2
103
+ end = start + target_length
104
+
105
+ correct_start = max(0, start)
106
+ correct_end = min(total_length, end)
107
+
108
+ pad_left = correct_start - start
109
+ pad_right = end - correct_end
110
+
111
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
112
+ assert out.shape[-1] == target_length
113
+ return out
114
+
115
+
116
+ def tensor_chunk(tensor_or_chunk):
117
+ if isinstance(tensor_or_chunk, TensorChunk):
118
+ return tensor_or_chunk
119
+ else:
120
+ assert isinstance(tensor_or_chunk, th.Tensor)
121
+ return TensorChunk(tensor_or_chunk)
122
+
123
+
124
+ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
125
+ """
126
+ Apply model to a given mixture.
127
+
128
+ Args:
129
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
130
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
131
+ all predictions are averaged. This effectively makes the model time equivariant
132
+ and improves SDR by up to 0.2 points.
133
+ split (bool): if True, the input will be broken down in 8 seconds extracts
134
+ and predictions will be performed individually on each and concatenated.
135
+ Useful for model with large memory footprint like Tasnet.
136
+ progress (bool): if True, show a progress bar (requires split=True)
137
+ device (torch.device, str, or None): if provided, device on which to
138
+ execute the computation, otherwise `mix.device` is assumed.
139
+ When `device` is different from `mix.device`, only local computations will
140
+ be on `device`, while the entire tracks will be stored on `mix.device`.
141
+ """
142
+
143
+ global fut_length
144
+ global bag_num
145
+ global prog_bar
146
+
147
+ if device is None:
148
+ device = mix.device
149
+ else:
150
+ device = th.device(device)
151
+ if pool is None:
152
+ if num_workers > 0 and device.type == "cpu":
153
+ pool = ThreadPoolExecutor(num_workers)
154
+ else:
155
+ pool = DummyPoolExecutor()
156
+
157
+ kwargs = {
158
+ "shifts": shifts,
159
+ "split": split,
160
+ "overlap": overlap,
161
+ "transition_power": transition_power,
162
+ "progress": progress,
163
+ "device": device,
164
+ "pool": pool,
165
+ "set_progress_bar": set_progress_bar,
166
+ "static_shifts": static_shifts,
167
+ }
168
+
169
+ if isinstance(model, BagOfModels):
170
+ # Special treatment for bag of model.
171
+ # We explicitely apply multiple times `apply_model` so that the random shifts
172
+ # are different for each model.
173
+
174
+ estimates = 0
175
+ totals = [0] * len(model.sources)
176
+ bag_num = len(model.models)
177
+ fut_length = 0
178
+ prog_bar = 0
179
+ current_model = 0 # (bag_num + 1)
180
+ for sub_model, weight in zip(model.models, model.weights):
181
+ original_model_device = next(iter(sub_model.parameters())).device
182
+ sub_model.to(device)
183
+ fut_length += fut_length
184
+ current_model += 1
185
+ out = apply_model(sub_model, mix, **kwargs)
186
+ sub_model.to(original_model_device)
187
+ for k, inst_weight in enumerate(weight):
188
+ out[:, k, :, :] *= inst_weight
189
+ totals[k] += inst_weight
190
+ estimates += out
191
+ del out
192
+
193
+ for k in range(estimates.shape[1]):
194
+ estimates[:, k, :, :] /= totals[k]
195
+ return estimates
196
+
197
+ model.to(device)
198
+ model.eval()
199
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
200
+ batch, channels, length = mix.shape
201
+
202
+ if shifts:
203
+ kwargs["shifts"] = 0
204
+ max_shift = int(0.5 * model.samplerate)
205
+ mix = tensor_chunk(mix)
206
+ padded_mix = mix.padded(length + 2 * max_shift)
207
+ out = 0
208
+ for _ in range(shifts):
209
+ offset = random.randint(0, max_shift)
210
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
211
+ shifted_out = apply_model(model, shifted, **kwargs)
212
+ out += shifted_out[..., max_shift - offset :]
213
+ out /= shifts
214
+ return out
215
+ elif split:
216
+ kwargs["split"] = False
217
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
218
+ sum_weight = th.zeros(length, device=mix.device)
219
+ segment = int(model.samplerate * model.segment)
220
+ stride = int((1 - overlap) * segment)
221
+ offsets = range(0, length, stride)
222
+ scale = float(format(stride / model.samplerate, ".2f"))
223
+ # We start from a triangle shaped weight, with maximal weight in the middle
224
+ # of the segment. Then we normalize and take to the power `transition_power`.
225
+ # Large values of transition power will lead to sharper transitions.
226
+ weight = th.cat([th.arange(1, segment // 2 + 1, device=device), th.arange(segment - segment // 2, 0, -1, device=device)])
227
+ assert len(weight) == segment
228
+ # If the overlap < 50%, this will translate to linear transition when
229
+ # transition_power is 1.
230
+ weight = (weight / weight.max()) ** transition_power
231
+ futures = []
232
+ for offset in offsets:
233
+ chunk = TensorChunk(mix, offset, segment)
234
+ future = pool.submit(apply_model, model, chunk, **kwargs)
235
+ futures.append((future, offset))
236
+ offset += segment
237
+ if progress:
238
+ futures = tqdm.tqdm(futures)
239
+ for future, offset in futures:
240
+ if set_progress_bar:
241
+ fut_length = len(futures) * bag_num * static_shifts
242
+ prog_bar += 1
243
+ set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
244
+ chunk_out = future.result()
245
+ chunk_length = chunk_out.shape[-1]
246
+ out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
247
+ sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
248
+ assert sum_weight.min() > 0
249
+ out /= sum_weight
250
+ return out
251
+ else:
252
+ if hasattr(model, "valid_length"):
253
+ valid_length = model.valid_length(length)
254
+ else:
255
+ valid_length = length
256
+ mix = tensor_chunk(mix)
257
+ padded_mix = mix.padded(valid_length).to(device)
258
+ with th.no_grad():
259
+ out = model(padded_mix)
260
+ return center_trim(out, length)
261
+
262
+
263
+ def demucs_segments(demucs_segment, demucs_model):
264
+
265
+ if demucs_segment == "Default":
266
+ segment = None
267
+ if isinstance(demucs_model, BagOfModels):
268
+ if segment is not None:
269
+ for sub in demucs_model.models:
270
+ sub.segment = segment
271
+ else:
272
+ if segment is not None:
273
+ sub.segment = segment
274
+ else:
275
+ try:
276
+ segment = int(demucs_segment)
277
+ if isinstance(demucs_model, BagOfModels):
278
+ if segment is not None:
279
+ for sub in demucs_model.models:
280
+ sub.segment = segment
281
+ else:
282
+ if segment is not None:
283
+ sub.segment = segment
284
+ except:
285
+ segment = None
286
+ if isinstance(demucs_model, BagOfModels):
287
+ if segment is not None:
288
+ for sub in demucs_model.models:
289
+ sub.segment = segment
290
+ else:
291
+ if segment is not None:
292
+ sub.segment = segment
293
+
294
+ return demucs_model
audio_separator/separator/uvr_lib_v5/demucs/demucs.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+
18
+
19
+ class BLSTM(nn.Module):
20
+ """
21
+ BiLSTM with same hidden units as input dim.
22
+ If `max_steps` is not None, input will be splitting in overlapping
23
+ chunks and the LSTM applied separately on each chunk.
24
+ """
25
+
26
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
27
+ super().__init__()
28
+ assert max_steps is None or max_steps % 4 == 0
29
+ self.max_steps = max_steps
30
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
31
+ self.linear = nn.Linear(2 * dim, dim)
32
+ self.skip = skip
33
+
34
+ def forward(self, x):
35
+ B, C, T = x.shape
36
+ y = x
37
+ framed = False
38
+ if self.max_steps is not None and T > self.max_steps:
39
+ width = self.max_steps
40
+ stride = width // 2
41
+ frames = unfold(x, width, stride)
42
+ nframes = frames.shape[2]
43
+ framed = True
44
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
45
+
46
+ x = x.permute(2, 0, 1)
47
+
48
+ x = self.lstm(x)[0]
49
+ x = self.linear(x)
50
+ x = x.permute(1, 2, 0)
51
+ if framed:
52
+ out = []
53
+ frames = x.reshape(B, -1, C, width)
54
+ limit = stride // 2
55
+ for k in range(nframes):
56
+ if k == 0:
57
+ out.append(frames[:, k, :, :-limit])
58
+ elif k == nframes - 1:
59
+ out.append(frames[:, k, :, limit:])
60
+ else:
61
+ out.append(frames[:, k, :, limit:-limit])
62
+ out = torch.cat(out, -1)
63
+ out = out[..., :T]
64
+ x = out
65
+ if self.skip:
66
+ x = x + y
67
+ return x
68
+
69
+
70
+ def rescale_conv(conv, reference):
71
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does."""
72
+ std = conv.weight.std().detach()
73
+ scale = (std / reference) ** 0.5
74
+ conv.weight.data /= scale
75
+ if conv.bias is not None:
76
+ conv.bias.data /= scale
77
+
78
+
79
+ def rescale_module(module, reference):
80
+ for sub in module.modules():
81
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
82
+ rescale_conv(sub, reference)
83
+
84
+
85
+ class LayerScale(nn.Module):
86
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
87
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
88
+ """
89
+
90
+ def __init__(self, channels: int, init: float = 0):
91
+ super().__init__()
92
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
93
+ self.scale.data[:] = init
94
+
95
+ def forward(self, x):
96
+ return self.scale[:, None] * x
97
+
98
+
99
+ class DConv(nn.Module):
100
+ """
101
+ New residual branches in each encoder layer.
102
+ This alternates dilated convolutions, potentially with LSTMs and attention.
103
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
104
+ e.g. of dim `channels // compress`.
105
+ """
106
+
107
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
108
+ """
109
+ Args:
110
+ channels: input/output channels for residual branch.
111
+ compress: amount of channel compression inside the branch.
112
+ depth: number of layers in the residual branch. Each layer has its own
113
+ projection, and potentially LSTM and attention.
114
+ init: initial scale for LayerNorm.
115
+ norm: use GroupNorm.
116
+ attn: use LocalAttention.
117
+ heads: number of heads for the LocalAttention.
118
+ ndecay: number of decay controls in the LocalAttention.
119
+ lstm: use LSTM.
120
+ gelu: Use GELU activation.
121
+ kernel: kernel size for the (dilated) convolutions.
122
+ dilate: if true, use dilation, increasing with the depth.
123
+ """
124
+
125
+ super().__init__()
126
+ assert kernel % 2 == 1
127
+ self.channels = channels
128
+ self.compress = compress
129
+ self.depth = abs(depth)
130
+ dilate = depth > 0
131
+
132
+ norm_fn: tp.Callable[[int], nn.Module]
133
+ norm_fn = lambda d: nn.Identity() # noqa
134
+ if norm:
135
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
136
+
137
+ hidden = int(channels / compress)
138
+
139
+ act: tp.Type[nn.Module]
140
+ if gelu:
141
+ act = nn.GELU
142
+ else:
143
+ act = nn.ReLU
144
+
145
+ self.layers = nn.ModuleList([])
146
+ for d in range(self.depth):
147
+ dilation = 2**d if dilate else 1
148
+ padding = dilation * (kernel // 2)
149
+ mods = [
150
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
151
+ norm_fn(hidden),
152
+ act(),
153
+ nn.Conv1d(hidden, 2 * channels, 1),
154
+ norm_fn(2 * channels),
155
+ nn.GLU(1),
156
+ LayerScale(channels, init),
157
+ ]
158
+ if attn:
159
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
160
+ if lstm:
161
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
162
+ layer = nn.Sequential(*mods)
163
+ self.layers.append(layer)
164
+
165
+ def forward(self, x):
166
+ for layer in self.layers:
167
+ x = x + layer(x)
168
+ return x
169
+
170
+
171
+ class LocalState(nn.Module):
172
+ """Local state allows to have attention based only on data (no positional embedding),
173
+ but while setting a constraint on the time window (e.g. decaying penalty term).
174
+
175
+ Also a failed experiments with trying to provide some frequency based attention.
176
+ """
177
+
178
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
179
+ super().__init__()
180
+ assert channels % heads == 0, (channels, heads)
181
+ self.heads = heads
182
+ self.nfreqs = nfreqs
183
+ self.ndecay = ndecay
184
+ self.content = nn.Conv1d(channels, channels, 1)
185
+ self.query = nn.Conv1d(channels, channels, 1)
186
+ self.key = nn.Conv1d(channels, channels, 1)
187
+ if nfreqs:
188
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
189
+ if ndecay:
190
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
191
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
192
+ self.query_decay.weight.data *= 0.01
193
+ assert self.query_decay.bias is not None # stupid type checker
194
+ self.query_decay.bias.data[:] = -2
195
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
196
+
197
+ def forward(self, x):
198
+ B, C, T = x.shape
199
+ heads = self.heads
200
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
201
+ # left index are keys, right index are queries
202
+ delta = indexes[:, None] - indexes[None, :]
203
+
204
+ queries = self.query(x).view(B, heads, -1, T)
205
+ keys = self.key(x).view(B, heads, -1, T)
206
+ # t are keys, s are queries
207
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
208
+ dots /= keys.shape[2] ** 0.5
209
+ if self.nfreqs:
210
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
211
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
212
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
213
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
214
+ if self.ndecay:
215
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
216
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
217
+ decay_q = torch.sigmoid(decay_q) / 2
218
+ decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
219
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
220
+
221
+ # Kill self reference.
222
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
223
+ weights = torch.softmax(dots, dim=2)
224
+
225
+ content = self.content(x).view(B, heads, -1, T)
226
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
227
+ if self.nfreqs:
228
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
229
+ result = torch.cat([result, time_sig], 2)
230
+ result = result.reshape(B, -1, T)
231
+ return x + self.proj(result)
232
+
233
+
234
+ class Demucs(nn.Module):
235
+ @capture_init
236
+ def __init__(
237
+ self,
238
+ sources,
239
+ # Channels
240
+ audio_channels=2,
241
+ channels=64,
242
+ growth=2.0,
243
+ # Main structure
244
+ depth=6,
245
+ rewrite=True,
246
+ lstm_layers=0,
247
+ # Convolutions
248
+ kernel_size=8,
249
+ stride=4,
250
+ context=1,
251
+ # Activations
252
+ gelu=True,
253
+ glu=True,
254
+ # Normalization
255
+ norm_starts=4,
256
+ norm_groups=4,
257
+ # DConv residual branch
258
+ dconv_mode=1,
259
+ dconv_depth=2,
260
+ dconv_comp=4,
261
+ dconv_attn=4,
262
+ dconv_lstm=4,
263
+ dconv_init=1e-4,
264
+ # Pre/post processing
265
+ normalize=True,
266
+ resample=True,
267
+ # Weight init
268
+ rescale=0.1,
269
+ # Metadata
270
+ samplerate=44100,
271
+ segment=4 * 10,
272
+ ):
273
+ """
274
+ Args:
275
+ sources (list[str]): list of source names
276
+ audio_channels (int): stereo or mono
277
+ channels (int): first convolution channels
278
+ depth (int): number of encoder/decoder layers
279
+ growth (float): multiply (resp divide) number of channels by that
280
+ for each layer of the encoder (resp decoder)
281
+ depth (int): number of layers in the encoder and in the decoder.
282
+ rewrite (bool): add 1x1 convolution to each layer.
283
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
284
+ by default, as this is now replaced by the smaller and faster small LSTMs
285
+ in the DConv branches.
286
+ kernel_size (int): kernel size for convolutions
287
+ stride (int): stride for convolutions
288
+ context (int): kernel size of the convolution in the
289
+ decoder before the transposed convolution. If > 1,
290
+ will provide some context from neighboring time steps.
291
+ gelu: use GELU activation function.
292
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
293
+ norm_starts: layer at which group norm starts being used.
294
+ decoder layers are numbered in reverse order.
295
+ norm_groups: number of groups for group norm.
296
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
297
+ dconv_depth: depth of residual DConv branch.
298
+ dconv_comp: compression of DConv branch.
299
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
300
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
301
+ dconv_init: initial scale for the DConv branch LayerScale.
302
+ normalize (bool): normalizes the input audio on the fly, and scales back
303
+ the output by the same amount.
304
+ resample (bool): upsample x2 the input and downsample /2 the output.
305
+ rescale (int): rescale initial weights of convolutions
306
+ to get their standard deviation closer to `rescale`.
307
+ samplerate (int): stored as meta information for easing
308
+ future evaluations of the model.
309
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
310
+ This is used by `demucs.apply.apply_model`.
311
+ """
312
+
313
+ super().__init__()
314
+ self.audio_channels = audio_channels
315
+ self.sources = sources
316
+ self.kernel_size = kernel_size
317
+ self.context = context
318
+ self.stride = stride
319
+ self.depth = depth
320
+ self.resample = resample
321
+ self.channels = channels
322
+ self.normalize = normalize
323
+ self.samplerate = samplerate
324
+ self.segment = segment
325
+ self.encoder = nn.ModuleList()
326
+ self.decoder = nn.ModuleList()
327
+ self.skip_scales = nn.ModuleList()
328
+
329
+ if glu:
330
+ activation = nn.GLU(dim=1)
331
+ ch_scale = 2
332
+ else:
333
+ activation = nn.ReLU()
334
+ ch_scale = 1
335
+ if gelu:
336
+ act2 = nn.GELU
337
+ else:
338
+ act2 = nn.ReLU
339
+
340
+ in_channels = audio_channels
341
+ padding = 0
342
+ for index in range(depth):
343
+ norm_fn = lambda d: nn.Identity() # noqa
344
+ if index >= norm_starts:
345
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
346
+
347
+ encode = []
348
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
349
+ attn = index >= dconv_attn
350
+ lstm = index >= dconv_lstm
351
+ if dconv_mode & 1:
352
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
353
+ if rewrite:
354
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
355
+ self.encoder.append(nn.Sequential(*encode))
356
+
357
+ decode = []
358
+ if index > 0:
359
+ out_channels = in_channels
360
+ else:
361
+ out_channels = len(self.sources) * audio_channels
362
+ if rewrite:
363
+ decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
364
+ if dconv_mode & 2:
365
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
366
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
367
+ if index > 0:
368
+ decode += [norm_fn(out_channels), act2()]
369
+ self.decoder.insert(0, nn.Sequential(*decode))
370
+ in_channels = channels
371
+ channels = int(growth * channels)
372
+
373
+ channels = in_channels
374
+ if lstm_layers:
375
+ self.lstm = BLSTM(channels, lstm_layers)
376
+ else:
377
+ self.lstm = None
378
+
379
+ if rescale:
380
+ rescale_module(self, reference=rescale)
381
+
382
+ def valid_length(self, length):
383
+ """
384
+ Return the nearest valid length to use with the model so that
385
+ there is no time steps left over in a convolution, e.g. for all
386
+ layers, size of the input - kernel_size % stride = 0.
387
+
388
+ Note that input are automatically padded if necessary to ensure that the output
389
+ has the same length as the input.
390
+ """
391
+ if self.resample:
392
+ length *= 2
393
+
394
+ for _ in range(self.depth):
395
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
396
+ length = max(1, length)
397
+
398
+ for idx in range(self.depth):
399
+ length = (length - 1) * self.stride + self.kernel_size
400
+
401
+ if self.resample:
402
+ length = math.ceil(length / 2)
403
+ return int(length)
404
+
405
+ def forward(self, mix):
406
+ x = mix
407
+ length = x.shape[-1]
408
+
409
+ if self.normalize:
410
+ mono = mix.mean(dim=1, keepdim=True)
411
+ mean = mono.mean(dim=-1, keepdim=True)
412
+ std = mono.std(dim=-1, keepdim=True)
413
+ x = (x - mean) / (1e-5 + std)
414
+ else:
415
+ mean = 0
416
+ std = 1
417
+
418
+ delta = self.valid_length(length) - length
419
+ x = F.pad(x, (delta // 2, delta - delta // 2))
420
+
421
+ if self.resample:
422
+ x = julius.resample_frac(x, 1, 2)
423
+
424
+ saved = []
425
+ for encode in self.encoder:
426
+ x = encode(x)
427
+ saved.append(x)
428
+
429
+ if self.lstm:
430
+ x = self.lstm(x)
431
+
432
+ for decode in self.decoder:
433
+ skip = saved.pop(-1)
434
+ skip = center_trim(skip, x)
435
+ x = decode(x + skip)
436
+
437
+ if self.resample:
438
+ x = julius.resample_frac(x, 2, 1)
439
+ x = x * std + mean
440
+ x = center_trim(x, length)
441
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
442
+ return x
443
+
444
+ def load_state_dict(self, state, strict=True):
445
+ # fix a mismatch with previous generation Demucs models.
446
+ for idx in range(self.depth):
447
+ for a in ["encoder", "decoder"]:
448
+ for b in ["bias", "weight"]:
449
+ new = f"{a}.{idx}.3.{b}"
450
+ old = f"{a}.{idx}.2.{b}"
451
+ if old in state and new not in state:
452
+ state[new] = state.pop(old)
453
+ super().load_state_dict(state, strict=strict)
audio_separator/separator/uvr_lib_v5/demucs/filtering.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ def atan2(y, x):
9
+ r"""Element-wise arctangent function of y/x.
10
+ Returns a new tensor with signed angles in radians.
11
+ It is an alternative implementation of torch.atan2
12
+
13
+ Args:
14
+ y (Tensor): First input tensor
15
+ x (Tensor): Second input tensor [shape=y.shape]
16
+
17
+ Returns:
18
+ Tensor: [shape=y.shape].
19
+ """
20
+ pi = 2 * torch.asin(torch.tensor(1.0))
21
+ x += ((x == 0) & (y == 0)) * 1.0
22
+ out = torch.atan(y / x)
23
+ out += ((y >= 0) & (x < 0)) * pi
24
+ out -= ((y < 0) & (x < 0)) * pi
25
+ out *= 1 - ((y > 0) & (x == 0)) * 1.0
26
+ out += ((y > 0) & (x == 0)) * (pi / 2)
27
+ out *= 1 - ((y < 0) & (x == 0)) * 1.0
28
+ out += ((y < 0) & (x == 0)) * (-pi / 2)
29
+ return out
30
+
31
+
32
+ # Define basic complex operations on torch.Tensor objects whose last dimension
33
+ # consists in the concatenation of the real and imaginary parts.
34
+
35
+
36
+ def _norm(x: torch.Tensor) -> torch.Tensor:
37
+ r"""Computes the norm value of a torch Tensor, assuming that it
38
+ comes as real and imaginary part in its last dimension.
39
+
40
+ Args:
41
+ x (Tensor): Input Tensor of shape [shape=(..., 2)]
42
+
43
+ Returns:
44
+ Tensor: shape as x excluding the last dimension.
45
+ """
46
+ return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
47
+
48
+
49
+ def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
50
+ """Element-wise multiplication of two complex Tensors described
51
+ through their real and imaginary parts.
52
+ The result is added to the `out` tensor"""
53
+
54
+ # check `out` and allocate it if needed
55
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
56
+ if out is None or out.shape != target_shape:
57
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
58
+ if out is a:
59
+ real_a = a[..., 0]
60
+ out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
61
+ out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
62
+ else:
63
+ out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
64
+ out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
65
+ return out
66
+
67
+
68
+ def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
69
+ """Element-wise multiplication of two complex Tensors described
70
+ through their real and imaginary parts
71
+ can work in place in case out is a only"""
72
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
73
+ if out is None or out.shape != target_shape:
74
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
75
+ if out is a:
76
+ real_a = a[..., 0]
77
+ out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
78
+ out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
79
+ else:
80
+ out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
81
+ out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
82
+ return out
83
+
84
+
85
+ def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
86
+ """Element-wise multiplicative inverse of a Tensor with complex
87
+ entries described through their real and imaginary parts.
88
+ can work in place in case out is z"""
89
+ ez = _norm(z)
90
+ if out is None or out.shape != z.shape:
91
+ out = torch.zeros_like(z)
92
+ out[..., 0] = z[..., 0] / ez
93
+ out[..., 1] = -z[..., 1] / ez
94
+ return out
95
+
96
+
97
+ def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
98
+ """Element-wise complex conjugate of a Tensor with complex entries
99
+ described through their real and imaginary parts.
100
+ can work in place in case out is z"""
101
+ if out is None or out.shape != z.shape:
102
+ out = torch.zeros_like(z)
103
+ out[..., 0] = z[..., 0]
104
+ out[..., 1] = -z[..., 1]
105
+ return out
106
+
107
+
108
+ def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
109
+ """
110
+ Invert 1x1 or 2x2 matrices
111
+
112
+ Will generate errors if the matrices are singular: user must handle this
113
+ through his own regularization schemes.
114
+
115
+ Args:
116
+ M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
117
+ matrices to invert: must be square along dimensions -3 and -2
118
+
119
+ Returns:
120
+ invM (Tensor): [shape=M.shape]
121
+ inverses of M
122
+ """
123
+ nb_channels = M.shape[-2]
124
+
125
+ if out is None or out.shape != M.shape:
126
+ out = torch.empty_like(M)
127
+
128
+ if nb_channels == 1:
129
+ # scalar case
130
+ out = _inv(M, out)
131
+ elif nb_channels == 2:
132
+ # two channels case: analytical expression
133
+
134
+ # first compute the determinent
135
+ det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
136
+ det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
137
+ # invert it
138
+ invDet = _inv(det)
139
+
140
+ # then fill out the matrix with the inverse
141
+ out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
142
+ out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
143
+ out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
144
+ out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
145
+ else:
146
+ raise Exception("Only 2 channels are supported for the torch version.")
147
+ return out
148
+
149
+
150
+ # Now define the signal-processing low-level functions used by the Separator
151
+
152
+
153
+ def expectation_maximization(y: torch.Tensor, x: torch.Tensor, iterations: int = 2, eps: float = 1e-10, batch_size: int = 200):
154
+ r"""Expectation maximization algorithm, for refining source separation
155
+ estimates.
156
+
157
+ This algorithm allows to make source separation results better by
158
+ enforcing multichannel consistency for the estimates. This usually means
159
+ a better perceptual quality in terms of spatial artifacts.
160
+
161
+ The implementation follows the details presented in [1]_, taking
162
+ inspiration from the original EM algorithm proposed in [2]_ and its
163
+ weighted refinement proposed in [3]_, [4]_.
164
+ It works by iteratively:
165
+
166
+ * Re-estimate source parameters (power spectral densities and spatial
167
+ covariance matrices) through :func:`get_local_gaussian_model`.
168
+
169
+ * Separate again the mixture with the new parameters by first computing
170
+ the new modelled mixture covariance matrices with :func:`get_mix_model`,
171
+ prepare the Wiener filters through :func:`wiener_gain` and apply them
172
+ with :func:`apply_filter``.
173
+
174
+ References
175
+ ----------
176
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
177
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
178
+ on deep neural networks through data augmentation and network
179
+ blending." 2017 IEEE International Conference on Acoustics, Speech
180
+ and Signal Processing (ICASSP). IEEE, 2017.
181
+
182
+ .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
183
+ reverberant audio source separation using a full-rank spatial
184
+ covariance model." IEEE Transactions on Audio, Speech, and Language
185
+ Processing 18.7 (2010): 1830-1840.
186
+
187
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
188
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
189
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
190
+
191
+ .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
192
+ separation with deep neural networks." 2016 24th European Signal
193
+ Processing Conference (EUSIPCO). IEEE, 2016.
194
+
195
+ .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
196
+ source separation." IEEE Transactions on Signal Processing
197
+ 62.16 (2014): 4298-4310.
198
+
199
+ Args:
200
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
201
+ initial estimates for the sources
202
+ x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
203
+ complex STFT of the mixture signal
204
+ iterations (int): [scalar]
205
+ number of iterations for the EM algorithm.
206
+ eps (float or None): [scalar]
207
+ The epsilon value to use for regularization and filters.
208
+
209
+ Returns:
210
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
211
+ estimated sources after iterations
212
+ v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
213
+ estimated power spectral densities
214
+ R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
215
+ estimated spatial covariance matrices
216
+
217
+ Notes:
218
+ * You need an initial estimate for the sources to apply this
219
+ algorithm. This is precisely what the :func:`wiener` function does.
220
+ * This algorithm *is not* an implementation of the "exact" EM
221
+ proposed in [1]_. In particular, it does compute the posterior
222
+ covariance matrices the same (exact) way. Instead, it uses the
223
+ simplified approximate scheme initially proposed in [5]_ and further
224
+ refined in [3]_, [4]_, that boils down to just take the empirical
225
+ covariance of the recent source estimates, followed by a weighted
226
+ average for the update of the spatial covariance matrix. It has been
227
+ empirically demonstrated that this simplified algorithm is more
228
+ robust for music separation.
229
+
230
+ Warning:
231
+ It is *very* important to make sure `x.dtype` is `torch.float64`
232
+ if you want double precision, because this function will **not**
233
+ do such conversion for you from `torch.complex32`, in case you want the
234
+ smaller RAM usage on purpose.
235
+
236
+ It is usually always better in terms of quality to have double
237
+ precision, by e.g. calling :func:`expectation_maximization`
238
+ with ``x.to(torch.float64)``.
239
+ """
240
+ # dimensions
241
+ (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
242
+ nb_sources = y.shape[-1]
243
+
244
+ regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
245
+ regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
246
+
247
+ # allocate the spatial covariance matrices
248
+ R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
249
+ weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
250
+
251
+ v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
252
+ for it in range(iterations):
253
+ # constructing the mixture covariance matrix. Doing it with a loop
254
+ # to avoid storing anytime in RAM the whole 6D tensor
255
+
256
+ # update the PSD as the average spectrogram over channels
257
+ v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
258
+
259
+ # update spatial covariance matrices (weighted update)
260
+ for j in range(nb_sources):
261
+ R[j] = torch.tensor(0.0, device=x.device)
262
+ weight = torch.tensor(eps, device=x.device)
263
+ pos: int = 0
264
+ batch_size = batch_size if batch_size else nb_frames
265
+ while pos < nb_frames:
266
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
267
+ pos = int(t[-1]) + 1
268
+
269
+ R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
270
+ weight = weight + torch.sum(v[t, ..., j], dim=0)
271
+ R[j] = R[j] / weight[..., None, None, None]
272
+ weight = torch.zeros_like(weight)
273
+
274
+ # cloning y if we track gradient, because we're going to update it
275
+ if y.requires_grad:
276
+ y = y.clone()
277
+
278
+ pos = 0
279
+ while pos < nb_frames:
280
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
281
+ pos = int(t[-1]) + 1
282
+
283
+ y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
284
+
285
+ # compute mix covariance matrix
286
+ Cxx = regularization
287
+ for j in range(nb_sources):
288
+ Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
289
+
290
+ # invert it
291
+ inv_Cxx = _invert(Cxx)
292
+
293
+ # separate the sources
294
+ for j in range(nb_sources):
295
+
296
+ # create a wiener gain for this source
297
+ gain = torch.zeros_like(inv_Cxx)
298
+
299
+ # computes multichannel Wiener gain as v_j R_j inv_Cxx
300
+ indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
301
+ for index in indices:
302
+ gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
303
+ gain = gain * v[t, ..., None, None, None, j]
304
+
305
+ # apply it to the mixture
306
+ for i in range(nb_channels):
307
+ y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
308
+
309
+ return y, v, R
310
+
311
+
312
+ def wiener(targets_spectrograms: torch.Tensor, mix_stft: torch.Tensor, iterations: int = 1, softmask: bool = False, residual: bool = False, scale_factor: float = 10.0, eps: float = 1e-10):
313
+ """Wiener-based separation for multichannel audio.
314
+
315
+ The method uses the (possibly multichannel) spectrograms of the
316
+ sources to separate the (complex) Short Term Fourier Transform of the
317
+ mix. Separation is done in a sequential way by:
318
+
319
+ * Getting an initial estimate. This can be done in two ways: either by
320
+ directly using the spectrograms with the mixture phase, or
321
+ by using a softmasking strategy. This initial phase is controlled
322
+ by the `softmask` flag.
323
+
324
+ * If required, adding an additional residual target as the mix minus
325
+ all targets.
326
+
327
+ * Refinining these initial estimates through a call to
328
+ :func:`expectation_maximization` if the number of iterations is nonzero.
329
+
330
+ This implementation also allows to specify the epsilon value used for
331
+ regularization. It is based on [1]_, [2]_, [3]_, [4]_.
332
+
333
+ References
334
+ ----------
335
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
336
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
337
+ on deep neural networks through data augmentation and network
338
+ blending." 2017 IEEE International Conference on Acoustics, Speech
339
+ and Signal Processing (ICASSP). IEEE, 2017.
340
+
341
+ .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
342
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
343
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
344
+
345
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
346
+ separation with deep neural networks." 2016 24th European Signal
347
+ Processing Conference (EUSIPCO). IEEE, 2016.
348
+
349
+ .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
350
+ source separation." IEEE Transactions on Signal Processing
351
+ 62.16 (2014): 4298-4310.
352
+
353
+ Args:
354
+ targets_spectrograms (Tensor): spectrograms of the sources
355
+ [shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
356
+ This is a nonnegative tensor that is
357
+ usually the output of the actual separation method of the user. The
358
+ spectrograms may be mono, but they need to be 4-dimensional in all
359
+ cases.
360
+ mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
361
+ STFT of the mixture signal.
362
+ iterations (int): [scalar]
363
+ number of iterations for the EM algorithm
364
+ softmask (bool): Describes how the initial estimates are obtained.
365
+ * if `False`, then the mixture phase will directly be used with the
366
+ spectrogram as initial estimates.
367
+ * if `True`, initial estimates are obtained by multiplying the
368
+ complex mix element-wise with the ratio of each target spectrogram
369
+ with the sum of them all. This strategy is better if the model are
370
+ not really good, and worse otherwise.
371
+ residual (bool): if `True`, an additional target is created, which is
372
+ equal to the mixture minus the other targets, before application of
373
+ expectation maximization
374
+ eps (float): Epsilon value to use for computing the separations.
375
+ This is used whenever division with a model energy is
376
+ performed, i.e. when softmasking and when iterating the EM.
377
+ It can be understood as the energy of the additional white noise
378
+ that is taken out when separating.
379
+
380
+ Returns:
381
+ Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
382
+ STFT of estimated sources
383
+
384
+ Notes:
385
+ * Be careful that you need *magnitude spectrogram estimates* for the
386
+ case `softmask==False`.
387
+ * `softmask=False` is recommended
388
+ * The epsilon value will have a huge impact on performance. If it's
389
+ large, only the parts of the signal with a significant energy will
390
+ be kept in the sources. This epsilon then directly controls the
391
+ energy of the reconstruction error.
392
+
393
+ Warning:
394
+ As in :func:`expectation_maximization`, we recommend converting the
395
+ mixture `x` to double precision `torch.float64` *before* calling
396
+ :func:`wiener`.
397
+ """
398
+ if softmask:
399
+ # if we use softmask, we compute the ratio mask for all targets and
400
+ # multiply by the mix stft
401
+ y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
402
+ else:
403
+ # otherwise, we just multiply the targets spectrograms with mix phase
404
+ # we tacitly assume that we have magnitude estimates.
405
+ angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
406
+ nb_sources = targets_spectrograms.shape[-1]
407
+ y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
408
+ y[..., 0, :] = targets_spectrograms * torch.cos(angle)
409
+ y[..., 1, :] = targets_spectrograms * torch.sin(angle)
410
+
411
+ if residual:
412
+ # if required, adding an additional target as the mix minus
413
+ # available targets
414
+ y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
415
+
416
+ if iterations == 0:
417
+ return y
418
+
419
+ # we need to refine the estimates. Scales down the estimates for
420
+ # numerical stability
421
+ max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
422
+
423
+ mix_stft = mix_stft / max_abs
424
+ y = y / max_abs
425
+
426
+ # call expectation maximization
427
+ y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
428
+
429
+ # scale estimates up again
430
+ y = y * max_abs
431
+ return y
432
+
433
+
434
+ def _covariance(y_j):
435
+ """
436
+ Compute the empirical covariance for a source.
437
+
438
+ Args:
439
+ y_j (Tensor): complex stft of the source.
440
+ [shape=(nb_frames, nb_bins, nb_channels, 2)].
441
+
442
+ Returns:
443
+ Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
444
+ just y_j * conj(y_j.T): empirical covariance for each TF bin.
445
+ """
446
+ (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
447
+ Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
448
+ indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
449
+ for index in indices:
450
+ Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
451
+ return Cj
audio_separator/separator/uvr_lib_v5/demucs/hdemucs.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from .filtering import wiener
16
+ from .demucs import DConv, rescale_module
17
+ from .states import capture_init
18
+ from .spec import spectro, ispectro
19
+
20
+
21
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0):
22
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
23
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
24
+ x0 = x
25
+ length = x.shape[-1]
26
+ padding_left, padding_right = paddings
27
+ if mode == "reflect":
28
+ max_pad = max(padding_left, padding_right)
29
+ if length <= max_pad:
30
+ extra_pad = max_pad - length + 1
31
+ extra_pad_right = min(padding_right, extra_pad)
32
+ extra_pad_left = extra_pad - extra_pad_right
33
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
34
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
35
+ out = F.pad(x, paddings, mode, value)
36
+ assert out.shape[-1] == length + padding_left + padding_right
37
+ assert (out[..., padding_left : padding_left + length] == x0).all()
38
+ return out
39
+
40
+
41
+ class ScaledEmbedding(nn.Module):
42
+ """
43
+ Boost learning rate for embeddings (with `scale`).
44
+ Also, can make embeddings continuous with `smooth`.
45
+ """
46
+
47
+ def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth=False):
48
+ super().__init__()
49
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
50
+ if smooth:
51
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
52
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
53
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
54
+ self.embedding.weight.data[:] = weight
55
+ self.embedding.weight.data /= scale
56
+ self.scale = scale
57
+
58
+ @property
59
+ def weight(self):
60
+ return self.embedding.weight * self.scale
61
+
62
+ def forward(self, x):
63
+ out = self.embedding(x) * self.scale
64
+ return out
65
+
66
+
67
+ class HEncLayer(nn.Module):
68
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
69
+ """Encoder layer. This used both by the time and the frequency branch.
70
+
71
+ Args:
72
+ chin: number of input channels.
73
+ chout: number of output channels.
74
+ norm_groups: number of groups for group norm.
75
+ empty: used to make a layer with just the first conv. this is used
76
+ before merging the time and freq. branches.
77
+ freq: this is acting on frequencies.
78
+ dconv: insert DConv residual branches.
79
+ norm: use GroupNorm.
80
+ context: context size for the 1x1 conv.
81
+ dconv_kw: list of kwargs for the DConv class.
82
+ pad: pad the input. Padding is done so that the output size is
83
+ always the input size / stride.
84
+ rewrite: add 1x1 conv at the end of the layer.
85
+ """
86
+ super().__init__()
87
+ norm_fn = lambda d: nn.Identity() # noqa
88
+ if norm:
89
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
90
+ if pad:
91
+ pad = kernel_size // 4
92
+ else:
93
+ pad = 0
94
+ klass = nn.Conv1d
95
+ self.freq = freq
96
+ self.kernel_size = kernel_size
97
+ self.stride = stride
98
+ self.empty = empty
99
+ self.norm = norm
100
+ self.pad = pad
101
+ if freq:
102
+ kernel_size = [kernel_size, 1]
103
+ stride = [stride, 1]
104
+ pad = [pad, 0]
105
+ klass = nn.Conv2d
106
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
107
+ if self.empty:
108
+ return
109
+ self.norm1 = norm_fn(chout)
110
+ self.rewrite = None
111
+ if rewrite:
112
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
113
+ self.norm2 = norm_fn(2 * chout)
114
+
115
+ self.dconv = None
116
+ if dconv:
117
+ self.dconv = DConv(chout, **dconv_kw)
118
+
119
+ def forward(self, x, inject=None):
120
+ """
121
+ `inject` is used to inject the result from the time branch into the frequency branch,
122
+ when both have the same stride.
123
+ """
124
+ if not self.freq and x.dim() == 4:
125
+ B, C, Fr, T = x.shape
126
+ x = x.view(B, -1, T)
127
+
128
+ if not self.freq:
129
+ le = x.shape[-1]
130
+ if not le % self.stride == 0:
131
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
132
+ y = self.conv(x)
133
+ if self.empty:
134
+ return y
135
+ if inject is not None:
136
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
137
+ if inject.dim() == 3 and y.dim() == 4:
138
+ inject = inject[:, :, None]
139
+ y = y + inject
140
+ y = F.gelu(self.norm1(y))
141
+ if self.dconv:
142
+ if self.freq:
143
+ B, C, Fr, T = y.shape
144
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
145
+ y = self.dconv(y)
146
+ if self.freq:
147
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
148
+ if self.rewrite:
149
+ z = self.norm2(self.rewrite(y))
150
+ z = F.glu(z, dim=1)
151
+ else:
152
+ z = y
153
+ return z
154
+
155
+
156
+ class MultiWrap(nn.Module):
157
+ """
158
+ Takes one layer and replicate it N times. each replica will act
159
+ on a frequency band. All is done so that if the N replica have the same weights,
160
+ then this is exactly equivalent to applying the original module on all frequencies.
161
+
162
+ This is a bit over-engineered to avoid edge artifacts when splitting
163
+ the frequency bands, but it is possible the naive implementation would work as well...
164
+ """
165
+
166
+ def __init__(self, layer, split_ratios):
167
+ """
168
+ Args:
169
+ layer: module to clone, must be either HEncLayer or HDecLayer.
170
+ split_ratios: list of float indicating which ratio to keep for each band.
171
+ """
172
+ super().__init__()
173
+ self.split_ratios = split_ratios
174
+ self.layers = nn.ModuleList()
175
+ self.conv = isinstance(layer, HEncLayer)
176
+ assert not layer.norm
177
+ assert layer.freq
178
+ assert layer.pad
179
+ if not self.conv:
180
+ assert not layer.context_freq
181
+ for k in range(len(split_ratios) + 1):
182
+ lay = deepcopy(layer)
183
+ if self.conv:
184
+ lay.conv.padding = (0, 0)
185
+ else:
186
+ lay.pad = False
187
+ for m in lay.modules():
188
+ if hasattr(m, "reset_parameters"):
189
+ m.reset_parameters()
190
+ self.layers.append(lay)
191
+
192
+ def forward(self, x, skip=None, length=None):
193
+ B, C, Fr, T = x.shape
194
+
195
+ ratios = list(self.split_ratios) + [1]
196
+ start = 0
197
+ outs = []
198
+ for ratio, layer in zip(ratios, self.layers):
199
+ if self.conv:
200
+ pad = layer.kernel_size // 4
201
+ if ratio == 1:
202
+ limit = Fr
203
+ frames = -1
204
+ else:
205
+ limit = int(round(Fr * ratio))
206
+ le = limit - start
207
+ if start == 0:
208
+ le += pad
209
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
210
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
211
+ if start == 0:
212
+ limit -= pad
213
+ assert limit - start > 0, (limit, start)
214
+ assert limit <= Fr, (limit, Fr)
215
+ y = x[:, :, start:limit, :]
216
+ if start == 0:
217
+ y = F.pad(y, (0, 0, pad, 0))
218
+ if ratio == 1:
219
+ y = F.pad(y, (0, 0, 0, pad))
220
+ outs.append(layer(y))
221
+ start = limit - layer.kernel_size + layer.stride
222
+ else:
223
+ if ratio == 1:
224
+ limit = Fr
225
+ else:
226
+ limit = int(round(Fr * ratio))
227
+ last = layer.last
228
+ layer.last = True
229
+
230
+ y = x[:, :, start:limit]
231
+ s = skip[:, :, start:limit]
232
+ out, _ = layer(y, s, None)
233
+ if outs:
234
+ outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
235
+ out = out[:, :, layer.stride :]
236
+ if ratio == 1:
237
+ out = out[:, :, : -layer.stride // 2, :]
238
+ if start == 0:
239
+ out = out[:, :, layer.stride // 2 :, :]
240
+ outs.append(out)
241
+ layer.last = last
242
+ start = limit
243
+ out = torch.cat(outs, dim=2)
244
+ if not self.conv and not last:
245
+ out = F.gelu(out)
246
+ if self.conv:
247
+ return out
248
+ else:
249
+ return out, None
250
+
251
+
252
+ class HDecLayer(nn.Module):
253
+ def __init__(
254
+ self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True
255
+ ):
256
+ """
257
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
258
+ """
259
+ super().__init__()
260
+ norm_fn = lambda d: nn.Identity() # noqa
261
+ if norm:
262
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
263
+ if pad:
264
+ pad = kernel_size // 4
265
+ else:
266
+ pad = 0
267
+ self.pad = pad
268
+ self.last = last
269
+ self.freq = freq
270
+ self.chin = chin
271
+ self.empty = empty
272
+ self.stride = stride
273
+ self.kernel_size = kernel_size
274
+ self.norm = norm
275
+ self.context_freq = context_freq
276
+ klass = nn.Conv1d
277
+ klass_tr = nn.ConvTranspose1d
278
+ if freq:
279
+ kernel_size = [kernel_size, 1]
280
+ stride = [stride, 1]
281
+ klass = nn.Conv2d
282
+ klass_tr = nn.ConvTranspose2d
283
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
284
+ self.norm2 = norm_fn(chout)
285
+ if self.empty:
286
+ return
287
+ self.rewrite = None
288
+ if rewrite:
289
+ if context_freq:
290
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
291
+ else:
292
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
293
+ self.norm1 = norm_fn(2 * chin)
294
+
295
+ self.dconv = None
296
+ if dconv:
297
+ self.dconv = DConv(chin, **dconv_kw)
298
+
299
+ def forward(self, x, skip, length):
300
+ if self.freq and x.dim() == 3:
301
+ B, C, T = x.shape
302
+ x = x.view(B, self.chin, -1, T)
303
+
304
+ if not self.empty:
305
+ x = x + skip
306
+
307
+ if self.rewrite:
308
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
309
+ else:
310
+ y = x
311
+ if self.dconv:
312
+ if self.freq:
313
+ B, C, Fr, T = y.shape
314
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
315
+ y = self.dconv(y)
316
+ if self.freq:
317
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
318
+ else:
319
+ y = x
320
+ assert skip is None
321
+ z = self.norm2(self.conv_tr(y))
322
+ if self.freq:
323
+ if self.pad:
324
+ z = z[..., self.pad : -self.pad, :]
325
+ else:
326
+ z = z[..., self.pad : self.pad + length]
327
+ assert z.shape[-1] == length, (z.shape[-1], length)
328
+ if not self.last:
329
+ z = F.gelu(z)
330
+ return z, y
331
+
332
+
333
+ class HDemucs(nn.Module):
334
+ """
335
+ Spectrogram and hybrid Demucs model.
336
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
337
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
338
+ Frequency layers can still access information across time steps thanks to the DConv residual.
339
+
340
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
341
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
342
+
343
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
344
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
345
+ Open Unmix implementation [Stoter et al. 2019].
346
+
347
+ The loss is always on the temporal domain, by backpropagating through the above
348
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
349
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
350
+ contribution, without changing the one from the waveform, which will lead to worse performance.
351
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
352
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
353
+ hybrid models.
354
+
355
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
356
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
357
+
358
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
359
+ """
360
+
361
+ @capture_init
362
+ def __init__(
363
+ self,
364
+ sources,
365
+ # Channels
366
+ audio_channels=2,
367
+ channels=48,
368
+ channels_time=None,
369
+ growth=2,
370
+ # STFT
371
+ nfft=4096,
372
+ wiener_iters=0,
373
+ end_iters=0,
374
+ wiener_residual=False,
375
+ cac=True,
376
+ # Main structure
377
+ depth=6,
378
+ rewrite=True,
379
+ hybrid=True,
380
+ hybrid_old=False,
381
+ # Frequency branch
382
+ multi_freqs=None,
383
+ multi_freqs_depth=2,
384
+ freq_emb=0.2,
385
+ emb_scale=10,
386
+ emb_smooth=True,
387
+ # Convolutions
388
+ kernel_size=8,
389
+ time_stride=2,
390
+ stride=4,
391
+ context=1,
392
+ context_enc=0,
393
+ # Normalization
394
+ norm_starts=4,
395
+ norm_groups=4,
396
+ # DConv residual branch
397
+ dconv_mode=1,
398
+ dconv_depth=2,
399
+ dconv_comp=4,
400
+ dconv_attn=4,
401
+ dconv_lstm=4,
402
+ dconv_init=1e-4,
403
+ # Weight init
404
+ rescale=0.1,
405
+ # Metadata
406
+ samplerate=44100,
407
+ segment=4 * 10,
408
+ ):
409
+ """
410
+ Args:
411
+ sources (list[str]): list of source names.
412
+ audio_channels (int): input/output audio channels.
413
+ channels (int): initial number of hidden channels.
414
+ channels_time: if not None, use a different `channels` value for the time branch.
415
+ growth: increase the number of hidden channels by this factor at each layer.
416
+ nfft: number of fft bins. Note that changing this require careful computation of
417
+ various shape parameters and will not work out of the box for hybrid models.
418
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
419
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
420
+ wiener_residual: add residual source before wiener filtering.
421
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
422
+ in input and output. no further processing is done before ISTFT.
423
+ depth (int): number of layers in the encoder and in the decoder.
424
+ rewrite (bool): add 1x1 convolution to each layer.
425
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
426
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
427
+ this bug to avoid retraining them.
428
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
429
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
430
+ layers will be wrapped.
431
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
432
+ the actual value controls the weight of the embedding.
433
+ emb_scale: equivalent to scaling the embedding learning rate
434
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
435
+ kernel_size: kernel_size for encoder and decoder layers.
436
+ stride: stride for encoder and decoder layers.
437
+ time_stride: stride for the final time layer, after the merge.
438
+ context: context for 1x1 conv in the decoder.
439
+ context_enc: context for 1x1 conv in the encoder.
440
+ norm_starts: layer at which group norm starts being used.
441
+ decoder layers are numbered in reverse order.
442
+ norm_groups: number of groups for group norm.
443
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
444
+ dconv_depth: depth of residual DConv branch.
445
+ dconv_comp: compression of DConv branch.
446
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
447
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
448
+ dconv_init: initial scale for the DConv branch LayerScale.
449
+ rescale: weight recaling trick
450
+
451
+ """
452
+ super().__init__()
453
+
454
+ self.cac = cac
455
+ self.wiener_residual = wiener_residual
456
+ self.audio_channels = audio_channels
457
+ self.sources = sources
458
+ self.kernel_size = kernel_size
459
+ self.context = context
460
+ self.stride = stride
461
+ self.depth = depth
462
+ self.channels = channels
463
+ self.samplerate = samplerate
464
+ self.segment = segment
465
+
466
+ self.nfft = nfft
467
+ self.hop_length = nfft // 4
468
+ self.wiener_iters = wiener_iters
469
+ self.end_iters = end_iters
470
+ self.freq_emb = None
471
+ self.hybrid = hybrid
472
+ self.hybrid_old = hybrid_old
473
+ if hybrid_old:
474
+ assert hybrid, "hybrid_old must come with hybrid=True"
475
+ if hybrid:
476
+ assert wiener_iters == end_iters
477
+
478
+ self.encoder = nn.ModuleList()
479
+ self.decoder = nn.ModuleList()
480
+
481
+ if hybrid:
482
+ self.tencoder = nn.ModuleList()
483
+ self.tdecoder = nn.ModuleList()
484
+
485
+ chin = audio_channels
486
+ chin_z = chin # number of channels for the freq branch
487
+ if self.cac:
488
+ chin_z *= 2
489
+ chout = channels_time or channels
490
+ chout_z = channels
491
+ freqs = nfft // 2
492
+
493
+ for index in range(depth):
494
+ lstm = index >= dconv_lstm
495
+ attn = index >= dconv_attn
496
+ norm = index >= norm_starts
497
+ freq = freqs > 1
498
+ stri = stride
499
+ ker = kernel_size
500
+ if not freq:
501
+ assert freqs == 1
502
+ ker = time_stride * 2
503
+ stri = time_stride
504
+
505
+ pad = True
506
+ last_freq = False
507
+ if freq and freqs <= kernel_size:
508
+ ker = freqs
509
+ pad = False
510
+ last_freq = True
511
+
512
+ kw = {
513
+ "kernel_size": ker,
514
+ "stride": stri,
515
+ "freq": freq,
516
+ "pad": pad,
517
+ "norm": norm,
518
+ "rewrite": rewrite,
519
+ "norm_groups": norm_groups,
520
+ "dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
521
+ }
522
+ kwt = dict(kw)
523
+ kwt["freq"] = 0
524
+ kwt["kernel_size"] = kernel_size
525
+ kwt["stride"] = stride
526
+ kwt["pad"] = True
527
+ kw_dec = dict(kw)
528
+ multi = False
529
+ if multi_freqs and index < multi_freqs_depth:
530
+ multi = True
531
+ kw_dec["context_freq"] = False
532
+
533
+ if last_freq:
534
+ chout_z = max(chout, chout_z)
535
+ chout = chout_z
536
+
537
+ enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
538
+ if hybrid and freq:
539
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
540
+ self.tencoder.append(tenc)
541
+
542
+ if multi:
543
+ enc = MultiWrap(enc, multi_freqs)
544
+ self.encoder.append(enc)
545
+ if index == 0:
546
+ chin = self.audio_channels * len(self.sources)
547
+ chin_z = chin
548
+ if self.cac:
549
+ chin_z *= 2
550
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
551
+ if multi:
552
+ dec = MultiWrap(dec, multi_freqs)
553
+ if hybrid and freq:
554
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
555
+ self.tdecoder.insert(0, tdec)
556
+ self.decoder.insert(0, dec)
557
+
558
+ chin = chout
559
+ chin_z = chout_z
560
+ chout = int(growth * chout)
561
+ chout_z = int(growth * chout_z)
562
+ if freq:
563
+ if freqs <= kernel_size:
564
+ freqs = 1
565
+ else:
566
+ freqs //= stride
567
+ if index == 0 and freq_emb:
568
+ self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
569
+ self.freq_emb_scale = freq_emb
570
+
571
+ if rescale:
572
+ rescale_module(self, reference=rescale)
573
+
574
+ def _spec(self, x):
575
+ hl = self.hop_length
576
+ nfft = self.nfft
577
+ x0 = x # noqa
578
+
579
+ if self.hybrid:
580
+ # We re-pad the signal in order to keep the property
581
+ # that the size of the output is exactly the size of the input
582
+ # divided by the stride (here hop_length), when divisible.
583
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
584
+ # which is not supported by torch.stft.
585
+ # Having all convolution operations follow this convention allow to easily
586
+ # align the time and frequency branches later on.
587
+ assert hl == nfft // 4
588
+ le = int(math.ceil(x.shape[-1] / hl))
589
+ pad = hl // 2 * 3
590
+ if not self.hybrid_old:
591
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
592
+ else:
593
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
594
+
595
+ z = spectro(x, nfft, hl)[..., :-1, :]
596
+ if self.hybrid:
597
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
598
+ z = z[..., 2 : 2 + le]
599
+ return z
600
+
601
+ def _ispec(self, z, length=None, scale=0):
602
+ hl = self.hop_length // (4**scale)
603
+ z = F.pad(z, (0, 0, 0, 1))
604
+ if self.hybrid:
605
+ z = F.pad(z, (2, 2))
606
+ pad = hl // 2 * 3
607
+ if not self.hybrid_old:
608
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
609
+ else:
610
+ le = hl * int(math.ceil(length / hl))
611
+ x = ispectro(z, hl, length=le)
612
+ if not self.hybrid_old:
613
+ x = x[..., pad : pad + length]
614
+ else:
615
+ x = x[..., :length]
616
+ else:
617
+ x = ispectro(z, hl, length)
618
+ return x
619
+
620
+ def _magnitude(self, z):
621
+ # return the magnitude of the spectrogram, except when cac is True,
622
+ # in which case we just move the complex dimension to the channel one.
623
+ if self.cac:
624
+ B, C, Fr, T = z.shape
625
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
626
+ m = m.reshape(B, C * 2, Fr, T)
627
+ else:
628
+ m = z.abs()
629
+ return m
630
+
631
+ def _mask(self, z, m):
632
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
633
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
634
+ niters = self.wiener_iters
635
+ if self.cac:
636
+ B, S, C, Fr, T = m.shape
637
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
638
+ out = torch.view_as_complex(out.contiguous())
639
+ return out
640
+ if self.training:
641
+ niters = self.end_iters
642
+ if niters < 0:
643
+ z = z[:, None]
644
+ return z / (1e-8 + z.abs()) * m
645
+ else:
646
+ return self._wiener(m, z, niters)
647
+
648
+ def _wiener(self, mag_out, mix_stft, niters):
649
+ # apply wiener filtering from OpenUnmix.
650
+ init = mix_stft.dtype
651
+ wiener_win_len = 300
652
+ residual = self.wiener_residual
653
+
654
+ B, S, C, Fq, T = mag_out.shape
655
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
656
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
657
+
658
+ outs = []
659
+ for sample in range(B):
660
+ pos = 0
661
+ out = []
662
+ for pos in range(0, T, wiener_win_len):
663
+ frame = slice(pos, pos + wiener_win_len)
664
+ z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
665
+ out.append(z_out.transpose(-1, -2))
666
+ outs.append(torch.cat(out, dim=0))
667
+ out = torch.view_as_complex(torch.stack(outs, 0))
668
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
669
+ if residual:
670
+ out = out[:, :-1]
671
+ assert list(out.shape) == [B, S, C, Fq, T]
672
+ return out.to(init)
673
+
674
+ def forward(self, mix):
675
+ x = mix
676
+ length = x.shape[-1]
677
+
678
+ z = self._spec(mix)
679
+ mag = self._magnitude(z).to(mix.device)
680
+ x = mag
681
+
682
+ B, C, Fq, T = x.shape
683
+
684
+ # unlike previous Demucs, we always normalize because it is easier.
685
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
686
+ std = x.std(dim=(1, 2, 3), keepdim=True)
687
+ x = (x - mean) / (1e-5 + std)
688
+ # x will be the freq. branch input.
689
+
690
+ if self.hybrid:
691
+ # Prepare the time branch input.
692
+ xt = mix
693
+ meant = xt.mean(dim=(1, 2), keepdim=True)
694
+ stdt = xt.std(dim=(1, 2), keepdim=True)
695
+ xt = (xt - meant) / (1e-5 + stdt)
696
+
697
+ # okay, this is a giant mess I know...
698
+ saved = [] # skip connections, freq.
699
+ saved_t = [] # skip connections, time.
700
+ lengths = [] # saved lengths to properly remove padding, freq branch.
701
+ lengths_t = [] # saved lengths for time branch.
702
+ for idx, encode in enumerate(self.encoder):
703
+ lengths.append(x.shape[-1])
704
+ inject = None
705
+ if self.hybrid and idx < len(self.tencoder):
706
+ # we have not yet merged branches.
707
+ lengths_t.append(xt.shape[-1])
708
+ tenc = self.tencoder[idx]
709
+ xt = tenc(xt)
710
+ if not tenc.empty:
711
+ # save for skip connection
712
+ saved_t.append(xt)
713
+ else:
714
+ # tenc contains just the first conv., so that now time and freq.
715
+ # branches have the same shape and can be merged.
716
+ inject = xt
717
+ x = encode(x, inject)
718
+ if idx == 0 and self.freq_emb is not None:
719
+ # add frequency embedding to allow for non equivariant convolutions
720
+ # over the frequency axis.
721
+ frs = torch.arange(x.shape[-2], device=x.device)
722
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
723
+ x = x + self.freq_emb_scale * emb
724
+
725
+ saved.append(x)
726
+
727
+ x = torch.zeros_like(x)
728
+ if self.hybrid:
729
+ xt = torch.zeros_like(x)
730
+ # initialize everything to zero (signal will go through u-net skips).
731
+
732
+ for idx, decode in enumerate(self.decoder):
733
+ skip = saved.pop(-1)
734
+ x, pre = decode(x, skip, lengths.pop(-1))
735
+ # `pre` contains the output just before final transposed convolution,
736
+ # which is used when the freq. and time branch separate.
737
+
738
+ if self.hybrid:
739
+ offset = self.depth - len(self.tdecoder)
740
+ if self.hybrid and idx >= offset:
741
+ tdec = self.tdecoder[idx - offset]
742
+ length_t = lengths_t.pop(-1)
743
+ if tdec.empty:
744
+ assert pre.shape[2] == 1, pre.shape
745
+ pre = pre[:, :, 0]
746
+ xt, _ = tdec(pre, None, length_t)
747
+ else:
748
+ skip = saved_t.pop(-1)
749
+ xt, _ = tdec(xt, skip, length_t)
750
+
751
+ # Let's make sure we used all stored skip connections.
752
+ assert len(saved) == 0
753
+ assert len(lengths_t) == 0
754
+ assert len(saved_t) == 0
755
+
756
+ S = len(self.sources)
757
+ x = x.view(B, S, -1, Fq, T)
758
+ x = x * std[:, None] + mean[:, None]
759
+
760
+ # to cpu as non-cuda GPUs don't support complex numbers
761
+ # demucs issue #435 ##432
762
+ # NOTE: in this case z already is on cpu
763
+ # TODO: remove this when mps supports complex numbers
764
+
765
+ device_type = x.device.type
766
+ device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
767
+ x_is_other_gpu = not device_type in ["cuda", "cpu"]
768
+
769
+ if x_is_other_gpu:
770
+ x = x.cpu()
771
+
772
+ zout = self._mask(z, x)
773
+ x = self._ispec(zout, length)
774
+
775
+ # back to other device
776
+ if x_is_other_gpu:
777
+ x = x.to(device_load)
778
+
779
+ if self.hybrid:
780
+ xt = xt.view(B, S, -1, length)
781
+ xt = xt * stdt[:, None] + meant[:, None]
782
+ x = xt + x
783
+ return x
audio_separator/separator/uvr_lib_v5/demucs/htdemucs.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from .filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
284
+ }
285
+ kwt = dict(kw)
286
+ kwt["freq"] = 0
287
+ kwt["kernel_size"] = kernel_size
288
+ kwt["stride"] = stride
289
+ kwt["pad"] = True
290
+ kw_dec = dict(kw)
291
+ multi = False
292
+ if multi_freqs and index < multi_freqs_depth:
293
+ multi = True
294
+ kw_dec["context_freq"] = False
295
+
296
+ if last_freq:
297
+ chout_z = max(chout, chout_z)
298
+ chout = chout_z
299
+
300
+ enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
301
+ if freq:
302
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
303
+ self.tencoder.append(tenc)
304
+
305
+ if multi:
306
+ enc = MultiWrap(enc, multi_freqs)
307
+ self.encoder.append(enc)
308
+ if index == 0:
309
+ chin = self.audio_channels * len(self.sources)
310
+ chin_z = chin
311
+ if self.cac:
312
+ chin_z *= 2
313
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
314
+ if multi:
315
+ dec = MultiWrap(dec, multi_freqs)
316
+ if freq:
317
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
318
+ self.tdecoder.insert(0, tdec)
319
+ self.decoder.insert(0, dec)
320
+
321
+ chin = chout
322
+ chin_z = chout_z
323
+ chout = int(growth * chout)
324
+ chout_z = int(growth * chout_z)
325
+ if freq:
326
+ if freqs <= kernel_size:
327
+ freqs = 1
328
+ else:
329
+ freqs //= stride
330
+ if index == 0 and freq_emb:
331
+ self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
332
+ self.freq_emb_scale = freq_emb
333
+
334
+ if rescale:
335
+ rescale_module(self, reference=rescale)
336
+
337
+ transformer_channels = channels * growth ** (depth - 1)
338
+ if bottom_channels:
339
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
340
+ self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
341
+ self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
342
+ self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
343
+
344
+ transformer_channels = bottom_channels
345
+
346
+ if t_layers > 0:
347
+ self.crosstransformer = CrossTransformerEncoder(
348
+ dim=transformer_channels,
349
+ emb=t_emb,
350
+ hidden_scale=t_hidden_scale,
351
+ num_heads=t_heads,
352
+ num_layers=t_layers,
353
+ cross_first=t_cross_first,
354
+ dropout=t_dropout,
355
+ max_positions=t_max_positions,
356
+ norm_in=t_norm_in,
357
+ norm_in_group=t_norm_in_group,
358
+ group_norm=t_group_norm,
359
+ norm_first=t_norm_first,
360
+ norm_out=t_norm_out,
361
+ max_period=t_max_period,
362
+ weight_decay=t_weight_decay,
363
+ lr=t_lr,
364
+ layer_scale=t_layer_scale,
365
+ gelu=t_gelu,
366
+ sin_random_shift=t_sin_random_shift,
367
+ weight_pos_embed=t_weight_pos_embed,
368
+ cape_mean_normalize=t_cape_mean_normalize,
369
+ cape_augment=t_cape_augment,
370
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
371
+ sparse_self_attn=t_sparse_self_attn,
372
+ sparse_cross_attn=t_sparse_cross_attn,
373
+ mask_type=t_mask_type,
374
+ mask_random_seed=t_mask_random_seed,
375
+ sparse_attn_window=t_sparse_attn_window,
376
+ global_window=t_global_window,
377
+ sparsity=t_sparsity,
378
+ auto_sparsity=t_auto_sparsity,
379
+ )
380
+ else:
381
+ self.crosstransformer = None
382
+
383
+ def _spec(self, x):
384
+ hl = self.hop_length
385
+ nfft = self.nfft
386
+ x0 = x # noqa
387
+
388
+ # We re-pad the signal in order to keep the property
389
+ # that the size of the output is exactly the size of the input
390
+ # divided by the stride (here hop_length), when divisible.
391
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
392
+ # which is not supported by torch.stft.
393
+ # Having all convolution operations follow this convention allow to easily
394
+ # align the time and frequency branches later on.
395
+ assert hl == nfft // 4
396
+ le = int(math.ceil(x.shape[-1] / hl))
397
+ pad = hl // 2 * 3
398
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
399
+
400
+ z = spectro(x, nfft, hl)[..., :-1, :]
401
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
402
+ z = z[..., 2 : 2 + le]
403
+ return z
404
+
405
+ def _ispec(self, z, length=None, scale=0):
406
+ hl = self.hop_length // (4**scale)
407
+ z = F.pad(z, (0, 0, 0, 1))
408
+ z = F.pad(z, (2, 2))
409
+ pad = hl // 2 * 3
410
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
411
+ x = ispectro(z, hl, length=le)
412
+ x = x[..., pad : pad + length]
413
+ return x
414
+
415
+ def _magnitude(self, z):
416
+ # return the magnitude of the spectrogram, except when cac is True,
417
+ # in which case we just move the complex dimension to the channel one.
418
+ if self.cac:
419
+ B, C, Fr, T = z.shape
420
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
421
+ m = m.reshape(B, C * 2, Fr, T)
422
+ else:
423
+ m = z.abs()
424
+ return m
425
+
426
+ def _mask(self, z, m):
427
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
428
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
429
+ niters = self.wiener_iters
430
+ if self.cac:
431
+ B, S, C, Fr, T = m.shape
432
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
433
+ out = torch.view_as_complex(out.contiguous())
434
+ return out
435
+ if self.training:
436
+ niters = self.end_iters
437
+ if niters < 0:
438
+ z = z[:, None]
439
+ return z / (1e-8 + z.abs()) * m
440
+ else:
441
+ return self._wiener(m, z, niters)
442
+
443
+ def _wiener(self, mag_out, mix_stft, niters):
444
+ # apply wiener filtering from OpenUnmix.
445
+ init = mix_stft.dtype
446
+ wiener_win_len = 300
447
+ residual = self.wiener_residual
448
+
449
+ B, S, C, Fq, T = mag_out.shape
450
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
451
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
452
+
453
+ outs = []
454
+ for sample in range(B):
455
+ pos = 0
456
+ out = []
457
+ for pos in range(0, T, wiener_win_len):
458
+ frame = slice(pos, pos + wiener_win_len)
459
+ z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
460
+ out.append(z_out.transpose(-1, -2))
461
+ outs.append(torch.cat(out, dim=0))
462
+ out = torch.view_as_complex(torch.stack(outs, 0))
463
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
464
+ if residual:
465
+ out = out[:, :-1]
466
+ assert list(out.shape) == [B, S, C, Fq, T]
467
+ return out.to(init)
468
+
469
+ def valid_length(self, length: int):
470
+ """
471
+ Return a length that is appropriate for evaluation.
472
+ In our case, always return the training length, unless
473
+ it is smaller than the given length, in which case this
474
+ raises an error.
475
+ """
476
+ if not self.use_train_segment:
477
+ return length
478
+ training_length = int(self.segment * self.samplerate)
479
+ if training_length < length:
480
+ raise ValueError(f"Given length {length} is longer than " f"training length {training_length}")
481
+ return training_length
482
+
483
+ def forward(self, mix):
484
+ length = mix.shape[-1]
485
+ length_pre_pad = None
486
+ if self.use_train_segment:
487
+ if self.training:
488
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
489
+ else:
490
+ training_length = int(self.segment * self.samplerate)
491
+ if mix.shape[-1] < training_length:
492
+ length_pre_pad = mix.shape[-1]
493
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
494
+ z = self._spec(mix)
495
+ mag = self._magnitude(z).to(mix.device)
496
+ x = mag
497
+
498
+ B, C, Fq, T = x.shape
499
+
500
+ # unlike previous Demucs, we always normalize because it is easier.
501
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
502
+ std = x.std(dim=(1, 2, 3), keepdim=True)
503
+ x = (x - mean) / (1e-5 + std)
504
+ # x will be the freq. branch input.
505
+
506
+ # Prepare the time branch input.
507
+ xt = mix
508
+ meant = xt.mean(dim=(1, 2), keepdim=True)
509
+ stdt = xt.std(dim=(1, 2), keepdim=True)
510
+ xt = (xt - meant) / (1e-5 + stdt)
511
+
512
+ # okay, this is a giant mess I know...
513
+ saved = [] # skip connections, freq.
514
+ saved_t = [] # skip connections, time.
515
+ lengths = [] # saved lengths to properly remove padding, freq branch.
516
+ lengths_t = [] # saved lengths for time branch.
517
+ for idx, encode in enumerate(self.encoder):
518
+ lengths.append(x.shape[-1])
519
+ inject = None
520
+ if idx < len(self.tencoder):
521
+ # we have not yet merged branches.
522
+ lengths_t.append(xt.shape[-1])
523
+ tenc = self.tencoder[idx]
524
+ xt = tenc(xt)
525
+ if not tenc.empty:
526
+ # save for skip connection
527
+ saved_t.append(xt)
528
+ else:
529
+ # tenc contains just the first conv., so that now time and freq.
530
+ # branches have the same shape and can be merged.
531
+ inject = xt
532
+ x = encode(x, inject)
533
+ if idx == 0 and self.freq_emb is not None:
534
+ # add frequency embedding to allow for non equivariant convolutions
535
+ # over the frequency axis.
536
+ frs = torch.arange(x.shape[-2], device=x.device)
537
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
538
+ x = x + self.freq_emb_scale * emb
539
+
540
+ saved.append(x)
541
+ if self.crosstransformer:
542
+ if self.bottom_channels:
543
+ b, c, f, t = x.shape
544
+ x = rearrange(x, "b c f t-> b c (f t)")
545
+ x = self.channel_upsampler(x)
546
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
547
+ xt = self.channel_upsampler_t(xt)
548
+
549
+ x, xt = self.crosstransformer(x, xt)
550
+
551
+ if self.bottom_channels:
552
+ x = rearrange(x, "b c f t-> b c (f t)")
553
+ x = self.channel_downsampler(x)
554
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
555
+ xt = self.channel_downsampler_t(xt)
556
+
557
+ for idx, decode in enumerate(self.decoder):
558
+ skip = saved.pop(-1)
559
+ x, pre = decode(x, skip, lengths.pop(-1))
560
+ # `pre` contains the output just before final transposed convolution,
561
+ # which is used when the freq. and time branch separate.
562
+
563
+ offset = self.depth - len(self.tdecoder)
564
+ if idx >= offset:
565
+ tdec = self.tdecoder[idx - offset]
566
+ length_t = lengths_t.pop(-1)
567
+ if tdec.empty:
568
+ assert pre.shape[2] == 1, pre.shape
569
+ pre = pre[:, :, 0]
570
+ xt, _ = tdec(pre, None, length_t)
571
+ else:
572
+ skip = saved_t.pop(-1)
573
+ xt, _ = tdec(xt, skip, length_t)
574
+
575
+ # Let's make sure we used all stored skip connections.
576
+ assert len(saved) == 0
577
+ assert len(lengths_t) == 0
578
+ assert len(saved_t) == 0
579
+
580
+ S = len(self.sources)
581
+ x = x.view(B, S, -1, Fq, T)
582
+ x = x * std[:, None] + mean[:, None]
583
+
584
+ # to cpu as non-cuda GPUs don't support complex numbers
585
+ # demucs issue #435 ##432
586
+ # NOTE: in this case z already is on cpu
587
+ # TODO: remove this when mps supports complex numbers
588
+
589
+ device_type = x.device.type
590
+ device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
591
+ x_is_other_gpu = not device_type in ["cuda", "cpu"]
592
+
593
+ if x_is_other_gpu:
594
+ x = x.cpu()
595
+
596
+ zout = self._mask(z, x)
597
+ if self.use_train_segment:
598
+ if self.training:
599
+ x = self._ispec(zout, length)
600
+ else:
601
+ x = self._ispec(zout, training_length)
602
+ else:
603
+ x = self._ispec(zout, length)
604
+
605
+ # back to other device
606
+ if x_is_other_gpu:
607
+ x = x.to(device_load)
608
+
609
+ if self.use_train_segment:
610
+ if self.training:
611
+ xt = xt.view(B, S, -1, length)
612
+ else:
613
+ xt = xt.view(B, S, -1, training_length)
614
+ else:
615
+ xt = xt.view(B, S, -1, length)
616
+ xt = xt * stdt[:, None] + meant[:, None]
617
+ x = xt + x
618
+ if length_pre_pad:
619
+ x = x[..., :length_pre_pad]
620
+ return x
audio_separator/separator/uvr_lib_v5/demucs/model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch as th
10
+ from torch import nn
11
+
12
+ from .utils import capture_init, center_trim
13
+
14
+
15
+ class BLSTM(nn.Module):
16
+ def __init__(self, dim, layers=1):
17
+ super().__init__()
18
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
19
+ self.linear = nn.Linear(2 * dim, dim)
20
+
21
+ def forward(self, x):
22
+ x = x.permute(2, 0, 1)
23
+ x = self.lstm(x)[0]
24
+ x = self.linear(x)
25
+ x = x.permute(1, 2, 0)
26
+ return x
27
+
28
+
29
+ def rescale_conv(conv, reference):
30
+ std = conv.weight.std().detach()
31
+ scale = (std / reference) ** 0.5
32
+ conv.weight.data /= scale
33
+ if conv.bias is not None:
34
+ conv.bias.data /= scale
35
+
36
+
37
+ def rescale_module(module, reference):
38
+ for sub in module.modules():
39
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
40
+ rescale_conv(sub, reference)
41
+
42
+
43
+ def upsample(x, stride):
44
+ """
45
+ Linear upsampling, the output will be `stride` times longer.
46
+ """
47
+ batch, channels, time = x.size()
48
+ weight = th.arange(stride, device=x.device, dtype=th.float) / stride
49
+ x = x.view(batch, channels, time, 1)
50
+ out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
51
+ return out.reshape(batch, channels, -1)
52
+
53
+
54
+ def downsample(x, stride):
55
+ """
56
+ Downsample x by decimation.
57
+ """
58
+ return x[:, :, ::stride]
59
+
60
+
61
+ class Demucs(nn.Module):
62
+ @capture_init
63
+ def __init__(
64
+ self, sources=4, audio_channels=2, channels=64, depth=6, rewrite=True, glu=True, upsample=False, rescale=0.1, kernel_size=8, stride=4, growth=2.0, lstm_layers=2, context=3, samplerate=44100
65
+ ):
66
+ """
67
+ Args:
68
+ sources (int): number of sources to separate
69
+ audio_channels (int): stereo or mono
70
+ channels (int): first convolution channels
71
+ depth (int): number of encoder/decoder layers
72
+ rewrite (bool): add 1x1 convolution to each encoder layer
73
+ and a convolution to each decoder layer.
74
+ For the decoder layer, `context` gives the kernel size.
75
+ glu (bool): use glu instead of ReLU
76
+ upsample (bool): use linear upsampling with convolutions
77
+ Wave-U-Net style, instead of transposed convolutions
78
+ rescale (int): rescale initial weights of convolutions
79
+ to get their standard deviation closer to `rescale`
80
+ kernel_size (int): kernel size for convolutions
81
+ stride (int): stride for convolutions
82
+ growth (float): multiply (resp divide) number of channels by that
83
+ for each layer of the encoder (resp decoder)
84
+ lstm_layers (int): number of lstm layers, 0 = no lstm
85
+ context (int): kernel size of the convolution in the
86
+ decoder before the transposed convolution. If > 1,
87
+ will provide some context from neighboring time
88
+ steps.
89
+ """
90
+
91
+ super().__init__()
92
+ self.audio_channels = audio_channels
93
+ self.sources = sources
94
+ self.kernel_size = kernel_size
95
+ self.context = context
96
+ self.stride = stride
97
+ self.depth = depth
98
+ self.upsample = upsample
99
+ self.channels = channels
100
+ self.samplerate = samplerate
101
+
102
+ self.encoder = nn.ModuleList()
103
+ self.decoder = nn.ModuleList()
104
+
105
+ self.final = None
106
+ if upsample:
107
+ self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
108
+ stride = 1
109
+
110
+ if glu:
111
+ activation = nn.GLU(dim=1)
112
+ ch_scale = 2
113
+ else:
114
+ activation = nn.ReLU()
115
+ ch_scale = 1
116
+ in_channels = audio_channels
117
+ for index in range(depth):
118
+ encode = []
119
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
120
+ if rewrite:
121
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
122
+ self.encoder.append(nn.Sequential(*encode))
123
+
124
+ decode = []
125
+ if index > 0:
126
+ out_channels = in_channels
127
+ else:
128
+ if upsample:
129
+ out_channels = channels
130
+ else:
131
+ out_channels = sources * audio_channels
132
+ if rewrite:
133
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
134
+ if upsample:
135
+ decode += [nn.Conv1d(channels, out_channels, kernel_size, stride=1)]
136
+ else:
137
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
138
+ if index > 0:
139
+ decode.append(nn.ReLU())
140
+ self.decoder.insert(0, nn.Sequential(*decode))
141
+ in_channels = channels
142
+ channels = int(growth * channels)
143
+
144
+ channels = in_channels
145
+
146
+ if lstm_layers:
147
+ self.lstm = BLSTM(channels, lstm_layers)
148
+ else:
149
+ self.lstm = None
150
+
151
+ if rescale:
152
+ rescale_module(self, reference=rescale)
153
+
154
+ def valid_length(self, length):
155
+ """
156
+ Return the nearest valid length to use with the model so that
157
+ there is no time steps left over in a convolutions, e.g. for all
158
+ layers, size of the input - kernel_size % stride = 0.
159
+
160
+ If the mixture has a valid length, the estimated sources
161
+ will have exactly the same length when context = 1. If context > 1,
162
+ the two signals can be center trimmed to match.
163
+
164
+ For training, extracts should have a valid length.For evaluation
165
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
166
+ """
167
+ for _ in range(self.depth):
168
+ if self.upsample:
169
+ length = math.ceil(length / self.stride) + self.kernel_size - 1
170
+ else:
171
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
172
+ length = max(1, length)
173
+ length += self.context - 1
174
+ for _ in range(self.depth):
175
+ if self.upsample:
176
+ length = length * self.stride + self.kernel_size - 1
177
+ else:
178
+ length = (length - 1) * self.stride + self.kernel_size
179
+
180
+ return int(length)
181
+
182
+ def forward(self, mix):
183
+ x = mix
184
+ saved = [x]
185
+ for encode in self.encoder:
186
+ x = encode(x)
187
+ saved.append(x)
188
+ if self.upsample:
189
+ x = downsample(x, self.stride)
190
+ if self.lstm:
191
+ x = self.lstm(x)
192
+ for decode in self.decoder:
193
+ if self.upsample:
194
+ x = upsample(x, stride=self.stride)
195
+ skip = center_trim(saved.pop(-1), x)
196
+ x = x + skip
197
+ x = decode(x)
198
+ if self.final:
199
+ skip = center_trim(saved.pop(-1), x)
200
+ x = th.cat([x, skip], dim=1)
201
+ x = self.final(x)
202
+
203
+ x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
204
+ return x
audio_separator/separator/uvr_lib_v5/demucs/model_v2.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import julius
10
+ from torch import nn
11
+ from .tasnet_v2 import ConvTasNet
12
+
13
+ from .utils import capture_init, center_trim
14
+
15
+
16
+ class BLSTM(nn.Module):
17
+ def __init__(self, dim, layers=1):
18
+ super().__init__()
19
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
20
+ self.linear = nn.Linear(2 * dim, dim)
21
+
22
+ def forward(self, x):
23
+ x = x.permute(2, 0, 1)
24
+ x = self.lstm(x)[0]
25
+ x = self.linear(x)
26
+ x = x.permute(1, 2, 0)
27
+ return x
28
+
29
+
30
+ def rescale_conv(conv, reference):
31
+ std = conv.weight.std().detach()
32
+ scale = (std / reference) ** 0.5
33
+ conv.weight.data /= scale
34
+ if conv.bias is not None:
35
+ conv.bias.data /= scale
36
+
37
+
38
+ def rescale_module(module, reference):
39
+ for sub in module.modules():
40
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
41
+ rescale_conv(sub, reference)
42
+
43
+
44
+ def auto_load_demucs_model_v2(sources, demucs_model_name):
45
+
46
+ if "48" in demucs_model_name:
47
+ channels = 48
48
+ elif "unittest" in demucs_model_name:
49
+ channels = 4
50
+ else:
51
+ channels = 64
52
+
53
+ if "tasnet" in demucs_model_name:
54
+ init_demucs_model = ConvTasNet(sources, X=10)
55
+ else:
56
+ init_demucs_model = Demucs(sources, channels=channels)
57
+
58
+ return init_demucs_model
59
+
60
+
61
+ class Demucs(nn.Module):
62
+ @capture_init
63
+ def __init__(
64
+ self,
65
+ sources,
66
+ audio_channels=2,
67
+ channels=64,
68
+ depth=6,
69
+ rewrite=True,
70
+ glu=True,
71
+ rescale=0.1,
72
+ resample=True,
73
+ kernel_size=8,
74
+ stride=4,
75
+ growth=2.0,
76
+ lstm_layers=2,
77
+ context=3,
78
+ normalize=False,
79
+ samplerate=44100,
80
+ segment_length=4 * 10 * 44100,
81
+ ):
82
+ """
83
+ Args:
84
+ sources (list[str]): list of source names
85
+ audio_channels (int): stereo or mono
86
+ channels (int): first convolution channels
87
+ depth (int): number of encoder/decoder layers
88
+ rewrite (bool): add 1x1 convolution to each encoder layer
89
+ and a convolution to each decoder layer.
90
+ For the decoder layer, `context` gives the kernel size.
91
+ glu (bool): use glu instead of ReLU
92
+ resample_input (bool): upsample x2 the input and downsample /2 the output.
93
+ rescale (int): rescale initial weights of convolutions
94
+ to get their standard deviation closer to `rescale`
95
+ kernel_size (int): kernel size for convolutions
96
+ stride (int): stride for convolutions
97
+ growth (float): multiply (resp divide) number of channels by that
98
+ for each layer of the encoder (resp decoder)
99
+ lstm_layers (int): number of lstm layers, 0 = no lstm
100
+ context (int): kernel size of the convolution in the
101
+ decoder before the transposed convolution. If > 1,
102
+ will provide some context from neighboring time
103
+ steps.
104
+ samplerate (int): stored as meta information for easing
105
+ future evaluations of the model.
106
+ segment_length (int): stored as meta information for easing
107
+ future evaluations of the model. Length of the segments on which
108
+ the model was trained.
109
+ """
110
+
111
+ super().__init__()
112
+ self.audio_channels = audio_channels
113
+ self.sources = sources
114
+ self.kernel_size = kernel_size
115
+ self.context = context
116
+ self.stride = stride
117
+ self.depth = depth
118
+ self.resample = resample
119
+ self.channels = channels
120
+ self.normalize = normalize
121
+ self.samplerate = samplerate
122
+ self.segment_length = segment_length
123
+
124
+ self.encoder = nn.ModuleList()
125
+ self.decoder = nn.ModuleList()
126
+
127
+ if glu:
128
+ activation = nn.GLU(dim=1)
129
+ ch_scale = 2
130
+ else:
131
+ activation = nn.ReLU()
132
+ ch_scale = 1
133
+ in_channels = audio_channels
134
+ for index in range(depth):
135
+ encode = []
136
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
137
+ if rewrite:
138
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
139
+ self.encoder.append(nn.Sequential(*encode))
140
+
141
+ decode = []
142
+ if index > 0:
143
+ out_channels = in_channels
144
+ else:
145
+ out_channels = len(self.sources) * audio_channels
146
+ if rewrite:
147
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
148
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
149
+ if index > 0:
150
+ decode.append(nn.ReLU())
151
+ self.decoder.insert(0, nn.Sequential(*decode))
152
+ in_channels = channels
153
+ channels = int(growth * channels)
154
+
155
+ channels = in_channels
156
+
157
+ if lstm_layers:
158
+ self.lstm = BLSTM(channels, lstm_layers)
159
+ else:
160
+ self.lstm = None
161
+
162
+ if rescale:
163
+ rescale_module(self, reference=rescale)
164
+
165
+ def valid_length(self, length):
166
+ """
167
+ Return the nearest valid length to use with the model so that
168
+ there is no time steps left over in a convolutions, e.g. for all
169
+ layers, size of the input - kernel_size % stride = 0.
170
+
171
+ If the mixture has a valid length, the estimated sources
172
+ will have exactly the same length when context = 1. If context > 1,
173
+ the two signals can be center trimmed to match.
174
+
175
+ For training, extracts should have a valid length.For evaluation
176
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
177
+ """
178
+ if self.resample:
179
+ length *= 2
180
+ for _ in range(self.depth):
181
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
182
+ length = max(1, length)
183
+ length += self.context - 1
184
+ for _ in range(self.depth):
185
+ length = (length - 1) * self.stride + self.kernel_size
186
+
187
+ if self.resample:
188
+ length = math.ceil(length / 2)
189
+ return int(length)
190
+
191
+ def forward(self, mix):
192
+ x = mix
193
+
194
+ if self.normalize:
195
+ mono = mix.mean(dim=1, keepdim=True)
196
+ mean = mono.mean(dim=-1, keepdim=True)
197
+ std = mono.std(dim=-1, keepdim=True)
198
+ else:
199
+ mean = 0
200
+ std = 1
201
+
202
+ x = (x - mean) / (1e-5 + std)
203
+
204
+ if self.resample:
205
+ x = julius.resample_frac(x, 1, 2)
206
+
207
+ saved = []
208
+ for encode in self.encoder:
209
+ x = encode(x)
210
+ saved.append(x)
211
+ if self.lstm:
212
+ x = self.lstm(x)
213
+ for decode in self.decoder:
214
+ skip = center_trim(saved.pop(-1), x)
215
+ x = x + skip
216
+ x = decode(x)
217
+
218
+ if self.resample:
219
+ x = julius.resample_frac(x, 2, 1)
220
+ x = x * std + mean
221
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
222
+ return x
audio_separator/separator/uvr_lib_v5/demucs/pretrained.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loading pretrained models.
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ import typing as tp
12
+
13
+ # from dora.log import fatal
14
+
15
+ import logging
16
+
17
+ from diffq import DiffQuantizer
18
+ import torch.hub
19
+
20
+ from .model import Demucs
21
+ from .tasnet_v2 import ConvTasNet
22
+ from .utils import set_state
23
+
24
+ from .hdemucs import HDemucs
25
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
26
+
27
+ logger = logging.getLogger(__name__)
28
+ ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
29
+ REMOTE_ROOT = Path(__file__).parent / "remote"
30
+
31
+ SOURCES = ["drums", "bass", "other", "vocals"]
32
+
33
+
34
+ def demucs_unittest():
35
+ model = HDemucs(channels=4, sources=SOURCES)
36
+ return model
37
+
38
+
39
+ def add_model_flags(parser):
40
+ group = parser.add_mutually_exclusive_group(required=False)
41
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
42
+ group.add_argument("-n", "--name", default="mdx_extra_q", help="Pretrained model name or signature. Default is mdx_extra_q.")
43
+ parser.add_argument("--repo", type=Path, help="Folder containing all pre-trained models for use with -n.")
44
+
45
+
46
+ def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
47
+ root: str = ""
48
+ models: tp.Dict[str, str] = {}
49
+ for line in remote_file_list.read_text().split("\n"):
50
+ line = line.strip()
51
+ if line.startswith("#"):
52
+ continue
53
+ elif line.startswith("root:"):
54
+ root = line.split(":", 1)[1].strip()
55
+ else:
56
+ sig = line.split("-", 1)[0]
57
+ assert sig not in models
58
+ models[sig] = ROOT_URL + root + line
59
+ return models
60
+
61
+
62
+ def get_model(name: str, repo: tp.Optional[Path] = None):
63
+ """`name` must be a bag of models name or a pretrained signature
64
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
65
+ """
66
+ if name == "demucs_unittest":
67
+ return demucs_unittest()
68
+ model_repo: ModelOnlyRepo
69
+ if repo is None:
70
+ models = _parse_remote_files(REMOTE_ROOT / "files.txt")
71
+ model_repo = RemoteRepo(models)
72
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
73
+ else:
74
+ if not repo.is_dir():
75
+ fatal(f"{repo} must exist and be a directory.")
76
+ model_repo = LocalRepo(repo)
77
+ bag_repo = BagOnlyRepo(repo, model_repo)
78
+ any_repo = AnyModelRepo(model_repo, bag_repo)
79
+ model = any_repo.get_model(name)
80
+ model.eval()
81
+ return model
82
+
83
+
84
+ def get_model_from_args(args):
85
+ """
86
+ Load local model package or pre-trained model.
87
+ """
88
+ return get_model(name=args.name, repo=args.repo)
89
+
90
+
91
+ logger = logging.getLogger(__name__)
92
+ ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
93
+
94
+ PRETRAINED_MODELS = {
95
+ "demucs": "e07c671f",
96
+ "demucs48_hq": "28a1282c",
97
+ "demucs_extra": "3646af93",
98
+ "demucs_quantized": "07afea75",
99
+ "tasnet": "beb46fac",
100
+ "tasnet_extra": "df3777b2",
101
+ "demucs_unittest": "09ebc15f",
102
+ }
103
+
104
+ SOURCES = ["drums", "bass", "other", "vocals"]
105
+
106
+
107
+ def get_url(name):
108
+ sig = PRETRAINED_MODELS[name]
109
+ return ROOT + name + "-" + sig[:8] + ".th"
110
+
111
+
112
+ def is_pretrained(name):
113
+ return name in PRETRAINED_MODELS
114
+
115
+
116
+ def load_pretrained(name):
117
+ if name == "demucs":
118
+ return demucs(pretrained=True)
119
+ elif name == "demucs48_hq":
120
+ return demucs(pretrained=True, hq=True, channels=48)
121
+ elif name == "demucs_extra":
122
+ return demucs(pretrained=True, extra=True)
123
+ elif name == "demucs_quantized":
124
+ return demucs(pretrained=True, quantized=True)
125
+ elif name == "demucs_unittest":
126
+ return demucs_unittest(pretrained=True)
127
+ elif name == "tasnet":
128
+ return tasnet(pretrained=True)
129
+ elif name == "tasnet_extra":
130
+ return tasnet(pretrained=True, extra=True)
131
+ else:
132
+ raise ValueError(f"Invalid pretrained name {name}")
133
+
134
+
135
+ def _load_state(name, model, quantizer=None):
136
+ url = get_url(name)
137
+ state = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
138
+ set_state(model, quantizer, state)
139
+ if quantizer:
140
+ quantizer.detach()
141
+
142
+
143
+ def demucs_unittest(pretrained=True):
144
+ model = Demucs(channels=4, sources=SOURCES)
145
+ if pretrained:
146
+ _load_state("demucs_unittest", model)
147
+ return model
148
+
149
+
150
+ def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
151
+ if not pretrained and (extra or quantized or hq):
152
+ raise ValueError("if extra or quantized is True, pretrained must be True.")
153
+ model = Demucs(sources=SOURCES, channels=channels)
154
+ if pretrained:
155
+ name = "demucs"
156
+ if channels != 64:
157
+ name += str(channels)
158
+ quantizer = None
159
+ if sum([extra, quantized, hq]) > 1:
160
+ raise ValueError("Only one of extra, quantized, hq, can be True.")
161
+ if quantized:
162
+ quantizer = DiffQuantizer(model, group_size=8, min_size=1)
163
+ name += "_quantized"
164
+ if extra:
165
+ name += "_extra"
166
+ if hq:
167
+ name += "_hq"
168
+ _load_state(name, model, quantizer)
169
+ return model
170
+
171
+
172
+ def tasnet(pretrained=True, extra=False):
173
+ if not pretrained and extra:
174
+ raise ValueError("if extra is True, pretrained must be True.")
175
+ model = ConvTasNet(X=10, sources=SOURCES)
176
+ if pretrained:
177
+ name = "tasnet"
178
+ if extra:
179
+ name = "tasnet_extra"
180
+ _load_state(name, model)
181
+ return model
audio_separator/separator/uvr_lib_v5/demucs/repo.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Represents a model repository, including pre-trained models and bags of models.
7
+ A repo can either be the main remote repository stored in AWS, or a local repository
8
+ with your own models.
9
+ """
10
+
11
+ from hashlib import sha256
12
+ from pathlib import Path
13
+ import typing as tp
14
+
15
+ import torch
16
+ import yaml
17
+
18
+ from .apply import BagOfModels, Model
19
+ from .states import load_model
20
+
21
+
22
+ AnyModel = tp.Union[Model, BagOfModels]
23
+
24
+
25
+ class ModelLoadingError(RuntimeError):
26
+ pass
27
+
28
+
29
+ def check_checksum(path: Path, checksum: str):
30
+ sha = sha256()
31
+ with open(path, "rb") as file:
32
+ while True:
33
+ buf = file.read(2**20)
34
+ if not buf:
35
+ break
36
+ sha.update(buf)
37
+ actual_checksum = sha.hexdigest()[: len(checksum)]
38
+ if actual_checksum != checksum:
39
+ raise ModelLoadingError(f"Invalid checksum for file {path}, " f"expected {checksum} but got {actual_checksum}")
40
+
41
+
42
+ class ModelOnlyRepo:
43
+ """Base class for all model only repos."""
44
+
45
+ def has_model(self, sig: str) -> bool:
46
+ raise NotImplementedError()
47
+
48
+ def get_model(self, sig: str) -> Model:
49
+ raise NotImplementedError()
50
+
51
+
52
+ class RemoteRepo(ModelOnlyRepo):
53
+ def __init__(self, models: tp.Dict[str, str]):
54
+ self._models = models
55
+
56
+ def has_model(self, sig: str) -> bool:
57
+ return sig in self._models
58
+
59
+ def get_model(self, sig: str) -> Model:
60
+ try:
61
+ url = self._models[sig]
62
+ except KeyError:
63
+ raise ModelLoadingError(f"Could not find a pre-trained model with signature {sig}.")
64
+ pkg = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
65
+ return load_model(pkg)
66
+
67
+
68
+ class LocalRepo(ModelOnlyRepo):
69
+ def __init__(self, root: Path):
70
+ self.root = root
71
+ self.scan()
72
+
73
+ def scan(self):
74
+ self._models = {}
75
+ self._checksums = {}
76
+ for file in self.root.iterdir():
77
+ if file.suffix == ".th":
78
+ if "-" in file.stem:
79
+ xp_sig, checksum = file.stem.split("-")
80
+ self._checksums[xp_sig] = checksum
81
+ else:
82
+ xp_sig = file.stem
83
+ if xp_sig in self._models:
84
+ print("Whats xp? ", xp_sig)
85
+ raise ModelLoadingError(f"Duplicate pre-trained model exist for signature {xp_sig}. " "Please delete all but one.")
86
+ self._models[xp_sig] = file
87
+
88
+ def has_model(self, sig: str) -> bool:
89
+ return sig in self._models
90
+
91
+ def get_model(self, sig: str) -> Model:
92
+ try:
93
+ file = self._models[sig]
94
+ except KeyError:
95
+ raise ModelLoadingError(f"Could not find pre-trained model with signature {sig}.")
96
+ if sig in self._checksums:
97
+ check_checksum(file, self._checksums[sig])
98
+ return load_model(file)
99
+
100
+
101
+ class BagOnlyRepo:
102
+ """Handles only YAML files containing bag of models, leaving the actual
103
+ model loading to some Repo.
104
+ """
105
+
106
+ def __init__(self, root: Path, model_repo: ModelOnlyRepo):
107
+ self.root = root
108
+ self.model_repo = model_repo
109
+ self.scan()
110
+
111
+ def scan(self):
112
+ self._bags = {}
113
+ for file in self.root.iterdir():
114
+ if file.suffix == ".yaml":
115
+ self._bags[file.stem] = file
116
+
117
+ def has_model(self, name: str) -> bool:
118
+ return name in self._bags
119
+
120
+ def get_model(self, name: str) -> BagOfModels:
121
+ try:
122
+ yaml_file = self._bags[name]
123
+ except KeyError:
124
+ raise ModelLoadingError(f"{name} is neither a single pre-trained model or " "a bag of models.")
125
+ bag = yaml.safe_load(open(yaml_file))
126
+ signatures = bag["models"]
127
+ models = [self.model_repo.get_model(sig) for sig in signatures]
128
+ weights = bag.get("weights")
129
+ segment = bag.get("segment")
130
+ return BagOfModels(models, weights, segment)
131
+
132
+
133
+ class AnyModelRepo:
134
+ def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
135
+ self.model_repo = model_repo
136
+ self.bag_repo = bag_repo
137
+
138
+ def has_model(self, name_or_sig: str) -> bool:
139
+ return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
140
+
141
+ def get_model(self, name_or_sig: str) -> AnyModel:
142
+ # print('name_or_sig: ', name_or_sig)
143
+ if self.model_repo.has_model(name_or_sig):
144
+ return self.model_repo.get_model(name_or_sig)
145
+ else:
146
+ return self.bag_repo.get_model(name_or_sig)
audio_separator/separator/uvr_lib_v5/demucs/spec.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Conveniance wrapper to perform STFT and iSTFT"""
7
+
8
+ import torch as th
9
+
10
+
11
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
12
+ *other, length = x.shape
13
+ x = x.reshape(-1, length)
14
+
15
+ device_type = x.device.type
16
+ is_other_gpu = not device_type in ["cuda", "cpu"]
17
+
18
+ if is_other_gpu:
19
+ x = x.cpu()
20
+ z = th.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=th.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
21
+ _, freqs, frame = z.shape
22
+ return z.view(*other, freqs, frame)
23
+
24
+
25
+ def ispectro(z, hop_length=None, length=None, pad=0):
26
+ *other, freqs, frames = z.shape
27
+ n_fft = 2 * freqs - 2
28
+ z = z.view(-1, freqs, frames)
29
+ win_length = n_fft // (1 + pad)
30
+
31
+ device_type = z.device.type
32
+ is_other_gpu = not device_type in ["cuda", "cpu"]
33
+
34
+ if is_other_gpu:
35
+ z = z.cpu()
36
+ x = th.istft(z, n_fft, hop_length, window=th.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
37
+ _, length = x.shape
38
+ return x.view(*other, length)
audio_separator/separator/uvr_lib_v5/demucs/states.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Utilities to save and load models.
8
+ """
9
+ from contextlib import contextmanager
10
+
11
+ import functools
12
+ import hashlib
13
+ import inspect
14
+ import io
15
+ from pathlib import Path
16
+ import warnings
17
+
18
+ from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
19
+ import torch
20
+
21
+
22
+ def get_quantizer(model, args, optimizer=None):
23
+ """Return the quantizer given the XP quantization args."""
24
+ quantizer = None
25
+ if args.diffq:
26
+ quantizer = DiffQuantizer(model, min_size=args.min_size, group_size=args.group_size)
27
+ if optimizer is not None:
28
+ quantizer.setup_optimizer(optimizer)
29
+ elif args.qat:
30
+ quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.min_size)
31
+ return quantizer
32
+
33
+
34
+ def load_model(path_or_package, strict=False):
35
+ """Load a model from the given serialized model, either given as a dict (already loaded)
36
+ or a path to a file on disk."""
37
+ if isinstance(path_or_package, dict):
38
+ package = path_or_package
39
+ elif isinstance(path_or_package, (str, Path)):
40
+ with warnings.catch_warnings():
41
+ warnings.simplefilter("ignore")
42
+ path = path_or_package
43
+ package = torch.load(path, "cpu", weights_only=False)
44
+ else:
45
+ raise ValueError(f"Invalid type for {path_or_package}.")
46
+
47
+ klass = package["klass"]
48
+ args = package["args"]
49
+ kwargs = package["kwargs"]
50
+
51
+ if strict:
52
+ model = klass(*args, **kwargs)
53
+ else:
54
+ sig = inspect.signature(klass)
55
+ for key in list(kwargs):
56
+ if key not in sig.parameters:
57
+ warnings.warn("Dropping inexistant parameter " + key)
58
+ del kwargs[key]
59
+ model = klass(*args, **kwargs)
60
+
61
+ state = package["state"]
62
+
63
+ set_state(model, state)
64
+ return model
65
+
66
+
67
+ def get_state(model, quantizer, half=False):
68
+ """Get the state from a model, potentially with quantization applied.
69
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
70
+ but half the state size."""
71
+ if quantizer is None:
72
+ dtype = torch.half if half else None
73
+ state = {k: p.data.to(device="cpu", dtype=dtype) for k, p in model.state_dict().items()}
74
+ else:
75
+ state = quantizer.get_quantized_state()
76
+ state["__quantized"] = True
77
+ return state
78
+
79
+
80
+ def set_state(model, state, quantizer=None):
81
+ """Set the state on a given model."""
82
+ if state.get("__quantized"):
83
+ if quantizer is not None:
84
+ quantizer.restore_quantized_state(model, state["quantized"])
85
+ else:
86
+ restore_quantized_state(model, state)
87
+ else:
88
+ model.load_state_dict(state)
89
+ return state
90
+
91
+
92
+ def save_with_checksum(content, path):
93
+ """Save the given value on disk, along with a sha256 hash.
94
+ Should be used with the output of either `serialize_model` or `get_state`."""
95
+ buf = io.BytesIO()
96
+ torch.save(content, buf)
97
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
98
+
99
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
100
+ path.write_bytes(buf.getvalue())
101
+
102
+
103
+ def copy_state(state):
104
+ return {k: v.cpu().clone() for k, v in state.items()}
105
+
106
+
107
+ @contextmanager
108
+ def swap_state(model, state):
109
+ """
110
+ Context manager that swaps the state of a model, e.g:
111
+
112
+ # model is in old state
113
+ with swap_state(model, new_state):
114
+ # model in new state
115
+ # model back to old state
116
+ """
117
+ old_state = copy_state(model.state_dict())
118
+ model.load_state_dict(state, strict=False)
119
+ try:
120
+ yield
121
+ finally:
122
+ model.load_state_dict(old_state)
123
+
124
+
125
+ def capture_init(init):
126
+ @functools.wraps(init)
127
+ def __init__(self, *args, **kwargs):
128
+ self._init_args_kwargs = (args, kwargs)
129
+ init(self, *args, **kwargs)
130
+
131
+ return __init__
audio_separator/separator/uvr_lib_v5/demucs/tasnet.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes, device=signal.device).unfold(0, subframes_per_frame, subframe_step)
57
+ frame = frame.long() # signal may in GPU or CPU
58
+ frame = frame.contiguous().view(-1)
59
+
60
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
61
+ result.index_add_(-2, frame, subframe_signal)
62
+ result = result.view(*outer_dimensions, -1)
63
+ return result
64
+
65
+
66
+ class ConvTasNet(nn.Module):
67
+ @capture_init
68
+ def __init__(self, N=256, L=20, B=256, H=512, P=3, X=8, R=4, C=4, audio_channels=1, samplerate=44100, norm_type="gLN", causal=False, mask_nonlinear="relu"):
69
+ """
70
+ Args:
71
+ N: Number of filters in autoencoder
72
+ L: Length of the filters (in samples)
73
+ B: Number of channels in bottleneck 1 × 1-conv block
74
+ H: Number of channels in convolutional blocks
75
+ P: Kernel size in convolutional blocks
76
+ X: Number of convolutional blocks in each repeat
77
+ R: Number of repeats
78
+ C: Number of speakers
79
+ norm_type: BN, gLN, cLN
80
+ causal: causal or non-causal
81
+ mask_nonlinear: use which non-linear function to generate mask
82
+ """
83
+ super(ConvTasNet, self).__init__()
84
+ # Hyper-parameter
85
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
86
+ self.norm_type = norm_type
87
+ self.causal = causal
88
+ self.mask_nonlinear = mask_nonlinear
89
+ self.audio_channels = audio_channels
90
+ self.samplerate = samplerate
91
+ # Components
92
+ self.encoder = Encoder(L, N, audio_channels)
93
+ self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
94
+ self.decoder = Decoder(N, L, audio_channels)
95
+ # init
96
+ for p in self.parameters():
97
+ if p.dim() > 1:
98
+ nn.init.xavier_normal_(p)
99
+
100
+ def valid_length(self, length):
101
+ return length
102
+
103
+ def forward(self, mixture):
104
+ """
105
+ Args:
106
+ mixture: [M, T], M is batch size, T is #samples
107
+ Returns:
108
+ est_source: [M, C, T]
109
+ """
110
+ mixture_w = self.encoder(mixture)
111
+ est_mask = self.separator(mixture_w)
112
+ est_source = self.decoder(mixture_w, est_mask)
113
+
114
+ # T changed after conv1d in encoder, fix it here
115
+ T_origin = mixture.size(-1)
116
+ T_conv = est_source.size(-1)
117
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
118
+ return est_source
119
+
120
+
121
+ class Encoder(nn.Module):
122
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer."""
123
+
124
+ def __init__(self, L, N, audio_channels):
125
+ super(Encoder, self).__init__()
126
+ # Hyper-parameter
127
+ self.L, self.N = L, N
128
+ # Components
129
+ # 50% overlap
130
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
131
+
132
+ def forward(self, mixture):
133
+ """
134
+ Args:
135
+ mixture: [M, T], M is batch size, T is #samples
136
+ Returns:
137
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
138
+ """
139
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
140
+ return mixture_w
141
+
142
+
143
+ class Decoder(nn.Module):
144
+ def __init__(self, N, L, audio_channels):
145
+ super(Decoder, self).__init__()
146
+ # Hyper-parameter
147
+ self.N, self.L = N, L
148
+ self.audio_channels = audio_channels
149
+ # Components
150
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
151
+
152
+ def forward(self, mixture_w, est_mask):
153
+ """
154
+ Args:
155
+ mixture_w: [M, N, K]
156
+ est_mask: [M, C, N, K]
157
+ Returns:
158
+ est_source: [M, C, T]
159
+ """
160
+ # D = W * M
161
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
162
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
163
+ # S = DV
164
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
165
+ m, c, k, _ = est_source.size()
166
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
167
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
168
+ return est_source
169
+
170
+
171
+ class TemporalConvNet(nn.Module):
172
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"):
173
+ """
174
+ Args:
175
+ N: Number of filters in autoencoder
176
+ B: Number of channels in bottleneck 1 × 1-conv block
177
+ H: Number of channels in convolutional blocks
178
+ P: Kernel size in convolutional blocks
179
+ X: Number of convolutional blocks in each repeat
180
+ R: Number of repeats
181
+ C: Number of speakers
182
+ norm_type: BN, gLN, cLN
183
+ causal: causal or non-causal
184
+ mask_nonlinear: use which non-linear function to generate mask
185
+ """
186
+ super(TemporalConvNet, self).__init__()
187
+ # Hyper-parameter
188
+ self.C = C
189
+ self.mask_nonlinear = mask_nonlinear
190
+ # Components
191
+ # [M, N, K] -> [M, N, K]
192
+ layer_norm = ChannelwiseLayerNorm(N)
193
+ # [M, N, K] -> [M, B, K]
194
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
195
+ # [M, B, K] -> [M, B, K]
196
+ repeats = []
197
+ for r in range(R):
198
+ blocks = []
199
+ for x in range(X):
200
+ dilation = 2**x
201
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
202
+ blocks += [TemporalBlock(B, H, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal)]
203
+ repeats += [nn.Sequential(*blocks)]
204
+ temporal_conv_net = nn.Sequential(*repeats)
205
+ # [M, B, K] -> [M, C*N, K]
206
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
207
+ # Put together
208
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1)
209
+
210
+ def forward(self, mixture_w):
211
+ """
212
+ Keep this API same with TasNet
213
+ Args:
214
+ mixture_w: [M, N, K], M is batch size
215
+ returns:
216
+ est_mask: [M, C, N, K]
217
+ """
218
+ M, N, K = mixture_w.size()
219
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
220
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
221
+ if self.mask_nonlinear == "softmax":
222
+ est_mask = F.softmax(score, dim=1)
223
+ elif self.mask_nonlinear == "relu":
224
+ est_mask = F.relu(score)
225
+ else:
226
+ raise ValueError("Unsupported mask non-linear function")
227
+ return est_mask
228
+
229
+
230
+ class TemporalBlock(nn.Module):
231
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
232
+ super(TemporalBlock, self).__init__()
233
+ # [M, B, K] -> [M, H, K]
234
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
235
+ prelu = nn.PReLU()
236
+ norm = chose_norm(norm_type, out_channels)
237
+ # [M, H, K] -> [M, B, K]
238
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)
239
+ # Put together
240
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
241
+
242
+ def forward(self, x):
243
+ """
244
+ Args:
245
+ x: [M, B, K]
246
+ Returns:
247
+ [M, B, K]
248
+ """
249
+ residual = x
250
+ out = self.net(x)
251
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
252
+ return out + residual # look like w/o F.relu is better than w/ F.relu
253
+ # return F.relu(out + residual)
254
+
255
+
256
+ class DepthwiseSeparableConv(nn.Module):
257
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
258
+ super(DepthwiseSeparableConv, self).__init__()
259
+ # Use `groups` option to implement depthwise convolution
260
+ # [M, H, K] -> [M, H, K]
261
+ depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
262
+ if causal:
263
+ chomp = Chomp1d(padding)
264
+ prelu = nn.PReLU()
265
+ norm = chose_norm(norm_type, in_channels)
266
+ # [M, H, K] -> [M, B, K]
267
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
268
+ # Put together
269
+ if causal:
270
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
271
+ else:
272
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
273
+
274
+ def forward(self, x):
275
+ """
276
+ Args:
277
+ x: [M, H, K]
278
+ Returns:
279
+ result: [M, B, K]
280
+ """
281
+ return self.net(x)
282
+
283
+
284
+ class Chomp1d(nn.Module):
285
+ """To ensure the output length is the same as the input."""
286
+
287
+ def __init__(self, chomp_size):
288
+ super(Chomp1d, self).__init__()
289
+ self.chomp_size = chomp_size
290
+
291
+ def forward(self, x):
292
+ """
293
+ Args:
294
+ x: [M, H, Kpad]
295
+ Returns:
296
+ [M, H, K]
297
+ """
298
+ return x[:, :, : -self.chomp_size].contiguous()
299
+
300
+
301
+ def chose_norm(norm_type, channel_size):
302
+ """The input of normlization will be (M, C, K), where M is batch size,
303
+ C is channel size and K is sequence length.
304
+ """
305
+ if norm_type == "gLN":
306
+ return GlobalLayerNorm(channel_size)
307
+ elif norm_type == "cLN":
308
+ return ChannelwiseLayerNorm(channel_size)
309
+ elif norm_type == "id":
310
+ return nn.Identity()
311
+ else: # norm_type == "BN":
312
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
313
+ # along M and K, so this BN usage is right.
314
+ return nn.BatchNorm1d(channel_size)
315
+
316
+
317
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
318
+ class ChannelwiseLayerNorm(nn.Module):
319
+ """Channel-wise Layer Normalization (cLN)"""
320
+
321
+ def __init__(self, channel_size):
322
+ super(ChannelwiseLayerNorm, self).__init__()
323
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
324
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
325
+ self.reset_parameters()
326
+
327
+ def reset_parameters(self):
328
+ self.gamma.data.fill_(1)
329
+ self.beta.data.zero_()
330
+
331
+ def forward(self, y):
332
+ """
333
+ Args:
334
+ y: [M, N, K], M is batch size, N is channel size, K is length
335
+ Returns:
336
+ cLN_y: [M, N, K]
337
+ """
338
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
339
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
340
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
341
+ return cLN_y
342
+
343
+
344
+ class GlobalLayerNorm(nn.Module):
345
+ """Global Layer Normalization (gLN)"""
346
+
347
+ def __init__(self, channel_size):
348
+ super(GlobalLayerNorm, self).__init__()
349
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
350
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
351
+ self.reset_parameters()
352
+
353
+ def reset_parameters(self):
354
+ self.gamma.data.fill_(1)
355
+ self.beta.data.zero_()
356
+
357
+ def forward(self, y):
358
+ """
359
+ Args:
360
+ y: [M, N, K], M is batch size, N is channel size, K is length
361
+ Returns:
362
+ gLN_y: [M, N, K]
363
+ """
364
+ # TODO: in torch 1.0, torch.mean() support dim list
365
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
366
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
367
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
368
+ return gLN_y
369
+
370
+
371
+ if __name__ == "__main__":
372
+ torch.manual_seed(123)
373
+ M, N, L, T = 2, 3, 4, 12
374
+ K = 2 * T // L - 1
375
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
376
+ mixture = torch.randint(3, (M, T))
377
+ # test Encoder
378
+ encoder = Encoder(L, N)
379
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
380
+ mixture_w = encoder(mixture)
381
+ print("mixture", mixture)
382
+ print("U", encoder.conv1d_U.weight)
383
+ print("mixture_w", mixture_w)
384
+ print("mixture_w size", mixture_w.size())
385
+
386
+ # test TemporalConvNet
387
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
388
+ est_mask = separator(mixture_w)
389
+ print("est_mask", est_mask)
390
+
391
+ # test Decoder
392
+ decoder = Decoder(N, L)
393
+ est_mask = torch.randint(2, (B, K, C, N))
394
+ est_source = decoder(mixture_w, est_mask)
395
+ print("est_source", est_source)
396
+
397
+ # test Conv-TasNet
398
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
399
+ est_source = conv_tasnet(mixture)
400
+ print("est_source", est_source)
401
+ print("est_source size", est_source.size())
audio_separator/separator/uvr_lib_v5/demucs/tasnet_v2.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes, device=signal.device).unfold(0, subframes_per_frame, subframe_step)
57
+ frame = frame.long() # signal may in GPU or CPU
58
+ frame = frame.contiguous().view(-1)
59
+
60
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
61
+ result.index_add_(-2, frame, subframe_signal)
62
+ result = result.view(*outer_dimensions, -1)
63
+ return result
64
+
65
+
66
+ class ConvTasNet(nn.Module):
67
+ @capture_init
68
+ def __init__(self, sources, N=256, L=20, B=256, H=512, P=3, X=8, R=4, audio_channels=2, norm_type="gLN", causal=False, mask_nonlinear="relu", samplerate=44100, segment_length=44100 * 2 * 4):
69
+ """
70
+ Args:
71
+ sources: list of sources
72
+ N: Number of filters in autoencoder
73
+ L: Length of the filters (in samples)
74
+ B: Number of channels in bottleneck 1 × 1-conv block
75
+ H: Number of channels in convolutional blocks
76
+ P: Kernel size in convolutional blocks
77
+ X: Number of convolutional blocks in each repeat
78
+ R: Number of repeats
79
+ norm_type: BN, gLN, cLN
80
+ causal: causal or non-causal
81
+ mask_nonlinear: use which non-linear function to generate mask
82
+ """
83
+ super(ConvTasNet, self).__init__()
84
+ # Hyper-parameter
85
+ self.sources = sources
86
+ self.C = len(sources)
87
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
88
+ self.norm_type = norm_type
89
+ self.causal = causal
90
+ self.mask_nonlinear = mask_nonlinear
91
+ self.audio_channels = audio_channels
92
+ self.samplerate = samplerate
93
+ self.segment_length = segment_length
94
+ # Components
95
+ self.encoder = Encoder(L, N, audio_channels)
96
+ self.separator = TemporalConvNet(N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
97
+ self.decoder = Decoder(N, L, audio_channels)
98
+ # init
99
+ for p in self.parameters():
100
+ if p.dim() > 1:
101
+ nn.init.xavier_normal_(p)
102
+
103
+ def valid_length(self, length):
104
+ return length
105
+
106
+ def forward(self, mixture):
107
+ """
108
+ Args:
109
+ mixture: [M, T], M is batch size, T is #samples
110
+ Returns:
111
+ est_source: [M, C, T]
112
+ """
113
+ mixture_w = self.encoder(mixture)
114
+ est_mask = self.separator(mixture_w)
115
+ est_source = self.decoder(mixture_w, est_mask)
116
+
117
+ # T changed after conv1d in encoder, fix it here
118
+ T_origin = mixture.size(-1)
119
+ T_conv = est_source.size(-1)
120
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
121
+ return est_source
122
+
123
+
124
+ class Encoder(nn.Module):
125
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer."""
126
+
127
+ def __init__(self, L, N, audio_channels):
128
+ super(Encoder, self).__init__()
129
+ # Hyper-parameter
130
+ self.L, self.N = L, N
131
+ # Components
132
+ # 50% overlap
133
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
134
+
135
+ def forward(self, mixture):
136
+ """
137
+ Args:
138
+ mixture: [M, T], M is batch size, T is #samples
139
+ Returns:
140
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
141
+ """
142
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
143
+ return mixture_w
144
+
145
+
146
+ class Decoder(nn.Module):
147
+ def __init__(self, N, L, audio_channels):
148
+ super(Decoder, self).__init__()
149
+ # Hyper-parameter
150
+ self.N, self.L = N, L
151
+ self.audio_channels = audio_channels
152
+ # Components
153
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
154
+
155
+ def forward(self, mixture_w, est_mask):
156
+ """
157
+ Args:
158
+ mixture_w: [M, N, K]
159
+ est_mask: [M, C, N, K]
160
+ Returns:
161
+ est_source: [M, C, T]
162
+ """
163
+ # D = W * M
164
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
165
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
166
+ # S = DV
167
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
168
+ m, c, k, _ = est_source.size()
169
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
170
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
171
+ return est_source
172
+
173
+
174
+ class TemporalConvNet(nn.Module):
175
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"):
176
+ """
177
+ Args:
178
+ N: Number of filters in autoencoder
179
+ B: Number of channels in bottleneck 1 × 1-conv block
180
+ H: Number of channels in convolutional blocks
181
+ P: Kernel size in convolutional blocks
182
+ X: Number of convolutional blocks in each repeat
183
+ R: Number of repeats
184
+ C: Number of speakers
185
+ norm_type: BN, gLN, cLN
186
+ causal: causal or non-causal
187
+ mask_nonlinear: use which non-linear function to generate mask
188
+ """
189
+ super(TemporalConvNet, self).__init__()
190
+ # Hyper-parameter
191
+ self.C = C
192
+ self.mask_nonlinear = mask_nonlinear
193
+ # Components
194
+ # [M, N, K] -> [M, N, K]
195
+ layer_norm = ChannelwiseLayerNorm(N)
196
+ # [M, N, K] -> [M, B, K]
197
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
198
+ # [M, B, K] -> [M, B, K]
199
+ repeats = []
200
+ for r in range(R):
201
+ blocks = []
202
+ for x in range(X):
203
+ dilation = 2**x
204
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
205
+ blocks += [TemporalBlock(B, H, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal)]
206
+ repeats += [nn.Sequential(*blocks)]
207
+ temporal_conv_net = nn.Sequential(*repeats)
208
+ # [M, B, K] -> [M, C*N, K]
209
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
210
+ # Put together
211
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1)
212
+
213
+ def forward(self, mixture_w):
214
+ """
215
+ Keep this API same with TasNet
216
+ Args:
217
+ mixture_w: [M, N, K], M is batch size
218
+ returns:
219
+ est_mask: [M, C, N, K]
220
+ """
221
+ M, N, K = mixture_w.size()
222
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
223
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
224
+ if self.mask_nonlinear == "softmax":
225
+ est_mask = F.softmax(score, dim=1)
226
+ elif self.mask_nonlinear == "relu":
227
+ est_mask = F.relu(score)
228
+ else:
229
+ raise ValueError("Unsupported mask non-linear function")
230
+ return est_mask
231
+
232
+
233
+ class TemporalBlock(nn.Module):
234
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
235
+ super(TemporalBlock, self).__init__()
236
+ # [M, B, K] -> [M, H, K]
237
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
238
+ prelu = nn.PReLU()
239
+ norm = chose_norm(norm_type, out_channels)
240
+ # [M, H, K] -> [M, B, K]
241
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)
242
+ # Put together
243
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
244
+
245
+ def forward(self, x):
246
+ """
247
+ Args:
248
+ x: [M, B, K]
249
+ Returns:
250
+ [M, B, K]
251
+ """
252
+ residual = x
253
+ out = self.net(x)
254
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
255
+ return out + residual # look like w/o F.relu is better than w/ F.relu
256
+ # return F.relu(out + residual)
257
+
258
+
259
+ class DepthwiseSeparableConv(nn.Module):
260
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
261
+ super(DepthwiseSeparableConv, self).__init__()
262
+ # Use `groups` option to implement depthwise convolution
263
+ # [M, H, K] -> [M, H, K]
264
+ depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
265
+ if causal:
266
+ chomp = Chomp1d(padding)
267
+ prelu = nn.PReLU()
268
+ norm = chose_norm(norm_type, in_channels)
269
+ # [M, H, K] -> [M, B, K]
270
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
271
+ # Put together
272
+ if causal:
273
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
274
+ else:
275
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
276
+
277
+ def forward(self, x):
278
+ """
279
+ Args:
280
+ x: [M, H, K]
281
+ Returns:
282
+ result: [M, B, K]
283
+ """
284
+ return self.net(x)
285
+
286
+
287
+ class Chomp1d(nn.Module):
288
+ """To ensure the output length is the same as the input."""
289
+
290
+ def __init__(self, chomp_size):
291
+ super(Chomp1d, self).__init__()
292
+ self.chomp_size = chomp_size
293
+
294
+ def forward(self, x):
295
+ """
296
+ Args:
297
+ x: [M, H, Kpad]
298
+ Returns:
299
+ [M, H, K]
300
+ """
301
+ return x[:, :, : -self.chomp_size].contiguous()
302
+
303
+
304
+ def chose_norm(norm_type, channel_size):
305
+ """The input of normlization will be (M, C, K), where M is batch size,
306
+ C is channel size and K is sequence length.
307
+ """
308
+ if norm_type == "gLN":
309
+ return GlobalLayerNorm(channel_size)
310
+ elif norm_type == "cLN":
311
+ return ChannelwiseLayerNorm(channel_size)
312
+ elif norm_type == "id":
313
+ return nn.Identity()
314
+ else: # norm_type == "BN":
315
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
316
+ # along M and K, so this BN usage is right.
317
+ return nn.BatchNorm1d(channel_size)
318
+
319
+
320
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
321
+ class ChannelwiseLayerNorm(nn.Module):
322
+ """Channel-wise Layer Normalization (cLN)"""
323
+
324
+ def __init__(self, channel_size):
325
+ super(ChannelwiseLayerNorm, self).__init__()
326
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
327
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
328
+ self.reset_parameters()
329
+
330
+ def reset_parameters(self):
331
+ self.gamma.data.fill_(1)
332
+ self.beta.data.zero_()
333
+
334
+ def forward(self, y):
335
+ """
336
+ Args:
337
+ y: [M, N, K], M is batch size, N is channel size, K is length
338
+ Returns:
339
+ cLN_y: [M, N, K]
340
+ """
341
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
342
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
343
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
344
+ return cLN_y
345
+
346
+
347
+ class GlobalLayerNorm(nn.Module):
348
+ """Global Layer Normalization (gLN)"""
349
+
350
+ def __init__(self, channel_size):
351
+ super(GlobalLayerNorm, self).__init__()
352
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
353
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
354
+ self.reset_parameters()
355
+
356
+ def reset_parameters(self):
357
+ self.gamma.data.fill_(1)
358
+ self.beta.data.zero_()
359
+
360
+ def forward(self, y):
361
+ """
362
+ Args:
363
+ y: [M, N, K], M is batch size, N is channel size, K is length
364
+ Returns:
365
+ gLN_y: [M, N, K]
366
+ """
367
+ # TODO: in torch 1.0, torch.mean() support dim list
368
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
369
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
370
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
371
+ return gLN_y
372
+
373
+
374
+ if __name__ == "__main__":
375
+ torch.manual_seed(123)
376
+ M, N, L, T = 2, 3, 4, 12
377
+ K = 2 * T // L - 1
378
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
379
+ mixture = torch.randint(3, (M, T))
380
+ # test Encoder
381
+ encoder = Encoder(L, N)
382
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
383
+ mixture_w = encoder(mixture)
384
+ print("mixture", mixture)
385
+ print("U", encoder.conv1d_U.weight)
386
+ print("mixture_w", mixture_w)
387
+ print("mixture_w size", mixture_w.size())
388
+
389
+ # test TemporalConvNet
390
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
391
+ est_mask = separator(mixture_w)
392
+ print("est_mask", est_mask)
393
+
394
+ # test Decoder
395
+ decoder = Decoder(N, L)
396
+ est_mask = torch.randint(2, (B, K, C, N))
397
+ est_source = decoder(mixture_w, est_mask)
398
+ print("est_source", est_source)
399
+
400
+ # test Conv-TasNet
401
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
402
+ est_source = conv_tasnet(mixture)
403
+ print("est_source", est_source)
404
+ print("est_source size", est_source.size())
audio_separator/separator/uvr_lib_v5/demucs/transformer.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Meta, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+
8
+ import random
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import math
16
+ from einops import rearrange
17
+
18
+
19
+ def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000):
20
+ # We aim for TBC format
21
+ assert dim % 2 == 0
22
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
23
+ half_dim = dim // 2
24
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
25
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
26
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
27
+
28
+
29
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
30
+ """
31
+ :param d_model: dimension of the model
32
+ :param height: height of the positions
33
+ :param width: width of the positions
34
+ :return: d_model*height*width position matrix
35
+ """
36
+ if d_model % 4 != 0:
37
+ raise ValueError("Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model))
38
+ pe = torch.zeros(d_model, height, width)
39
+ # Each dimension use half of d_model
40
+ d_model = int(d_model / 2)
41
+ div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
42
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
43
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
44
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
45
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
46
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
47
+ pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
48
+
49
+ return pe[None, :].to(device)
50
+
51
+
52
+ def create_sin_embedding_cape(
53
+ length: int,
54
+ dim: int,
55
+ batch_size: int,
56
+ mean_normalize: bool,
57
+ augment: bool, # True during training
58
+ max_global_shift: float = 0.0, # delta max
59
+ max_local_shift: float = 0.0, # epsilon max
60
+ max_scale: float = 1.0,
61
+ device: str = "cpu",
62
+ max_period: float = 10000.0,
63
+ ):
64
+ # We aim for TBC format
65
+ assert dim % 2 == 0
66
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
67
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
68
+ if mean_normalize:
69
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
70
+
71
+ if augment:
72
+ delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
73
+ delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
74
+ log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
75
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
76
+
77
+ pos = pos.to(device)
78
+
79
+ half_dim = dim // 2
80
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
81
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
82
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
83
+
84
+
85
+ def get_causal_mask(length):
86
+ pos = torch.arange(length)
87
+ return pos > pos[:, None]
88
+
89
+
90
+ def get_elementary_mask(T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device):
91
+ """
92
+ When the input of the Decoder has length T1 and the output T2
93
+ The mask matrix has shape (T2, T1)
94
+ """
95
+ assert mask_type in ["diag", "jmask", "random", "global"]
96
+
97
+ if mask_type == "global":
98
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
99
+ mask[:, :global_window] = True
100
+ line_window = int(global_window * T2 / T1)
101
+ mask[:line_window, :] = True
102
+
103
+ if mask_type == "diag":
104
+
105
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
106
+ rows = torch.arange(T2)[:, None]
107
+ cols = (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)).long().clamp(0, T1 - 1)
108
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
109
+
110
+ elif mask_type == "jmask":
111
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
112
+ rows = torch.arange(T2 + 2)[:, None]
113
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
114
+ t = (t * (t + 1) / 2).int()
115
+ t = torch.cat([-t.flip(0)[:-1], t])
116
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
117
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
118
+ mask = mask[1:-1, 1:-1]
119
+
120
+ elif mask_type == "random":
121
+ gene = torch.Generator(device=device)
122
+ gene.manual_seed(mask_random_seed)
123
+ mask = torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) > sparsity
124
+
125
+ mask = mask.to(device)
126
+ return mask
127
+
128
+
129
+ def get_mask(T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device):
130
+ """
131
+ Return a SparseCSRTensor mask that is a combination of elementary masks
132
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
133
+ """
134
+ from xformers.sparse import SparseCSRTensor
135
+
136
+ # create a list
137
+ mask_types = mask_type.split("_")
138
+
139
+ all_masks = [get_elementary_mask(T1, T2, mask, sparse_attn_window, global_window, mask_random_seed, sparsity, device) for mask in mask_types]
140
+
141
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
142
+
143
+ return SparseCSRTensor.from_dense(final_mask[None])
144
+
145
+
146
+ class ScaledEmbedding(nn.Module):
147
+ def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 1.0, boost: float = 3.0):
148
+ super().__init__()
149
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
150
+ self.embedding.weight.data *= scale / boost
151
+ self.boost = boost
152
+
153
+ @property
154
+ def weight(self):
155
+ return self.embedding.weight * self.boost
156
+
157
+ def forward(self, x):
158
+ return self.embedding(x) * self.boost
159
+
160
+
161
+ class LayerScale(nn.Module):
162
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
163
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
164
+ """
165
+
166
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
167
+ """
168
+ channel_last = False corresponds to (B, C, T) tensors
169
+ channel_last = True corresponds to (T, B, C) tensors
170
+ """
171
+ super().__init__()
172
+ self.channel_last = channel_last
173
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
174
+ self.scale.data[:] = init
175
+
176
+ def forward(self, x):
177
+ if self.channel_last:
178
+ return self.scale * x
179
+ else:
180
+ return self.scale[:, None] * x
181
+
182
+
183
+ class MyGroupNorm(nn.GroupNorm):
184
+ def __init__(self, *args, **kwargs):
185
+ super().__init__(*args, **kwargs)
186
+
187
+ def forward(self, x):
188
+ """
189
+ x: (B, T, C)
190
+ if num_groups=1: Normalisation on all T and C together for each B
191
+ """
192
+ x = x.transpose(1, 2)
193
+ return super().forward(x).transpose(1, 2)
194
+
195
+
196
+ class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
197
+ def __init__(
198
+ self,
199
+ d_model,
200
+ nhead,
201
+ dim_feedforward=2048,
202
+ dropout=0.1,
203
+ activation=F.relu,
204
+ group_norm=0,
205
+ norm_first=False,
206
+ norm_out=False,
207
+ layer_norm_eps=1e-5,
208
+ layer_scale=False,
209
+ init_values=1e-4,
210
+ device=None,
211
+ dtype=None,
212
+ sparse=False,
213
+ mask_type="diag",
214
+ mask_random_seed=42,
215
+ sparse_attn_window=500,
216
+ global_window=50,
217
+ auto_sparsity=False,
218
+ sparsity=0.95,
219
+ batch_first=False,
220
+ ):
221
+ factory_kwargs = {"device": device, "dtype": dtype}
222
+ super().__init__(
223
+ d_model=d_model,
224
+ nhead=nhead,
225
+ dim_feedforward=dim_feedforward,
226
+ dropout=dropout,
227
+ activation=activation,
228
+ layer_norm_eps=layer_norm_eps,
229
+ batch_first=batch_first,
230
+ norm_first=norm_first,
231
+ device=device,
232
+ dtype=dtype,
233
+ )
234
+ self.sparse = sparse
235
+ self.auto_sparsity = auto_sparsity
236
+ if sparse:
237
+ if not auto_sparsity:
238
+ self.mask_type = mask_type
239
+ self.sparse_attn_window = sparse_attn_window
240
+ self.global_window = global_window
241
+ self.sparsity = sparsity
242
+ if group_norm:
243
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
244
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
245
+
246
+ self.norm_out = None
247
+ if self.norm_first & norm_out:
248
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
249
+ self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
250
+ self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
251
+
252
+ if sparse:
253
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0)
254
+ self.__setattr__("src_mask", torch.zeros(1, 1))
255
+ self.mask_random_seed = mask_random_seed
256
+
257
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
258
+ """
259
+ if batch_first = False, src shape is (T, B, C)
260
+ the case where batch_first=True is not covered
261
+ """
262
+ device = src.device
263
+ x = src
264
+ T, B, C = x.shape
265
+ if self.sparse and not self.auto_sparsity:
266
+ assert src_mask is None
267
+ src_mask = self.src_mask
268
+ if src_mask.shape[-1] != T:
269
+ src_mask = get_mask(T, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device)
270
+ self.__setattr__("src_mask", src_mask)
271
+
272
+ if self.norm_first:
273
+ x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
274
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
275
+
276
+ if self.norm_out:
277
+ x = self.norm_out(x)
278
+ else:
279
+ x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
280
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
281
+
282
+ return x
283
+
284
+
285
+ class CrossTransformerEncoderLayer(nn.Module):
286
+ def __init__(
287
+ self,
288
+ d_model: int,
289
+ nhead: int,
290
+ dim_feedforward: int = 2048,
291
+ dropout: float = 0.1,
292
+ activation=F.relu,
293
+ layer_norm_eps: float = 1e-5,
294
+ layer_scale: bool = False,
295
+ init_values: float = 1e-4,
296
+ norm_first: bool = False,
297
+ group_norm: bool = False,
298
+ norm_out: bool = False,
299
+ sparse=False,
300
+ mask_type="diag",
301
+ mask_random_seed=42,
302
+ sparse_attn_window=500,
303
+ global_window=50,
304
+ sparsity=0.95,
305
+ auto_sparsity=None,
306
+ device=None,
307
+ dtype=None,
308
+ batch_first=False,
309
+ ):
310
+ factory_kwargs = {"device": device, "dtype": dtype}
311
+ super().__init__()
312
+
313
+ self.sparse = sparse
314
+ self.auto_sparsity = auto_sparsity
315
+ if sparse:
316
+ if not auto_sparsity:
317
+ self.mask_type = mask_type
318
+ self.sparse_attn_window = sparse_attn_window
319
+ self.global_window = global_window
320
+ self.sparsity = sparsity
321
+
322
+ self.cross_attn: nn.Module
323
+ self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
324
+ # Implementation of Feedforward model
325
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
326
+ self.dropout = nn.Dropout(dropout)
327
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
328
+
329
+ self.norm_first = norm_first
330
+ self.norm1: nn.Module
331
+ self.norm2: nn.Module
332
+ self.norm3: nn.Module
333
+ if group_norm:
334
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
335
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
336
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
337
+ else:
338
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
339
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
340
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
341
+
342
+ self.norm_out = None
343
+ if self.norm_first & norm_out:
344
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
345
+
346
+ self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
347
+ self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
348
+
349
+ self.dropout1 = nn.Dropout(dropout)
350
+ self.dropout2 = nn.Dropout(dropout)
351
+
352
+ # Legacy string support for activation function.
353
+ if isinstance(activation, str):
354
+ self.activation = self._get_activation_fn(activation)
355
+ else:
356
+ self.activation = activation
357
+
358
+ if sparse:
359
+ self.cross_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0)
360
+ if not auto_sparsity:
361
+ self.__setattr__("mask", torch.zeros(1, 1))
362
+ self.mask_random_seed = mask_random_seed
363
+
364
+ def forward(self, q, k, mask=None):
365
+ """
366
+ Args:
367
+ q: tensor of shape (T, B, C)
368
+ k: tensor of shape (S, B, C)
369
+ mask: tensor of shape (T, S)
370
+
371
+ """
372
+ device = q.device
373
+ T, B, C = q.shape
374
+ S, B, C = k.shape
375
+ if self.sparse and not self.auto_sparsity:
376
+ assert mask is None
377
+ mask = self.mask
378
+ if mask.shape[-1] != S or mask.shape[-2] != T:
379
+ mask = get_mask(S, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device)
380
+ self.__setattr__("mask", mask)
381
+
382
+ if self.norm_first:
383
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
384
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
385
+ if self.norm_out:
386
+ x = self.norm_out(x)
387
+ else:
388
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
389
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
390
+
391
+ return x
392
+
393
+ # self-attention block
394
+ def _ca_block(self, q, k, attn_mask=None):
395
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
396
+ return self.dropout1(x)
397
+
398
+ # feed forward block
399
+ def _ff_block(self, x):
400
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
401
+ return self.dropout2(x)
402
+
403
+ def _get_activation_fn(self, activation):
404
+ if activation == "relu":
405
+ return F.relu
406
+ elif activation == "gelu":
407
+ return F.gelu
408
+
409
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
410
+
411
+
412
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
413
+
414
+
415
+ class CrossTransformerEncoder(nn.Module):
416
+ def __init__(
417
+ self,
418
+ dim: int,
419
+ emb: str = "sin",
420
+ hidden_scale: float = 4.0,
421
+ num_heads: int = 8,
422
+ num_layers: int = 6,
423
+ cross_first: bool = False,
424
+ dropout: float = 0.0,
425
+ max_positions: int = 1000,
426
+ norm_in: bool = True,
427
+ norm_in_group: bool = False,
428
+ group_norm: int = False,
429
+ norm_first: bool = False,
430
+ norm_out: bool = False,
431
+ max_period: float = 10000.0,
432
+ weight_decay: float = 0.0,
433
+ lr: tp.Optional[float] = None,
434
+ layer_scale: bool = False,
435
+ gelu: bool = True,
436
+ sin_random_shift: int = 0,
437
+ weight_pos_embed: float = 1.0,
438
+ cape_mean_normalize: bool = True,
439
+ cape_augment: bool = True,
440
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
441
+ sparse_self_attn: bool = False,
442
+ sparse_cross_attn: bool = False,
443
+ mask_type: str = "diag",
444
+ mask_random_seed: int = 42,
445
+ sparse_attn_window: int = 500,
446
+ global_window: int = 50,
447
+ auto_sparsity: bool = False,
448
+ sparsity: float = 0.95,
449
+ ):
450
+ super().__init__()
451
+ """
452
+ """
453
+ assert dim % num_heads == 0
454
+
455
+ hidden_dim = int(dim * hidden_scale)
456
+
457
+ self.num_layers = num_layers
458
+ # classic parity = 1 means that if idx%2 == 1 there is a
459
+ # classical encoder else there is a cross encoder
460
+ self.classic_parity = 1 if cross_first else 0
461
+ self.emb = emb
462
+ self.max_period = max_period
463
+ self.weight_decay = weight_decay
464
+ self.weight_pos_embed = weight_pos_embed
465
+ self.sin_random_shift = sin_random_shift
466
+ if emb == "cape":
467
+ self.cape_mean_normalize = cape_mean_normalize
468
+ self.cape_augment = cape_augment
469
+ self.cape_glob_loc_scale = cape_glob_loc_scale
470
+ if emb == "scaled":
471
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
472
+
473
+ self.lr = lr
474
+
475
+ activation: tp.Any = F.gelu if gelu else F.relu
476
+
477
+ self.norm_in: nn.Module
478
+ self.norm_in_t: nn.Module
479
+ if norm_in:
480
+ self.norm_in = nn.LayerNorm(dim)
481
+ self.norm_in_t = nn.LayerNorm(dim)
482
+ elif norm_in_group:
483
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
484
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
485
+ else:
486
+ self.norm_in = nn.Identity()
487
+ self.norm_in_t = nn.Identity()
488
+
489
+ # spectrogram layers
490
+ self.layers = nn.ModuleList()
491
+ # temporal layers
492
+ self.layers_t = nn.ModuleList()
493
+
494
+ kwargs_common = {
495
+ "d_model": dim,
496
+ "nhead": num_heads,
497
+ "dim_feedforward": hidden_dim,
498
+ "dropout": dropout,
499
+ "activation": activation,
500
+ "group_norm": group_norm,
501
+ "norm_first": norm_first,
502
+ "norm_out": norm_out,
503
+ "layer_scale": layer_scale,
504
+ "mask_type": mask_type,
505
+ "mask_random_seed": mask_random_seed,
506
+ "sparse_attn_window": sparse_attn_window,
507
+ "global_window": global_window,
508
+ "sparsity": sparsity,
509
+ "auto_sparsity": auto_sparsity,
510
+ "batch_first": True,
511
+ }
512
+
513
+ kwargs_classic_encoder = dict(kwargs_common)
514
+ kwargs_classic_encoder.update({"sparse": sparse_self_attn})
515
+ kwargs_cross_encoder = dict(kwargs_common)
516
+ kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
517
+
518
+ for idx in range(num_layers):
519
+ if idx % 2 == self.classic_parity:
520
+
521
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
522
+ self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
523
+
524
+ else:
525
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
526
+
527
+ self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
528
+
529
+ def forward(self, x, xt):
530
+ B, C, Fr, T1 = x.shape
531
+ pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period) # (1, C, Fr, T1)
532
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
533
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
534
+ x = self.norm_in(x)
535
+ x = x + self.weight_pos_embed * pos_emb_2d
536
+
537
+ B, C, T2 = xt.shape
538
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
539
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
540
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
541
+ xt = self.norm_in_t(xt)
542
+ xt = xt + self.weight_pos_embed * pos_emb
543
+
544
+ for idx in range(self.num_layers):
545
+ if idx % 2 == self.classic_parity:
546
+ x = self.layers[idx](x)
547
+ xt = self.layers_t[idx](xt)
548
+ else:
549
+ old_x = x
550
+ x = self.layers[idx](x, xt)
551
+ xt = self.layers_t[idx](xt, old_x)
552
+
553
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
554
+ xt = rearrange(xt, "b t2 c -> b c t2")
555
+ return x, xt
556
+
557
+ def _get_pos_embedding(self, T, B, C, device):
558
+ if self.emb == "sin":
559
+ shift = random.randrange(self.sin_random_shift + 1)
560
+ pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
561
+ elif self.emb == "cape":
562
+ if self.training:
563
+ pos_emb = create_sin_embedding_cape(
564
+ T,
565
+ C,
566
+ B,
567
+ device=device,
568
+ max_period=self.max_period,
569
+ mean_normalize=self.cape_mean_normalize,
570
+ augment=self.cape_augment,
571
+ max_global_shift=self.cape_glob_loc_scale[0],
572
+ max_local_shift=self.cape_glob_loc_scale[1],
573
+ max_scale=self.cape_glob_loc_scale[2],
574
+ )
575
+ else:
576
+ pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
577
+
578
+ elif self.emb == "scaled":
579
+ pos = torch.arange(T, device=device)
580
+ pos_emb = self.position_embeddings(pos)[:, None]
581
+
582
+ return pos_emb
583
+
584
+ def make_optim_group(self):
585
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
586
+ if self.lr is not None:
587
+ group["lr"] = self.lr
588
+ return group
589
+
590
+
591
+ # Attention Modules
592
+
593
+
594
+ class MultiheadAttention(nn.Module):
595
+ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, auto_sparsity=None):
596
+ super().__init__()
597
+ assert auto_sparsity is not None, "sanity check"
598
+ self.num_heads = num_heads
599
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
600
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
601
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
602
+ self.attn_drop = torch.nn.Dropout(dropout)
603
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
604
+ self.proj_drop = torch.nn.Dropout(dropout)
605
+ self.batch_first = batch_first
606
+ self.auto_sparsity = auto_sparsity
607
+
608
+ def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True):
609
+
610
+ if not self.batch_first: # N, B, C
611
+ query = query.permute(1, 0, 2) # B, N_q, C
612
+ key = key.permute(1, 0, 2) # B, N_k, C
613
+ value = value.permute(1, 0, 2) # B, N_k, C
614
+ B, N_q, C = query.shape
615
+ B, N_k, C = key.shape
616
+
617
+ q = self.q(query).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
618
+ q = q.flatten(0, 1)
619
+ k = self.k(key).reshape(B, N_k, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
620
+ k = k.flatten(0, 1)
621
+ v = self.v(value).reshape(B, N_k, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
622
+ v = v.flatten(0, 1)
623
+
624
+ if self.auto_sparsity:
625
+ assert attn_mask is None
626
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
627
+ else:
628
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
629
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
630
+
631
+ x = x.transpose(1, 2).reshape(B, N_q, C)
632
+ x = self.proj(x)
633
+ x = self.proj_drop(x)
634
+ if not self.batch_first:
635
+ x = x.permute(1, 0, 2)
636
+ return x, None
637
+
638
+
639
+ def scaled_query_key_softmax(q, k, att_mask):
640
+ from xformers.ops import masked_matmul
641
+
642
+ q = q / (k.size(-1)) ** 0.5
643
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
644
+ att = torch.nn.functional.softmax(att, -1)
645
+ return att
646
+
647
+
648
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
649
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
650
+ att = dropout(att)
651
+ y = att @ v
652
+ return y
653
+
654
+
655
+ def _compute_buckets(x, R):
656
+ qq = torch.einsum("btf,bfhi->bhti", x, R)
657
+ qq = torch.cat([qq, -qq], dim=-1)
658
+ buckets = qq.argmax(dim=-1)
659
+
660
+ return buckets.permute(0, 2, 1).byte().contiguous()
661
+
662
+
663
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
664
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
665
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
666
+
667
+ n_hashes = 32
668
+ proj_size = 4
669
+ query, key, value = [x.contiguous() for x in [query, key, value]]
670
+ with torch.no_grad():
671
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
672
+ bucket_query = _compute_buckets(query, R)
673
+ bucket_key = _compute_buckets(key, R)
674
+ row_offsets, column_indices = find_locations(bucket_query, bucket_key, sparsity, infer_sparsity)
675
+ return sparse_memory_efficient_attention(query, key, value, row_offsets, column_indices, attn_bias)
audio_separator/separator/uvr_lib_v5/demucs/utils.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ import errno
15
+ import functools
16
+ import hashlib
17
+ import inspect
18
+ import io
19
+ import os
20
+ import random
21
+ import socket
22
+ import tempfile
23
+ import warnings
24
+ import zlib
25
+
26
+ from diffq import UniformQuantizer, DiffQuantizer
27
+ import torch as th
28
+ import tqdm
29
+ from torch import distributed
30
+ from torch.nn import functional as F
31
+
32
+ import torch
33
+
34
+
35
+ def unfold(a, kernel_size, stride):
36
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
37
+ with K the kernel size, by extracting frames with the given stride.
38
+
39
+ This will pad the input so that `F = ceil(T / K)`.
40
+
41
+ see https://github.com/pytorch/pytorch/issues/60466
42
+ """
43
+ *shape, length = a.shape
44
+ n_frames = math.ceil(length / stride)
45
+ tgt_length = (n_frames - 1) * stride + kernel_size
46
+ a = F.pad(a, (0, tgt_length - length))
47
+ strides = list(a.stride())
48
+ assert strides[-1] == 1, "data should be contiguous"
49
+ strides = strides[:-1] + [stride, 1]
50
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
51
+
52
+
53
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
54
+ """
55
+ Center trim `tensor` with respect to `reference`, along the last dimension.
56
+ `reference` can also be a number, representing the length to trim to.
57
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
58
+ """
59
+ ref_size: int
60
+ if isinstance(reference, torch.Tensor):
61
+ ref_size = reference.size(-1)
62
+ else:
63
+ ref_size = reference
64
+ delta = tensor.size(-1) - ref_size
65
+ if delta < 0:
66
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
67
+ if delta:
68
+ tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
69
+ return tensor
70
+
71
+
72
+ def pull_metric(history: tp.List[dict], name: str):
73
+ out = []
74
+ for metrics in history:
75
+ metric = metrics
76
+ for part in name.split("."):
77
+ metric = metric[part]
78
+ out.append(metric)
79
+ return out
80
+
81
+
82
+ def EMA(beta: float = 1):
83
+ """
84
+ Exponential Moving Average callback.
85
+ Returns a single function that can be called to repeatidly update the EMA
86
+ with a dict of metrics. The callback will return
87
+ the new averaged dict of metrics.
88
+
89
+ Note that for `beta=1`, this is just plain averaging.
90
+ """
91
+ fix: tp.Dict[str, float] = defaultdict(float)
92
+ total: tp.Dict[str, float] = defaultdict(float)
93
+
94
+ def _update(metrics: dict, weight: float = 1) -> dict:
95
+ nonlocal total, fix
96
+ for key, value in metrics.items():
97
+ total[key] = total[key] * beta + weight * float(value)
98
+ fix[key] = fix[key] * beta + weight
99
+ return {key: tot / fix[key] for key, tot in total.items()}
100
+
101
+ return _update
102
+
103
+
104
+ def sizeof_fmt(num: float, suffix: str = "B"):
105
+ """
106
+ Given `num` bytes, return human readable size.
107
+ Taken from https://stackoverflow.com/a/1094933
108
+ """
109
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
110
+ if abs(num) < 1024.0:
111
+ return "%3.1f%s%s" % (num, unit, suffix)
112
+ num /= 1024.0
113
+ return "%.1f%s%s" % (num, "Yi", suffix)
114
+
115
+
116
+ @contextmanager
117
+ def temp_filenames(count: int, delete=True):
118
+ names = []
119
+ try:
120
+ for _ in range(count):
121
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
122
+ yield names
123
+ finally:
124
+ if delete:
125
+ for name in names:
126
+ os.unlink(name)
127
+
128
+
129
+ def average_metric(metric, count=1.0):
130
+ """
131
+ Average `metric` which should be a float across all hosts. `count` should be
132
+ the weight for this particular host (i.e. number of examples).
133
+ """
134
+ metric = th.tensor([count, count * metric], dtype=th.float32, device="cuda")
135
+ distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
136
+ return metric[1].item() / metric[0].item()
137
+
138
+
139
+ def free_port(host="", low=20000, high=40000):
140
+ """
141
+ Return a port number that is most likely free.
142
+ This could suffer from a race condition although
143
+ it should be quite rare.
144
+ """
145
+ sock = socket.socket()
146
+ while True:
147
+ port = random.randint(low, high)
148
+ try:
149
+ sock.bind((host, port))
150
+ except OSError as error:
151
+ if error.errno == errno.EADDRINUSE:
152
+ continue
153
+ raise
154
+ return port
155
+
156
+
157
+ def sizeof_fmt(num, suffix="B"):
158
+ """
159
+ Given `num` bytes, return human readable size.
160
+ Taken from https://stackoverflow.com/a/1094933
161
+ """
162
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
163
+ if abs(num) < 1024.0:
164
+ return "%3.1f%s%s" % (num, unit, suffix)
165
+ num /= 1024.0
166
+ return "%.1f%s%s" % (num, "Yi", suffix)
167
+
168
+
169
+ def human_seconds(seconds, display=".2f"):
170
+ """
171
+ Given `seconds` seconds, return human readable duration.
172
+ """
173
+ value = seconds * 1e6
174
+ ratios = [1e3, 1e3, 60, 60, 24]
175
+ names = ["us", "ms", "s", "min", "hrs", "days"]
176
+ last = names.pop(0)
177
+ for name, ratio in zip(names, ratios):
178
+ if value / ratio < 0.3:
179
+ break
180
+ value /= ratio
181
+ last = name
182
+ return f"{format(value, display)} {last}"
183
+
184
+
185
+ class TensorChunk:
186
+ def __init__(self, tensor, offset=0, length=None):
187
+ total_length = tensor.shape[-1]
188
+ assert offset >= 0
189
+ assert offset < total_length
190
+
191
+ if length is None:
192
+ length = total_length - offset
193
+ else:
194
+ length = min(total_length - offset, length)
195
+
196
+ self.tensor = tensor
197
+ self.offset = offset
198
+ self.length = length
199
+ self.device = tensor.device
200
+
201
+ @property
202
+ def shape(self):
203
+ shape = list(self.tensor.shape)
204
+ shape[-1] = self.length
205
+ return shape
206
+
207
+ def padded(self, target_length):
208
+ delta = target_length - self.length
209
+ total_length = self.tensor.shape[-1]
210
+ assert delta >= 0
211
+
212
+ start = self.offset - delta // 2
213
+ end = start + target_length
214
+
215
+ correct_start = max(0, start)
216
+ correct_end = min(total_length, end)
217
+
218
+ pad_left = correct_start - start
219
+ pad_right = end - correct_end
220
+
221
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
222
+ assert out.shape[-1] == target_length
223
+ return out
224
+
225
+
226
+ def tensor_chunk(tensor_or_chunk):
227
+ if isinstance(tensor_or_chunk, TensorChunk):
228
+ return tensor_or_chunk
229
+ else:
230
+ assert isinstance(tensor_or_chunk, th.Tensor)
231
+ return TensorChunk(tensor_or_chunk)
232
+
233
+
234
+ def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
235
+ """
236
+ Apply model to a given mixture.
237
+
238
+ Args:
239
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
240
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
241
+ all predictions are averaged. This effectively makes the model time equivariant
242
+ and improves SDR by up to 0.2 points.
243
+ split (bool): if True, the input will be broken down in 8 seconds extracts
244
+ and predictions will be performed individually on each and concatenated.
245
+ Useful for model with large memory footprint like Tasnet.
246
+ progress (bool): if True, show a progress bar (requires split=True)
247
+ """
248
+
249
+ channels, length = mix.size()
250
+ device = mix.device
251
+ progress_value = 0
252
+
253
+ if split:
254
+ out = th.zeros(4, channels, length, device=device)
255
+ shift = model.samplerate * 10
256
+ offsets = range(0, length, shift)
257
+ scale = 10
258
+ if progress:
259
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit="seconds")
260
+ for offset in offsets:
261
+ chunk = mix[..., offset : offset + shift]
262
+ if set_progress_bar:
263
+ progress_value += 1
264
+ set_progress_bar(0.1, (0.8 / len(offsets) * progress_value))
265
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
266
+ else:
267
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts)
268
+ out[..., offset : offset + shift] = chunk_out
269
+ offset += shift
270
+ return out
271
+ elif shifts:
272
+ max_shift = int(model.samplerate / 2)
273
+ mix = F.pad(mix, (max_shift, max_shift))
274
+ offsets = list(range(max_shift))
275
+ random.shuffle(offsets)
276
+ out = 0
277
+ for offset in offsets[:shifts]:
278
+ shifted = mix[..., offset : offset + length + max_shift]
279
+ if set_progress_bar:
280
+ shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
281
+ else:
282
+ shifted_out = apply_model_v1(model, shifted)
283
+ out += shifted_out[..., max_shift - offset : max_shift - offset + length]
284
+ out /= shifts
285
+ return out
286
+ else:
287
+ valid_length = model.valid_length(length)
288
+ delta = valid_length - length
289
+ padded = F.pad(mix, (delta // 2, delta - delta // 2))
290
+ with th.no_grad():
291
+ out = model(padded.unsqueeze(0))[0]
292
+ return center_trim(out, mix)
293
+
294
+
295
+ def apply_model_v2(model, mix, shifts=None, split=False, overlap=0.25, transition_power=1.0, progress=False, set_progress_bar=None):
296
+ """
297
+ Apply model to a given mixture.
298
+
299
+ Args:
300
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
301
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
302
+ all predictions are averaged. This effectively makes the model time equivariant
303
+ and improves SDR by up to 0.2 points.
304
+ split (bool): if True, the input will be broken down in 8 seconds extracts
305
+ and predictions will be performed individually on each and concatenated.
306
+ Useful for model with large memory footprint like Tasnet.
307
+ progress (bool): if True, show a progress bar (requires split=True)
308
+ """
309
+
310
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
311
+ device = mix.device
312
+ channels, length = mix.shape
313
+ progress_value = 0
314
+
315
+ if split:
316
+ out = th.zeros(len(model.sources), channels, length, device=device)
317
+ sum_weight = th.zeros(length, device=device)
318
+ segment = model.segment_length
319
+ stride = int((1 - overlap) * segment)
320
+ offsets = range(0, length, stride)
321
+ scale = stride / model.samplerate
322
+ if progress:
323
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit="seconds")
324
+ # We start from a triangle shaped weight, with maximal weight in the middle
325
+ # of the segment. Then we normalize and take to the power `transition_power`.
326
+ # Large values of transition power will lead to sharper transitions.
327
+ weight = th.cat([th.arange(1, segment // 2 + 1), th.arange(segment - segment // 2, 0, -1)]).to(device)
328
+ assert len(weight) == segment
329
+ # If the overlap < 50%, this will translate to linear transition when
330
+ # transition_power is 1.
331
+ weight = (weight / weight.max()) ** transition_power
332
+ for offset in offsets:
333
+ chunk = TensorChunk(mix, offset, segment)
334
+ if set_progress_bar:
335
+ progress_value += 1
336
+ set_progress_bar(0.1, (0.8 / len(offsets) * progress_value))
337
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
338
+ else:
339
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts)
340
+ chunk_length = chunk_out.shape[-1]
341
+ out[..., offset : offset + segment] += weight[:chunk_length] * chunk_out
342
+ sum_weight[offset : offset + segment] += weight[:chunk_length]
343
+ offset += segment
344
+ assert sum_weight.min() > 0
345
+ out /= sum_weight
346
+ return out
347
+ elif shifts:
348
+ max_shift = int(0.5 * model.samplerate)
349
+ mix = tensor_chunk(mix)
350
+ padded_mix = mix.padded(length + 2 * max_shift)
351
+ out = 0
352
+ for _ in range(shifts):
353
+ offset = random.randint(0, max_shift)
354
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
355
+
356
+ if set_progress_bar:
357
+ progress_value += 1
358
+ shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
359
+ else:
360
+ shifted_out = apply_model_v2(model, shifted)
361
+ out += shifted_out[..., max_shift - offset :]
362
+ out /= shifts
363
+ return out
364
+ else:
365
+ valid_length = model.valid_length(length)
366
+ mix = tensor_chunk(mix)
367
+ padded_mix = mix.padded(valid_length)
368
+ with th.no_grad():
369
+ out = model(padded_mix.unsqueeze(0))[0]
370
+ return center_trim(out, length)
371
+
372
+
373
+ @contextmanager
374
+ def temp_filenames(count, delete=True):
375
+ names = []
376
+ try:
377
+ for _ in range(count):
378
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
379
+ yield names
380
+ finally:
381
+ if delete:
382
+ for name in names:
383
+ os.unlink(name)
384
+
385
+
386
+ def get_quantizer(model, args, optimizer=None):
387
+ quantizer = None
388
+ if args.diffq:
389
+ quantizer = DiffQuantizer(model, min_size=args.q_min_size, group_size=8)
390
+ if optimizer is not None:
391
+ quantizer.setup_optimizer(optimizer)
392
+ elif args.qat:
393
+ quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.q_min_size)
394
+ return quantizer
395
+
396
+
397
+ def load_model(path, strict=False):
398
+ with warnings.catch_warnings():
399
+ warnings.simplefilter("ignore")
400
+ load_from = path
401
+ package = th.load(load_from, "cpu")
402
+
403
+ klass = package["klass"]
404
+ args = package["args"]
405
+ kwargs = package["kwargs"]
406
+
407
+ if strict:
408
+ model = klass(*args, **kwargs)
409
+ else:
410
+ sig = inspect.signature(klass)
411
+ for key in list(kwargs):
412
+ if key not in sig.parameters:
413
+ warnings.warn("Dropping inexistant parameter " + key)
414
+ del kwargs[key]
415
+ model = klass(*args, **kwargs)
416
+
417
+ state = package["state"]
418
+ training_args = package["training_args"]
419
+ quantizer = get_quantizer(model, training_args)
420
+
421
+ set_state(model, quantizer, state)
422
+ return model
423
+
424
+
425
+ def get_state(model, quantizer):
426
+ if quantizer is None:
427
+ state = {k: p.data.to("cpu") for k, p in model.state_dict().items()}
428
+ else:
429
+ state = quantizer.get_quantized_state()
430
+ buf = io.BytesIO()
431
+ th.save(state, buf)
432
+ state = {"compressed": zlib.compress(buf.getvalue())}
433
+ return state
434
+
435
+
436
+ def set_state(model, quantizer, state):
437
+ if quantizer is None:
438
+ model.load_state_dict(state)
439
+ else:
440
+ buf = io.BytesIO(zlib.decompress(state["compressed"]))
441
+ state = th.load(buf, "cpu")
442
+ quantizer.restore_quantized_state(state)
443
+
444
+ return state
445
+
446
+
447
+ def save_state(state, path):
448
+ buf = io.BytesIO()
449
+ th.save(state, buf)
450
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
451
+
452
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
453
+ path.write_bytes(buf.getvalue())
454
+
455
+
456
+ def save_model(model, quantizer, training_args, path):
457
+ args, kwargs = model._init_args_kwargs
458
+ klass = model.__class__
459
+
460
+ state = get_state(model, quantizer)
461
+
462
+ save_to = path
463
+ package = {"klass": klass, "args": args, "kwargs": kwargs, "state": state, "training_args": training_args}
464
+ th.save(package, save_to)
465
+
466
+
467
+ def capture_init(init):
468
+ @functools.wraps(init)
469
+ def __init__(self, *args, **kwargs):
470
+ self._init_args_kwargs = (args, kwargs)
471
+ init(self, *args, **kwargs)
472
+
473
+ return __init__
474
+
475
+
476
+ class DummyPoolExecutor:
477
+ class DummyResult:
478
+ def __init__(self, func, *args, **kwargs):
479
+ self.func = func
480
+ self.args = args
481
+ self.kwargs = kwargs
482
+
483
+ def result(self):
484
+ return self.func(*self.args, **self.kwargs)
485
+
486
+ def __init__(self, workers=0):
487
+ pass
488
+
489
+ def submit(self, func, *args, **kwargs):
490
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
491
+
492
+ def __enter__(self):
493
+ return self
494
+
495
+ def __exit__(self, exc_type, exc_value, exc_tb):
496
+ return
audio_separator/separator/uvr_lib_v5/mdxnet.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .modules import TFC_TDF
4
+ from pytorch_lightning import LightningModule
5
+
6
+ dim_s = 4
7
+
8
+ class AbstractMDXNet(LightningModule):
9
+ def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
10
+ super().__init__()
11
+ self.target_name = target_name
12
+ self.lr = lr
13
+ self.optimizer = optimizer
14
+ self.dim_c = dim_c
15
+ self.dim_f = dim_f
16
+ self.dim_t = dim_t
17
+ self.n_fft = n_fft
18
+ self.n_bins = n_fft // 2 + 1
19
+ self.hop_length = hop_length
20
+ self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
21
+ self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
22
+
23
+ def get_optimizer(self):
24
+ if self.optimizer == 'rmsprop':
25
+ return torch.optim.RMSprop(self.parameters(), self.lr)
26
+
27
+ if self.optimizer == 'adamw':
28
+ return torch.optim.AdamW(self.parameters(), self.lr)
29
+
30
+ class ConvTDFNet(AbstractMDXNet):
31
+ def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
32
+ num_blocks, l, g, k, bn, bias, overlap):
33
+
34
+ super(ConvTDFNet, self).__init__(
35
+ target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
36
+ #self.save_hyperparameters()
37
+
38
+ self.num_blocks = num_blocks
39
+ self.l = l
40
+ self.g = g
41
+ self.k = k
42
+ self.bn = bn
43
+ self.bias = bias
44
+
45
+ if optimizer == 'rmsprop':
46
+ norm = nn.BatchNorm2d
47
+
48
+ if optimizer == 'adamw':
49
+ norm = lambda input:nn.GroupNorm(2, input)
50
+
51
+ self.n = num_blocks // 2
52
+ scale = (2, 2)
53
+
54
+ self.first_conv = nn.Sequential(
55
+ nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
56
+ norm(g),
57
+ nn.ReLU(),
58
+ )
59
+
60
+ f = self.dim_f
61
+ c = g
62
+ self.encoding_blocks = nn.ModuleList()
63
+ self.ds = nn.ModuleList()
64
+ for i in range(self.n):
65
+ self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
66
+ self.ds.append(
67
+ nn.Sequential(
68
+ nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
69
+ norm(c + g),
70
+ nn.ReLU()
71
+ )
72
+ )
73
+ f = f // 2
74
+ c += g
75
+
76
+ self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
77
+
78
+ self.decoding_blocks = nn.ModuleList()
79
+ self.us = nn.ModuleList()
80
+ for i in range(self.n):
81
+ self.us.append(
82
+ nn.Sequential(
83
+ nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
84
+ norm(c - g),
85
+ nn.ReLU()
86
+ )
87
+ )
88
+ f = f * 2
89
+ c -= g
90
+
91
+ self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
92
+
93
+ self.final_conv = nn.Sequential(
94
+ nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
95
+ )
96
+
97
+ def forward(self, x):
98
+
99
+ x = self.first_conv(x)
100
+
101
+ x = x.transpose(-1, -2)
102
+
103
+ ds_outputs = []
104
+ for i in range(self.n):
105
+ x = self.encoding_blocks[i](x)
106
+ ds_outputs.append(x)
107
+ x = self.ds[i](x)
108
+
109
+ x = self.bottleneck_block(x)
110
+
111
+ for i in range(self.n):
112
+ x = self.us[i](x)
113
+ x *= ds_outputs[-i - 1]
114
+ x = self.decoding_blocks[i](x)
115
+
116
+ x = x.transpose(-1, -2)
117
+
118
+ x = self.final_conv(x)
119
+
120
+ return x
121
+
122
+ class Mixer(nn.Module):
123
+ def __init__(self, device, mixer_path):
124
+
125
+ super(Mixer, self).__init__()
126
+
127
+ self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
128
+
129
+ self.load_state_dict(
130
+ torch.load(mixer_path, map_location=device)
131
+ )
132
+
133
+ def forward(self, x):
134
+ x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
135
+ x = self.linear(x)
136
+ return x.transpose(-1,-2).reshape(dim_s,2,-1)
audio_separator/separator/uvr_lib_v5/mixer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea781bd52c6a523b825fa6cdbb6189f52e318edd8b17e6fe404f76f7af8caa9c
3
+ size 1208
audio_separator/separator/uvr_lib_v5/modules.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TFC(nn.Module):
6
+ def __init__(self, c, l, k, norm):
7
+ super(TFC, self).__init__()
8
+
9
+ self.H = nn.ModuleList()
10
+ for i in range(l):
11
+ self.H.append(
12
+ nn.Sequential(
13
+ nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
14
+ norm(c),
15
+ nn.ReLU(),
16
+ )
17
+ )
18
+
19
+ def forward(self, x):
20
+ for h in self.H:
21
+ x = h(x)
22
+ return x
23
+
24
+
25
+ class DenseTFC(nn.Module):
26
+ def __init__(self, c, l, k, norm):
27
+ super(DenseTFC, self).__init__()
28
+
29
+ self.conv = nn.ModuleList()
30
+ for i in range(l):
31
+ self.conv.append(
32
+ nn.Sequential(
33
+ nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
34
+ norm(c),
35
+ nn.ReLU(),
36
+ )
37
+ )
38
+
39
+ def forward(self, x):
40
+ for layer in self.conv[:-1]:
41
+ x = torch.cat([layer(x), x], 1)
42
+ return self.conv[-1](x)
43
+
44
+
45
+ class TFC_TDF(nn.Module):
46
+ def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d):
47
+
48
+ super(TFC_TDF, self).__init__()
49
+
50
+ self.use_tdf = bn is not None
51
+
52
+ self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm)
53
+
54
+ if self.use_tdf:
55
+ if bn == 0:
56
+ self.tdf = nn.Sequential(
57
+ nn.Linear(f, f, bias=bias),
58
+ norm(c),
59
+ nn.ReLU()
60
+ )
61
+ else:
62
+ self.tdf = nn.Sequential(
63
+ nn.Linear(f, f // bn, bias=bias),
64
+ norm(c),
65
+ nn.ReLU(),
66
+ nn.Linear(f // bn, f, bias=bias),
67
+ norm(c),
68
+ nn.ReLU()
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.tfc(x)
73
+ return x + self.tdf(x) if self.use_tdf else x
74
+
audio_separator/separator/uvr_lib_v5/playsound.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger(__name__)
3
+
4
+ class PlaysoundException(Exception):
5
+ pass
6
+
7
+ def _canonicalizePath(path):
8
+ """
9
+ Support passing in a pathlib.Path-like object by converting to str.
10
+ """
11
+ import sys
12
+ if sys.version_info[0] >= 3:
13
+ return str(path)
14
+ else:
15
+ # On earlier Python versions, str is a byte string, so attempting to
16
+ # convert a unicode string to str will fail. Leave it alone in this case.
17
+ return path
18
+
19
+ def _playsoundWin(sound, block = True):
20
+ '''
21
+ Utilizes windll.winmm. Tested and known to work with MP3 and WAVE on
22
+ Windows 7 with Python 2.7. Probably works with more file formats.
23
+ Probably works on Windows XP thru Windows 10. Probably works with all
24
+ versions of Python.
25
+
26
+ Inspired by (but not copied from) Michael Gundlach <[email protected]>'s mp3play:
27
+ https://github.com/michaelgundlach/mp3play
28
+
29
+ I never would have tried using windll.winmm without seeing his code.
30
+ '''
31
+ sound = '"' + _canonicalizePath(sound) + '"'
32
+
33
+ from ctypes import create_unicode_buffer, windll, wintypes
34
+ windll.winmm.mciSendStringW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.UINT, wintypes.HANDLE]
35
+ windll.winmm.mciGetErrorStringW.argtypes = [wintypes.DWORD, wintypes.LPWSTR, wintypes.UINT]
36
+
37
+ def winCommand(*command):
38
+ bufLen = 600
39
+ buf = create_unicode_buffer(bufLen)
40
+ command = ' '.join(command)
41
+ errorCode = int(windll.winmm.mciSendStringW(command, buf, bufLen - 1, 0)) # use widestring version of the function
42
+ if errorCode:
43
+ errorBuffer = create_unicode_buffer(bufLen)
44
+ windll.winmm.mciGetErrorStringW(errorCode, errorBuffer, bufLen - 1) # use widestring version of the function
45
+ exceptionMessage = ('\n Error ' + str(errorCode) + ' for command:'
46
+ '\n ' + command +
47
+ '\n ' + errorBuffer.value)
48
+ logger.error(exceptionMessage)
49
+ raise PlaysoundException(exceptionMessage)
50
+ return buf.value
51
+
52
+ try:
53
+ logger.debug('Starting')
54
+ winCommand(u'open {}'.format(sound))
55
+ winCommand(u'play {}{}'.format(sound, ' wait' if block else ''))
56
+ logger.debug('Returning')
57
+ finally:
58
+ try:
59
+ winCommand(u'close {}'.format(sound))
60
+ except PlaysoundException:
61
+ logger.warning(u'Failed to close the file: {}'.format(sound))
62
+ # If it fails, there's nothing more that can be done...
63
+ pass
64
+
65
+ def _handlePathOSX(sound):
66
+ sound = _canonicalizePath(sound)
67
+
68
+ if '://' not in sound:
69
+ if not sound.startswith('/'):
70
+ from os import getcwd
71
+ sound = getcwd() + '/' + sound
72
+ sound = 'file://' + sound
73
+
74
+ try:
75
+ # Don't double-encode it.
76
+ sound.encode('ascii')
77
+ return sound.replace(' ', '%20')
78
+ except UnicodeEncodeError:
79
+ try:
80
+ from urllib.parse import quote # Try the Python 3 import first...
81
+ except ImportError:
82
+ from urllib import quote # Try using the Python 2 import before giving up entirely...
83
+
84
+ parts = sound.split('://', 1)
85
+ return parts[0] + '://' + quote(parts[1].encode('utf-8')).replace(' ', '%20')
86
+
87
+
88
+ def _playsoundOSX(sound, block = True):
89
+ '''
90
+ Utilizes AppKit.NSSound. Tested and known to work with MP3 and WAVE on
91
+ OS X 10.11 with Python 2.7. Probably works with anything QuickTime supports.
92
+ Probably works on OS X 10.5 and newer. Probably works with all versions of
93
+ Python.
94
+
95
+ Inspired by (but not copied from) Aaron's Stack Overflow answer here:
96
+ http://stackoverflow.com/a/34568298/901641
97
+
98
+ I never would have tried using AppKit.NSSound without seeing his code.
99
+ '''
100
+ try:
101
+ from AppKit import NSSound
102
+ except ImportError:
103
+ logger.warning("playsound could not find a copy of AppKit - falling back to using macOS's system copy.")
104
+ sys.path.append('/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/PyObjC')
105
+ from AppKit import NSSound
106
+
107
+ from Foundation import NSURL
108
+ from time import sleep
109
+
110
+ sound = _handlePathOSX(sound)
111
+ url = NSURL.URLWithString_(sound)
112
+ if not url:
113
+ raise PlaysoundException('Cannot find a sound with filename: ' + sound)
114
+
115
+ for i in range(5):
116
+ nssound = NSSound.alloc().initWithContentsOfURL_byReference_(url, True)
117
+ if nssound:
118
+ break
119
+ else:
120
+ logger.debug('Failed to load sound, although url was good... ' + sound)
121
+ else:
122
+ raise PlaysoundException('Could not load sound with filename, although URL was good... ' + sound)
123
+ nssound.play()
124
+
125
+ if block:
126
+ sleep(nssound.duration())
127
+
128
+ def _playsoundNix(sound, block = True):
129
+ """Play a sound using GStreamer.
130
+
131
+ Inspired by this:
132
+ https://gstreamer.freedesktop.org/documentation/tutorials/playback/playbin-usage.html
133
+ """
134
+ sound = _canonicalizePath(sound)
135
+
136
+ # pathname2url escapes non-URL-safe characters
137
+ from os.path import abspath, exists
138
+ try:
139
+ from urllib.request import pathname2url
140
+ except ImportError:
141
+ # python 2
142
+ from urllib import pathname2url
143
+
144
+ import gi
145
+ gi.require_version('Gst', '1.0')
146
+ from gi.repository import Gst
147
+
148
+ Gst.init(None)
149
+
150
+ playbin = Gst.ElementFactory.make('playbin', 'playbin')
151
+ if sound.startswith(('http://', 'https://')):
152
+ playbin.props.uri = sound
153
+ else:
154
+ path = abspath(sound)
155
+ if not exists(path):
156
+ raise PlaysoundException(u'File not found: {}'.format(path))
157
+ playbin.props.uri = 'file://' + pathname2url(path)
158
+
159
+
160
+ set_result = playbin.set_state(Gst.State.PLAYING)
161
+ if set_result != Gst.StateChangeReturn.ASYNC:
162
+ raise PlaysoundException(
163
+ "playbin.set_state returned " + repr(set_result))
164
+
165
+ # FIXME: use some other bus method than poll() with block=False
166
+ # https://lazka.github.io/pgi-docs/#Gst-1.0/classes/Bus.html
167
+ logger.debug('Starting play')
168
+ if block:
169
+ bus = playbin.get_bus()
170
+ try:
171
+ bus.poll(Gst.MessageType.EOS, Gst.CLOCK_TIME_NONE)
172
+ finally:
173
+ playbin.set_state(Gst.State.NULL)
174
+
175
+ logger.debug('Finishing play')
176
+
177
+ def _playsoundAnotherPython(otherPython, sound, block = True, macOS = False):
178
+ '''
179
+ Mostly written so that when this is run on python3 on macOS, it can invoke
180
+ python2 on macOS... but maybe this idea could be useful on linux, too.
181
+ '''
182
+ from inspect import getsourcefile
183
+ from os.path import abspath, exists
184
+ from subprocess import check_call
185
+ from threading import Thread
186
+
187
+ sound = _canonicalizePath(sound)
188
+
189
+ class PropogatingThread(Thread):
190
+ def run(self):
191
+ self.exc = None
192
+ try:
193
+ self.ret = self._target(*self._args, **self._kwargs)
194
+ except BaseException as e:
195
+ self.exc = e
196
+
197
+ def join(self, timeout = None):
198
+ super().join(timeout)
199
+ if self.exc:
200
+ raise self.exc
201
+ return self.ret
202
+
203
+ # Check if the file exists...
204
+ if not exists(abspath(sound)):
205
+ raise PlaysoundException('Cannot find a sound with filename: ' + sound)
206
+
207
+ playsoundPath = abspath(getsourcefile(lambda: 0))
208
+ t = PropogatingThread(target = lambda: check_call([otherPython, playsoundPath, _handlePathOSX(sound) if macOS else sound]))
209
+ t.start()
210
+ if block:
211
+ t.join()
212
+
213
+ from platform import system
214
+ system = system()
215
+
216
+ if system == 'Windows':
217
+ playsound_func = _playsoundWin
218
+ elif system == 'Darwin':
219
+ playsound_func = _playsoundOSX
220
+ import sys
221
+ if sys.version_info[0] > 2:
222
+ try:
223
+ from AppKit import NSSound
224
+ except ImportError:
225
+ logger.warning("playsound is relying on a python 2 subprocess. Please use `pip3 install PyObjC` if you want playsound to run more efficiently.")
226
+ playsound_func = lambda sound, block = True: _playsoundAnotherPython('/System/Library/Frameworks/Python.framework/Versions/2.7/bin/python', sound, block, macOS = True)
227
+ else:
228
+ playsound_func = _playsoundNix
229
+ if __name__ != '__main__': # Ensure we don't infinitely recurse trying to get another python instance.
230
+ try:
231
+ import gi
232
+ gi.require_version('Gst', '1.0')
233
+ from gi.repository import Gst
234
+ except:
235
+ logger.warning("playsound is relying on another python subprocess. Please use `pip install pygobject` if you want playsound to run more efficiently.")
236
+ playsound_func = lambda sound, block = True: _playsoundAnotherPython('/usr/bin/python3', sound, block, macOS = False)
237
+
238
+ del system
239
+
240
+ def play(audio_filepath):
241
+ playsound_func(audio_filepath)
audio_separator/separator/uvr_lib_v5/pyrb.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ import six
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import sys
8
+
9
+ if getattr(sys, 'frozen', False):
10
+ BASE_PATH_RUB = sys._MEIPASS
11
+ else:
12
+ BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__))
13
+
14
+ __all__ = ['time_stretch', 'pitch_shift']
15
+
16
+ __RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, 'rubberband')
17
+
18
+ if six.PY2:
19
+ DEVNULL = open(os.devnull, 'w')
20
+ else:
21
+ DEVNULL = subprocess.DEVNULL
22
+
23
+ def __rubberband(y, sr, **kwargs):
24
+
25
+ assert sr > 0
26
+
27
+ # Get the input and output tempfile
28
+ fd, infile = tempfile.mkstemp(suffix='.wav')
29
+ os.close(fd)
30
+ fd, outfile = tempfile.mkstemp(suffix='.wav')
31
+ os.close(fd)
32
+
33
+ # dump the audio
34
+ sf.write(infile, y, sr)
35
+
36
+ try:
37
+ # Execute rubberband
38
+ arguments = [__RUBBERBAND_UTIL, '-q']
39
+
40
+ for key, value in six.iteritems(kwargs):
41
+ arguments.append(str(key))
42
+ arguments.append(str(value))
43
+
44
+ arguments.extend([infile, outfile])
45
+
46
+ subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
47
+
48
+ # Load the processed audio.
49
+ y_out, _ = sf.read(outfile, always_2d=True)
50
+
51
+ # make sure that output dimensions matches input
52
+ if y.ndim == 1:
53
+ y_out = np.squeeze(y_out)
54
+
55
+ except OSError as exc:
56
+ six.raise_from(RuntimeError('Failed to execute rubberband. '
57
+ 'Please verify that rubberband-cli '
58
+ 'is installed.'),
59
+ exc)
60
+
61
+ finally:
62
+ # Remove temp files
63
+ os.unlink(infile)
64
+ os.unlink(outfile)
65
+
66
+ return y_out
67
+
68
+ def time_stretch(y, sr, rate, rbargs=None):
69
+ if rate <= 0:
70
+ raise ValueError('rate must be strictly positive')
71
+
72
+ if rate == 1.0:
73
+ return y
74
+
75
+ if rbargs is None:
76
+ rbargs = dict()
77
+
78
+ rbargs.setdefault('--tempo', rate)
79
+
80
+ return __rubberband(y, sr, **rbargs)
81
+
82
+ def pitch_shift(y, sr, n_steps, rbargs=None):
83
+
84
+ if n_steps == 0:
85
+ return y
86
+
87
+ if rbargs is None:
88
+ rbargs = dict()
89
+
90
+ rbargs.setdefault('--pitch', n_steps)
91
+
92
+ return __rubberband(y, sr, **rbargs)
audio_separator/separator/uvr_lib_v5/results.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Matchering - Audio Matching and Mastering Python Library
5
+ Copyright (C) 2016-2022 Sergree
6
+
7
+ This program is free software: you can redistribute it and/or modify
8
+ it under the terms of the GNU General Public License as published by
9
+ the Free Software Foundation, either version 3 of the License, or
10
+ (at your option) any later version.
11
+
12
+ This program is distributed in the hope that it will be useful,
13
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ GNU General Public License for more details.
16
+
17
+ You should have received a copy of the GNU General Public License
18
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
19
+ """
20
+
21
+ import os
22
+ import soundfile as sf
23
+
24
+
25
+ class Result:
26
+ def __init__(
27
+ self, file: str, subtype: str, use_limiter: bool = True, normalize: bool = True
28
+ ):
29
+ _, file_ext = os.path.splitext(file)
30
+ file_ext = file_ext[1:].upper()
31
+ if not sf.check_format(file_ext):
32
+ raise TypeError(f"{file_ext} format is not supported")
33
+ if not sf.check_format(file_ext, subtype):
34
+ raise TypeError(f"{file_ext} format does not have {subtype} subtype")
35
+ self.file = file
36
+ self.subtype = subtype
37
+ self.use_limiter = use_limiter
38
+ self.normalize = normalize
39
+
40
+
41
+ def pcm16(file: str) -> Result:
42
+ return Result(file, "PCM_16")
43
+
44
+ def pcm24(file: str) -> Result:
45
+ return Result(file, "FLOAT")
46
+
47
+ def save_audiofile(file: str, wav_set="PCM_16") -> Result:
48
+ return Result(file, wav_set)
audio_separator/separator/uvr_lib_v5/roformer/attend.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, reduce
10
+
11
+ # constants
12
+
13
+ FlashAttentionConfig = namedtuple("FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
14
+
15
+ # helpers
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def once(fn):
23
+ called = False
24
+
25
+ @wraps(fn)
26
+ def inner(x):
27
+ nonlocal called
28
+ if called:
29
+ return
30
+ called = True
31
+ return fn(x)
32
+
33
+ return inner
34
+
35
+
36
+ print_once = once(print)
37
+
38
+ # main class
39
+
40
+
41
+ class Attend(nn.Module):
42
+ def __init__(self, dropout=0.0, flash=False):
43
+ super().__init__()
44
+ self.dropout = dropout
45
+ self.attn_dropout = nn.Dropout(dropout)
46
+
47
+ self.flash = flash
48
+ assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), "in order to use flash attention, you must be using pytorch 2.0 or above"
49
+
50
+ # determine efficient attention configs for cuda and cpu
51
+
52
+ self.cpu_config = FlashAttentionConfig(True, True, True)
53
+ self.cuda_config = None
54
+
55
+ if not torch.cuda.is_available() or not flash:
56
+ return
57
+
58
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
59
+
60
+ if device_properties.major == 8 and device_properties.minor == 0:
61
+ print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
62
+ self.cuda_config = FlashAttentionConfig(True, False, False)
63
+ else:
64
+ self.cuda_config = FlashAttentionConfig(False, True, True)
65
+
66
+ def flash_attn(self, q, k, v):
67
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
68
+
69
+ # Check if there is a compatible device for flash attention
70
+
71
+ config = self.cuda_config if is_cuda else self.cpu_config
72
+
73
+ # sdpa_flash kernel only supports float16 on sm80+ architecture gpu
74
+ if is_cuda and q.dtype != torch.float16:
75
+ config = FlashAttentionConfig(False, True, True)
76
+
77
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
78
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
79
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
80
+
81
+ return out
82
+
83
+ def forward(self, q, k, v):
84
+ """
85
+ einstein notation
86
+ b - batch
87
+ h - heads
88
+ n, i, j - sequence length (base sequence length, source, target)
89
+ d - feature dimension
90
+ """
91
+
92
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
93
+
94
+ scale = q.shape[-1] ** -0.5
95
+
96
+ if self.flash:
97
+ return self.flash_attn(q, k, v)
98
+
99
+ # similarity
100
+
101
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
102
+
103
+ # attention
104
+
105
+ attn = sim.softmax(dim=-1)
106
+ attn = self.attn_dropout(attn)
107
+
108
+ # aggregate values
109
+
110
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
111
+
112
+ return out
audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack
16
+ from einops.layers.torch import Rearrange
17
+
18
+ # helper functions
19
+
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+
25
+ def default(v, d):
26
+ return v if exists(v) else d
27
+
28
+
29
+ def pack_one(t, pattern):
30
+ return pack([t], pattern)
31
+
32
+
33
+ def unpack_one(t, ps, pattern):
34
+ return unpack(t, ps, pattern)[0]
35
+
36
+
37
+ # norm
38
+
39
+
40
+ def l2norm(t):
41
+ return F.normalize(t, dim=-1, p=2)
42
+
43
+
44
+ class RMSNorm(Module):
45
+ def __init__(self, dim):
46
+ super().__init__()
47
+ self.scale = dim**0.5
48
+ self.gamma = nn.Parameter(torch.ones(dim))
49
+
50
+ def forward(self, x):
51
+ x = x.to(self.gamma.device)
52
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
53
+
54
+
55
+ # attention
56
+
57
+
58
+ class FeedForward(Module):
59
+ def __init__(self, dim, mult=4, dropout=0.0):
60
+ super().__init__()
61
+ dim_inner = int(dim * mult)
62
+ self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout))
63
+
64
+ def forward(self, x):
65
+ return self.net(x)
66
+
67
+
68
+ class Attention(Module):
69
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
70
+ super().__init__()
71
+ self.heads = heads
72
+ self.scale = dim_head**-0.5
73
+ dim_inner = heads * dim_head
74
+
75
+ self.rotary_embed = rotary_embed
76
+
77
+ self.attend = Attend(flash=flash, dropout=dropout)
78
+
79
+ self.norm = RMSNorm(dim)
80
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
81
+
82
+ self.to_gates = nn.Linear(dim, heads)
83
+
84
+ self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
85
+
86
+ def forward(self, x):
87
+ x = self.norm(x)
88
+
89
+ q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
90
+
91
+ if exists(self.rotary_embed):
92
+ q = self.rotary_embed.rotate_queries_or_keys(q)
93
+ k = self.rotary_embed.rotate_queries_or_keys(k)
94
+
95
+ out = self.attend(q, k, v)
96
+
97
+ gates = self.to_gates(x)
98
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
99
+
100
+ out = rearrange(out, "b h n d -> b n (h d)")
101
+ return self.to_out(out)
102
+
103
+
104
+ class LinearAttention(Module):
105
+ """
106
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
107
+ """
108
+
109
+ @beartype
110
+ def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
111
+ super().__init__()
112
+ dim_inner = dim_head * heads
113
+ self.norm = RMSNorm(dim)
114
+
115
+ self.to_qkv = nn.Sequential(nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads))
116
+
117
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
118
+
119
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
120
+
121
+ self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
122
+
123
+ def forward(self, x):
124
+ x = self.norm(x)
125
+
126
+ q, k, v = self.to_qkv(x)
127
+
128
+ q, k = map(l2norm, (q, k))
129
+ q = q * self.temperature.exp()
130
+
131
+ out = self.attend(q, k, v)
132
+
133
+ return self.to_out(out)
134
+
135
+
136
+ class Transformer(Module):
137
+ def __init__(self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True, linear_attn=False):
138
+ super().__init__()
139
+ self.layers = ModuleList([])
140
+
141
+ for _ in range(depth):
142
+ if linear_attn:
143
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
144
+ else:
145
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn)
146
+
147
+ self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
148
+
149
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
150
+
151
+ def forward(self, x):
152
+
153
+ for attn, ff in self.layers:
154
+ x = attn(x) + x
155
+ x = ff(x) + x
156
+
157
+ return self.norm(x)
158
+
159
+
160
+ # bandsplit module
161
+
162
+
163
+ class BandSplit(Module):
164
+ @beartype
165
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
166
+ super().__init__()
167
+ self.dim_inputs = dim_inputs
168
+ self.to_features = ModuleList([])
169
+
170
+ for dim_in in dim_inputs:
171
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
172
+
173
+ self.to_features.append(net)
174
+
175
+ def forward(self, x):
176
+ x = x.split(self.dim_inputs, dim=-1)
177
+
178
+ outs = []
179
+ for split_input, to_feature in zip(x, self.to_features):
180
+ split_output = to_feature(split_input)
181
+ outs.append(split_output)
182
+
183
+ return torch.stack(outs, dim=-2)
184
+
185
+
186
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
187
+ dim_hidden = default(dim_hidden, dim_in)
188
+
189
+ net = []
190
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
191
+
192
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
193
+ is_last = ind == (len(dims) - 2)
194
+
195
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
196
+
197
+ if is_last:
198
+ continue
199
+
200
+ net.append(activation())
201
+
202
+ return nn.Sequential(*net)
203
+
204
+
205
+ class MaskEstimator(Module):
206
+ @beartype
207
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
208
+ super().__init__()
209
+ self.dim_inputs = dim_inputs
210
+ self.to_freqs = ModuleList([])
211
+ dim_hidden = dim * mlp_expansion_factor
212
+
213
+ for dim_in in dim_inputs:
214
+ net = []
215
+
216
+ mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
217
+
218
+ self.to_freqs.append(mlp)
219
+
220
+ def forward(self, x):
221
+ x = x.unbind(dim=-2)
222
+
223
+ outs = []
224
+
225
+ for band_features, mlp in zip(x, self.to_freqs):
226
+ freq_out = mlp(band_features)
227
+ outs.append(freq_out)
228
+
229
+ return torch.cat(outs, dim=-1)
230
+
231
+
232
+ # main class
233
+
234
+ DEFAULT_FREQS_PER_BANDS = (
235
+ 2,
236
+ 2,
237
+ 2,
238
+ 2,
239
+ 2,
240
+ 2,
241
+ 2,
242
+ 2,
243
+ 2,
244
+ 2,
245
+ 2,
246
+ 2,
247
+ 2,
248
+ 2,
249
+ 2,
250
+ 2,
251
+ 2,
252
+ 2,
253
+ 2,
254
+ 2,
255
+ 2,
256
+ 2,
257
+ 2,
258
+ 2,
259
+ 4,
260
+ 4,
261
+ 4,
262
+ 4,
263
+ 4,
264
+ 4,
265
+ 4,
266
+ 4,
267
+ 4,
268
+ 4,
269
+ 4,
270
+ 4,
271
+ 12,
272
+ 12,
273
+ 12,
274
+ 12,
275
+ 12,
276
+ 12,
277
+ 12,
278
+ 12,
279
+ 24,
280
+ 24,
281
+ 24,
282
+ 24,
283
+ 24,
284
+ 24,
285
+ 24,
286
+ 24,
287
+ 48,
288
+ 48,
289
+ 48,
290
+ 48,
291
+ 48,
292
+ 48,
293
+ 48,
294
+ 48,
295
+ 128,
296
+ 129,
297
+ )
298
+
299
+
300
+ class BSRoformer(Module):
301
+
302
+ @beartype
303
+ def __init__(
304
+ self,
305
+ dim,
306
+ *,
307
+ depth,
308
+ stereo=False,
309
+ num_stems=1,
310
+ time_transformer_depth=2,
311
+ freq_transformer_depth=2,
312
+ linear_transformer_depth=0,
313
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
314
+ # in the paper, they divide into ~60 bands, test with 1 for starters
315
+ dim_head=64,
316
+ heads=8,
317
+ attn_dropout=0.0,
318
+ ff_dropout=0.0,
319
+ flash_attn=True,
320
+ dim_freqs_in=1025,
321
+ stft_n_fft=2048,
322
+ stft_hop_length=512,
323
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
324
+ stft_win_length=2048,
325
+ stft_normalized=False,
326
+ stft_window_fn: Optional[Callable] = None,
327
+ mask_estimator_depth=2,
328
+ multi_stft_resolution_loss_weight=1.0,
329
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
330
+ multi_stft_hop_size=147,
331
+ multi_stft_normalized=False,
332
+ multi_stft_window_fn: Callable = torch.hann_window,
333
+ ):
334
+ super().__init__()
335
+
336
+ self.stereo = stereo
337
+ self.audio_channels = 2 if stereo else 1
338
+ self.num_stems = num_stems
339
+
340
+ self.layers = ModuleList([])
341
+
342
+ transformer_kwargs = dict(dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn, norm_output=False)
343
+
344
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
345
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
346
+
347
+ for _ in range(depth):
348
+ tran_modules = []
349
+ if linear_transformer_depth > 0:
350
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
351
+ tran_modules.append(Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs))
352
+ tran_modules.append(Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs))
353
+ self.layers.append(nn.ModuleList(tran_modules))
354
+
355
+ self.final_norm = RMSNorm(dim)
356
+
357
+ self.stft_kwargs = dict(n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized)
358
+
359
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
360
+
361
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
362
+
363
+ assert len(freqs_per_bands) > 1
364
+ assert sum(freqs_per_bands) == freqs, f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
365
+
366
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
367
+
368
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
369
+
370
+ self.mask_estimators = nn.ModuleList([])
371
+
372
+ for _ in range(num_stems):
373
+ mask_estimator = MaskEstimator(dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth)
374
+
375
+ self.mask_estimators.append(mask_estimator)
376
+
377
+ # for the multi-resolution stft loss
378
+
379
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
380
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
381
+ self.multi_stft_n_fft = stft_n_fft
382
+ self.multi_stft_window_fn = multi_stft_window_fn
383
+
384
+ self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
385
+
386
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
387
+ """
388
+ einops
389
+
390
+ b - batch
391
+ f - freq
392
+ t - time
393
+ s - audio channel (1 for mono, 2 for stereo)
394
+ n - number of 'stems'
395
+ c - complex (2)
396
+ d - feature dimension
397
+ """
398
+
399
+ original_device = raw_audio.device
400
+ x_is_mps = True if original_device.type == "mps" else False
401
+
402
+ # if x_is_mps:
403
+ # raw_audio = raw_audio.cpu()
404
+
405
+ device = raw_audio.device
406
+
407
+ if raw_audio.ndim == 2:
408
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
409
+
410
+ channels = raw_audio.shape[1]
411
+ assert (not self.stereo and channels == 1) or (
412
+ self.stereo and channels == 2
413
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
414
+
415
+ # to stft
416
+
417
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
418
+
419
+ stft_window = self.stft_window_fn().to(device)
420
+
421
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
422
+ stft_repr = torch.view_as_real(stft_repr)
423
+
424
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
425
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
426
+
427
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
428
+
429
+ x = self.band_split(x)
430
+
431
+ # axial / hierarchical attention
432
+
433
+ for transformer_block in self.layers:
434
+
435
+ if len(transformer_block) == 3:
436
+ linear_transformer, time_transformer, freq_transformer = transformer_block
437
+
438
+ x, ft_ps = pack([x], "b * d")
439
+ x = linear_transformer(x)
440
+ (x,) = unpack(x, ft_ps, "b * d")
441
+ else:
442
+ time_transformer, freq_transformer = transformer_block
443
+
444
+ x = rearrange(x, "b t f d -> b f t d")
445
+ x, ps = pack([x], "* t d")
446
+
447
+ x = time_transformer(x)
448
+
449
+ (x,) = unpack(x, ps, "* t d")
450
+ x = rearrange(x, "b f t d -> b t f d")
451
+ x, ps = pack([x], "* f d")
452
+
453
+ x = freq_transformer(x)
454
+
455
+ (x,) = unpack(x, ps, "* f d")
456
+
457
+ x = self.final_norm(x)
458
+
459
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
460
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
461
+
462
+ # if x_is_mps:
463
+ # mask = mask.to('cpu')
464
+
465
+ # modulate frequency representation
466
+
467
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
468
+
469
+ # complex number multiplication
470
+
471
+ stft_repr = torch.view_as_complex(stft_repr)
472
+ mask = torch.view_as_complex(mask)
473
+
474
+ stft_repr = stft_repr * mask
475
+
476
+ # istft
477
+
478
+ stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
479
+
480
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False).to(device)
481
+
482
+ recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=self.num_stems)
483
+
484
+ if self.num_stems == 1:
485
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
486
+
487
+ # if a target is passed in, calculate loss for learning
488
+
489
+ if not exists(target):
490
+ return recon_audio
491
+
492
+ if self.num_stems > 1:
493
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
494
+
495
+ if target.ndim == 2:
496
+ target = rearrange(target, "... t -> ... 1 t")
497
+
498
+ target = target[..., : recon_audio.shape[-1]]
499
+
500
+ loss = F.l1_loss(recon_audio, target)
501
+
502
+ multi_stft_resolution_loss = 0.0
503
+
504
+ for window_size in self.multi_stft_resolutions_window_sizes:
505
+ res_stft_kwargs = dict(
506
+ n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs
507
+ )
508
+
509
+ recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
510
+ target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
511
+
512
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
513
+
514
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
515
+
516
+ total_loss = loss + weighted_multi_resolution_loss
517
+
518
+ if not return_loss_breakdown:
519
+ # Move the result back to the original device if it was moved to CPU for MPS compatibility
520
+ # if x_is_mps:
521
+ # total_loss = total_loss.to(original_device)
522
+ return total_loss
523
+
524
+ # For detailed loss breakdown, ensure all components are moved back to the original device for MPS
525
+ # if x_is_mps:
526
+ # loss = loss.to(original_device)
527
+ # multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
528
+ # weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
529
+
530
+ return total_loss, (loss, multi_stft_resolution_loss)
531
+
532
+ # if not return_loss_breakdown:
533
+ # return total_loss
534
+
535
+ # return total_loss, (loss, multi_stft_resolution_loss)
audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack, reduce, repeat
16
+
17
+ from librosa import filters
18
+
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+
28
+ def pack_one(t, pattern):
29
+ return pack([t], pattern)
30
+
31
+
32
+ def unpack_one(t, ps, pattern):
33
+ return unpack(t, ps, pattern)[0]
34
+
35
+
36
+ def pad_at_dim(t, pad, dim=-1, value=0.0):
37
+ dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
38
+ zeros = (0, 0) * dims_from_right
39
+ return F.pad(t, (*zeros, *pad), value=value)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim**0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ x = x.to(self.gamma.device)
50
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
51
+
52
+
53
+ class FeedForward(Module):
54
+ def __init__(self, dim, mult=4, dropout=0.0):
55
+ super().__init__()
56
+ dim_inner = int(dim * mult)
57
+ self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout))
58
+
59
+ def forward(self, x):
60
+ return self.net(x)
61
+
62
+
63
+ class Attention(Module):
64
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
65
+ super().__init__()
66
+ self.heads = heads
67
+ self.scale = dim_head**-0.5
68
+ dim_inner = heads * dim_head
69
+
70
+ self.rotary_embed = rotary_embed
71
+
72
+ self.attend = Attend(flash=flash, dropout=dropout)
73
+
74
+ self.norm = RMSNorm(dim)
75
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
76
+
77
+ self.to_gates = nn.Linear(dim, heads)
78
+
79
+ self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
80
+
81
+ def forward(self, x):
82
+ x = self.norm(x)
83
+
84
+ q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
85
+
86
+ if exists(self.rotary_embed):
87
+ q = self.rotary_embed.rotate_queries_or_keys(q)
88
+ k = self.rotary_embed.rotate_queries_or_keys(k)
89
+
90
+ out = self.attend(q, k, v)
91
+
92
+ gates = self.to_gates(x)
93
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
94
+
95
+ out = rearrange(out, "b h n d -> b n (h d)")
96
+ return self.to_out(out)
97
+
98
+
99
+ class Transformer(Module):
100
+ def __init__(self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True):
101
+ super().__init__()
102
+ self.layers = ModuleList([])
103
+
104
+ for _ in range(depth):
105
+ self.layers.append(
106
+ ModuleList(
107
+ [Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn), FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
108
+ )
109
+ )
110
+
111
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
112
+
113
+ def forward(self, x):
114
+
115
+ for attn, ff in self.layers:
116
+ x = attn(x) + x
117
+ x = ff(x) + x
118
+
119
+ return self.norm(x)
120
+
121
+
122
+ class BandSplit(Module):
123
+ @beartype
124
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
125
+ super().__init__()
126
+ self.dim_inputs = dim_inputs
127
+ self.to_features = ModuleList([])
128
+
129
+ for dim_in in dim_inputs:
130
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
131
+
132
+ self.to_features.append(net)
133
+
134
+ def forward(self, x):
135
+ x = x.split(self.dim_inputs, dim=-1)
136
+
137
+ outs = []
138
+ for split_input, to_feature in zip(x, self.to_features):
139
+ split_output = to_feature(split_input)
140
+ outs.append(split_output)
141
+
142
+ return torch.stack(outs, dim=-2)
143
+
144
+
145
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
146
+ dim_hidden = default(dim_hidden, dim_in)
147
+
148
+ net = []
149
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
150
+
151
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
152
+ is_last = ind == (len(dims) - 2)
153
+
154
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
155
+
156
+ if is_last:
157
+ continue
158
+
159
+ net.append(activation())
160
+
161
+ return nn.Sequential(*net)
162
+
163
+
164
+ class MaskEstimator(Module):
165
+ @beartype
166
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
167
+ super().__init__()
168
+ self.dim_inputs = dim_inputs
169
+ self.to_freqs = ModuleList([])
170
+ dim_hidden = dim * mlp_expansion_factor
171
+
172
+ for dim_in in dim_inputs:
173
+ net = []
174
+
175
+ mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
176
+
177
+ self.to_freqs.append(mlp)
178
+
179
+ def forward(self, x):
180
+ x = x.unbind(dim=-2)
181
+
182
+ outs = []
183
+
184
+ for band_features, mlp in zip(x, self.to_freqs):
185
+ freq_out = mlp(band_features)
186
+ outs.append(freq_out)
187
+
188
+ return torch.cat(outs, dim=-1)
189
+
190
+
191
+ class MelBandRoformer(Module):
192
+
193
+ @beartype
194
+ def __init__(
195
+ self,
196
+ dim,
197
+ *,
198
+ depth,
199
+ stereo=False,
200
+ num_stems=1,
201
+ time_transformer_depth=2,
202
+ freq_transformer_depth=2,
203
+ num_bands=60,
204
+ dim_head=64,
205
+ heads=8,
206
+ attn_dropout=0.1,
207
+ ff_dropout=0.1,
208
+ flash_attn=True,
209
+ dim_freqs_in=1025,
210
+ sample_rate=44100,
211
+ stft_n_fft=2048,
212
+ stft_hop_length=512,
213
+ stft_win_length=2048,
214
+ stft_normalized=False,
215
+ stft_window_fn: Optional[Callable] = None,
216
+ mask_estimator_depth=1,
217
+ multi_stft_resolution_loss_weight=1.0,
218
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
219
+ multi_stft_hop_size=147,
220
+ multi_stft_normalized=False,
221
+ multi_stft_window_fn: Callable = torch.hann_window,
222
+ match_input_audio_length=False,
223
+ ):
224
+ super().__init__()
225
+
226
+ self.stereo = stereo
227
+ self.audio_channels = 2 if stereo else 1
228
+ self.num_stems = num_stems
229
+
230
+ self.layers = ModuleList([])
231
+
232
+ transformer_kwargs = dict(dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn)
233
+
234
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
235
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
236
+
237
+ for _ in range(depth):
238
+ self.layers.append(
239
+ nn.ModuleList(
240
+ [
241
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs),
242
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs),
243
+ ]
244
+ )
245
+ )
246
+
247
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
248
+
249
+ self.stft_kwargs = dict(n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized)
250
+
251
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
252
+
253
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
254
+
255
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
256
+
257
+ mel_filter_bank[0][0] = 1.0
258
+
259
+ mel_filter_bank[-1, -1] = 1.0
260
+
261
+ freqs_per_band = mel_filter_bank > 0
262
+ assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
263
+
264
+ repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
265
+ freq_indices = repeated_freq_indices[freqs_per_band]
266
+
267
+ if stereo:
268
+ freq_indices = repeat(freq_indices, "f -> f s", s=2)
269
+ freq_indices = freq_indices * 2 + torch.arange(2)
270
+ freq_indices = rearrange(freq_indices, "f s -> (f s)")
271
+
272
+ self.register_buffer("freq_indices", freq_indices, persistent=False)
273
+ self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
274
+
275
+ num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
276
+ num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
277
+
278
+ self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
279
+ self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
280
+
281
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
282
+
283
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
284
+
285
+ self.mask_estimators = nn.ModuleList([])
286
+
287
+ for _ in range(num_stems):
288
+ mask_estimator = MaskEstimator(dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth)
289
+
290
+ self.mask_estimators.append(mask_estimator)
291
+
292
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
293
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
294
+ self.multi_stft_n_fft = stft_n_fft
295
+ self.multi_stft_window_fn = multi_stft_window_fn
296
+
297
+ self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
298
+
299
+ self.match_input_audio_length = match_input_audio_length
300
+
301
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
302
+ """
303
+ einops
304
+
305
+ b - batch
306
+ f - freq
307
+ t - time
308
+ s - audio channel (1 for mono, 2 for stereo)
309
+ n - number of 'stems'
310
+ c - complex (2)
311
+ d - feature dimension
312
+ """
313
+
314
+ original_device = raw_audio.device
315
+ x_is_mps = True if original_device.type == "mps" else False
316
+
317
+ if x_is_mps:
318
+ raw_audio = raw_audio.cpu()
319
+
320
+ device = raw_audio.device
321
+
322
+ if raw_audio.ndim == 2:
323
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
324
+
325
+ batch, channels, raw_audio_length = raw_audio.shape
326
+
327
+ istft_length = raw_audio_length if self.match_input_audio_length else None
328
+
329
+ assert (not self.stereo and channels == 1) or (
330
+ self.stereo and channels == 2
331
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
332
+
333
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
334
+
335
+ stft_window = self.stft_window_fn().to(device)
336
+
337
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
338
+ stft_repr = torch.view_as_real(stft_repr)
339
+
340
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
341
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
342
+
343
+ batch_arange = torch.arange(batch, device=device)[..., None]
344
+
345
+ x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
346
+
347
+ x = rearrange(x, "b f t c -> b t (f c)")
348
+
349
+ x = self.band_split(x)
350
+
351
+ for time_transformer, freq_transformer in self.layers:
352
+ x = rearrange(x, "b t f d -> b f t d")
353
+ x, ps = pack([x], "* t d")
354
+
355
+ x = time_transformer(x)
356
+
357
+ (x,) = unpack(x, ps, "* t d")
358
+ x = rearrange(x, "b f t d -> b t f d")
359
+ x, ps = pack([x], "* f d")
360
+
361
+ x = freq_transformer(x)
362
+
363
+ (x,) = unpack(x, ps, "* f d")
364
+
365
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
366
+ masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
367
+
368
+ if x_is_mps:
369
+ masks = masks.cpu()
370
+
371
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
372
+
373
+ stft_repr = torch.view_as_complex(stft_repr)
374
+ masks = torch.view_as_complex(masks)
375
+
376
+ masks = masks.type(stft_repr.dtype)
377
+
378
+ if x_is_mps:
379
+ scatter_indices = repeat(self.freq_indices.cpu(), "f -> b n f t", b=batch, n=self.num_stems, t=stft_repr.shape[-1])
380
+ else:
381
+ scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=self.num_stems, t=stft_repr.shape[-1])
382
+
383
+ stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=self.num_stems)
384
+ masks_summed = (
385
+ torch.zeros_like(stft_repr_expanded_stems.cpu() if x_is_mps else stft_repr_expanded_stems)
386
+ .scatter_add_(2, scatter_indices.cpu() if x_is_mps else scatter_indices, masks.cpu() if x_is_mps else masks)
387
+ .to(device)
388
+ )
389
+
390
+ denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
391
+
392
+ if x_is_mps:
393
+ denom = denom.cpu()
394
+
395
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
396
+
397
+ stft_repr = stft_repr * masks_averaged
398
+
399
+ stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
400
+
401
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=istft_length)
402
+
403
+ recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=self.num_stems)
404
+
405
+ if self.num_stems == 1:
406
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
407
+
408
+ if not exists(target):
409
+ return recon_audio
410
+
411
+ if self.num_stems > 1:
412
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
413
+
414
+ if target.ndim == 2:
415
+ target = rearrange(target, "... t -> ... 1 t")
416
+
417
+ target = target[..., : recon_audio.shape[-1]]
418
+
419
+ loss = F.l1_loss(recon_audio, target)
420
+
421
+ multi_stft_resolution_loss = 0.0
422
+
423
+ for window_size in self.multi_stft_resolutions_window_sizes:
424
+ res_stft_kwargs = dict(
425
+ n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs
426
+ )
427
+
428
+ recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
429
+ target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
430
+
431
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
432
+
433
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
434
+
435
+ total_loss = loss + weighted_multi_resolution_loss
436
+
437
+ # Move the total loss back to the original device if necessary
438
+ if x_is_mps:
439
+ total_loss = total_loss.to(original_device)
440
+
441
+ if not return_loss_breakdown:
442
+ return total_loss
443
+
444
+ # If detailed loss breakdown is requested, ensure all components are on the original device
445
+ return total_loss, (loss.to(original_device) if x_is_mps else loss, multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
audio_separator/separator/uvr_lib_v5/spec_utils.py ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import audioread
2
+ import librosa
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import math
6
+ import platform
7
+ import traceback
8
+ from audio_separator.separator.uvr_lib_v5 import pyrb
9
+ from scipy.signal import correlate, hilbert
10
+ import io
11
+
12
+ OPERATING_SYSTEM = platform.system()
13
+ SYSTEM_ARCH = platform.platform()
14
+ SYSTEM_PROC = platform.processor()
15
+ ARM = "arm"
16
+
17
+ AUTO_PHASE = "Automatic"
18
+ POSITIVE_PHASE = "Positive Phase"
19
+ NEGATIVE_PHASE = "Negative Phase"
20
+ NONE_P = ("None",)
21
+ LOW_P = ("Shifts: Low",)
22
+ MED_P = ("Shifts: Medium",)
23
+ HIGH_P = ("Shifts: High",)
24
+ VHIGH_P = "Shifts: Very High"
25
+ MAXIMUM_P = "Shifts: Maximum"
26
+
27
+ progress_value = 0
28
+ last_update_time = 0
29
+ is_macos = False
30
+
31
+
32
+ if OPERATING_SYSTEM == "Darwin":
33
+ wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
34
+ wav_resolution_float_resampling = "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution
35
+ is_macos = True
36
+ else:
37
+ wav_resolution = "sinc_fastest"
38
+ wav_resolution_float_resampling = wav_resolution
39
+
40
+ MAX_SPEC = "Max Spec"
41
+ MIN_SPEC = "Min Spec"
42
+ LIN_ENSE = "Linear Ensemble"
43
+
44
+ MAX_WAV = MAX_SPEC
45
+ MIN_WAV = MIN_SPEC
46
+
47
+ AVERAGE = "Average"
48
+
49
+
50
+ def crop_center(h1, h2):
51
+ """
52
+ This function crops the center of the first input tensor to match the size of the second input tensor.
53
+ It is used to ensure that the two tensors have the same size in the time dimension.
54
+ """
55
+ h1_shape = h1.size()
56
+ h2_shape = h2.size()
57
+
58
+ # If the time dimensions are already equal, return the first tensor as is
59
+ if h1_shape[3] == h2_shape[3]:
60
+ return h1
61
+ # If the time dimension of the first tensor is smaller, raise an error
62
+ elif h1_shape[3] < h2_shape[3]:
63
+ raise ValueError("h1_shape[3] must be greater than h2_shape[3]")
64
+
65
+ # Calculate the start and end indices for cropping
66
+ s_time = (h1_shape[3] - h2_shape[3]) // 2
67
+ e_time = s_time + h2_shape[3]
68
+ # Crop the first tensor
69
+ h1 = h1[:, :, :, s_time:e_time]
70
+
71
+ return h1
72
+
73
+
74
+ def preprocess(X_spec):
75
+ """
76
+ This function preprocesses a spectrogram by separating it into magnitude and phase components.
77
+ This is a common preprocessing step in audio processing tasks.
78
+ """
79
+ X_mag = np.abs(X_spec)
80
+ X_phase = np.angle(X_spec)
81
+
82
+ return X_mag, X_phase
83
+
84
+
85
+ def make_padding(width, cropsize, offset):
86
+ """
87
+ This function calculates the padding needed to make the width of an image divisible by the crop size.
88
+ It is used in the process of splitting an image into smaller patches.
89
+ """
90
+ left = offset
91
+ roi_size = cropsize - offset * 2
92
+ if roi_size == 0:
93
+ roi_size = cropsize
94
+ right = roi_size - (width % roi_size) + left
95
+
96
+ return left, right, roi_size
97
+
98
+
99
+ def normalize(wave, max_peak=1.0, min_peak=None):
100
+ """Normalize (or amplify) audio waveform to a specified peak value.
101
+
102
+ Args:
103
+ wave (array-like): Audio waveform.
104
+ max_peak (float): Maximum peak value for normalization.
105
+
106
+ Returns:
107
+ array-like: Normalized or original waveform.
108
+ """
109
+ maxv = np.abs(wave).max()
110
+ if maxv > max_peak:
111
+ wave *= max_peak / maxv
112
+ elif min_peak is not None and maxv < min_peak:
113
+ wave *= min_peak / maxv
114
+
115
+ return wave
116
+
117
+
118
+ def auto_transpose(audio_array: np.ndarray):
119
+ """
120
+ Ensure that the audio array is in the (channels, samples) format.
121
+
122
+ Parameters:
123
+ audio_array (ndarray): Input audio array.
124
+
125
+ Returns:
126
+ ndarray: Transposed audio array if necessary.
127
+ """
128
+
129
+ # If the second dimension is 2 (indicating stereo channels), transpose the array
130
+ if audio_array.shape[1] == 2:
131
+ return audio_array.T
132
+ return audio_array
133
+
134
+
135
+ def write_array_to_mem(audio_data, subtype):
136
+ if isinstance(audio_data, np.ndarray):
137
+ audio_buffer = io.BytesIO()
138
+ sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
139
+ audio_buffer.seek(0)
140
+ return audio_buffer
141
+ else:
142
+ return audio_data
143
+
144
+
145
+ def spectrogram_to_image(spec, mode="magnitude"):
146
+ if mode == "magnitude":
147
+ if np.iscomplexobj(spec):
148
+ y = np.abs(spec)
149
+ else:
150
+ y = spec
151
+ y = np.log10(y**2 + 1e-8)
152
+ elif mode == "phase":
153
+ if np.iscomplexobj(spec):
154
+ y = np.angle(spec)
155
+ else:
156
+ y = spec
157
+
158
+ y -= y.min()
159
+ y *= 255 / y.max()
160
+ img = np.uint8(y)
161
+
162
+ if y.ndim == 3:
163
+ img = img.transpose(1, 2, 0)
164
+ img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
165
+
166
+ return img
167
+
168
+
169
+ def reduce_vocal_aggressively(X, y, softmask):
170
+ v = X - y
171
+ y_mag_tmp = np.abs(y)
172
+ v_mag_tmp = np.abs(v)
173
+
174
+ v_mask = v_mag_tmp > y_mag_tmp
175
+ y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
176
+
177
+ return y_mag * np.exp(1.0j * np.angle(y))
178
+
179
+
180
+ def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
181
+ mask = y_mask
182
+
183
+ try:
184
+ if min_range < fade_size * 2:
185
+ raise ValueError("min_range must be >= fade_size * 2")
186
+
187
+ idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
188
+ start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
189
+ end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
190
+ artifact_idx = np.where(end_idx - start_idx > min_range)[0]
191
+ weight = np.zeros_like(y_mask)
192
+ if len(artifact_idx) > 0:
193
+ start_idx = start_idx[artifact_idx]
194
+ end_idx = end_idx[artifact_idx]
195
+ old_e = None
196
+ for s, e in zip(start_idx, end_idx):
197
+ if old_e is not None and s - old_e < fade_size:
198
+ s = old_e - fade_size * 2
199
+
200
+ if s != 0:
201
+ weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
202
+ else:
203
+ s -= fade_size
204
+
205
+ if e != y_mask.shape[2]:
206
+ weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
207
+ else:
208
+ e += fade_size
209
+
210
+ weight[:, :, s + fade_size : e - fade_size] = 1
211
+ old_e = e
212
+
213
+ v_mask = 1 - y_mask
214
+ y_mask += weight * v_mask
215
+
216
+ mask = y_mask
217
+ except Exception as e:
218
+ error_name = f"{type(e).__name__}"
219
+ traceback_text = "".join(traceback.format_tb(e.__traceback__))
220
+ message = f'{error_name}: "{e}"\n{traceback_text}"'
221
+ print("Post Process Failed: ", message)
222
+
223
+ return mask
224
+
225
+
226
+ def align_wave_head_and_tail(a, b):
227
+ l = min([a[0].size, b[0].size])
228
+
229
+ return a[:l, :l], b[:l, :l]
230
+
231
+
232
+ def convert_channels(spec, mp, band):
233
+ cc = mp.param["band"][band].get("convert_channels")
234
+
235
+ if "mid_side_c" == cc:
236
+ spec_left = np.add(spec[0], spec[1] * 0.25)
237
+ spec_right = np.subtract(spec[1], spec[0] * 0.25)
238
+ elif "mid_side" == cc:
239
+ spec_left = np.add(spec[0], spec[1]) / 2
240
+ spec_right = np.subtract(spec[0], spec[1])
241
+ elif "stereo_n" == cc:
242
+ spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
243
+ spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
244
+ else:
245
+ return spec
246
+
247
+ return np.asfortranarray([spec_left, spec_right])
248
+
249
+
250
+ def combine_spectrograms(specs, mp, is_v51_model=False):
251
+ l = min([specs[i].shape[2] for i in specs])
252
+ spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
253
+ offset = 0
254
+ bands_n = len(mp.param["band"])
255
+
256
+ for d in range(1, bands_n + 1):
257
+ h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
258
+ spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
259
+ offset += h
260
+
261
+ if offset > mp.param["bins"]:
262
+ raise ValueError("Too much bins")
263
+
264
+ # lowpass fiter
265
+
266
+ if mp.param["pre_filter_start"] > 0:
267
+ if is_v51_model:
268
+ spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
269
+ else:
270
+ if bands_n == 1:
271
+ spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
272
+ else:
273
+ gp = 1
274
+ for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
275
+ g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
276
+ gp = g
277
+ spec_c[:, b, :] *= g
278
+
279
+ return np.asfortranarray(spec_c)
280
+
281
+
282
+ def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
283
+
284
+ if wave.ndim == 1:
285
+ wave = np.asfortranarray([wave, wave])
286
+
287
+ if not is_v51_model:
288
+ if mp.param["reverse"]:
289
+ wave_left = np.flip(np.asfortranarray(wave[0]))
290
+ wave_right = np.flip(np.asfortranarray(wave[1]))
291
+ elif mp.param["mid_side"]:
292
+ wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
293
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
294
+ elif mp.param["mid_side_b2"]:
295
+ wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
296
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
297
+ else:
298
+ wave_left = np.asfortranarray(wave[0])
299
+ wave_right = np.asfortranarray(wave[1])
300
+ else:
301
+ wave_left = np.asfortranarray(wave[0])
302
+ wave_right = np.asfortranarray(wave[1])
303
+
304
+ spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
305
+ spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
306
+
307
+ spec = np.asfortranarray([spec_left, spec_right])
308
+
309
+ if is_v51_model:
310
+ spec = convert_channels(spec, mp, band)
311
+
312
+ return spec
313
+
314
+
315
+ def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
316
+ spec_left = np.asfortranarray(spec[0])
317
+ spec_right = np.asfortranarray(spec[1])
318
+
319
+ wave_left = librosa.istft(spec_left, hop_length=hop_length)
320
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
321
+
322
+ if is_v51_model:
323
+ cc = mp.param["band"][band].get("convert_channels")
324
+ if "mid_side_c" == cc:
325
+ return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
326
+ elif "mid_side" == cc:
327
+ return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
328
+ elif "stereo_n" == cc:
329
+ return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
330
+ else:
331
+ if mp.param["reverse"]:
332
+ return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
333
+ elif mp.param["mid_side"]:
334
+ return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
335
+ elif mp.param["mid_side_b2"]:
336
+ return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
337
+
338
+ return np.asfortranarray([wave_left, wave_right])
339
+
340
+
341
+ def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
342
+ bands_n = len(mp.param["band"])
343
+ offset = 0
344
+
345
+ for d in range(1, bands_n + 1):
346
+ bp = mp.param["band"][d]
347
+ spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
348
+ h = bp["crop_stop"] - bp["crop_start"]
349
+ spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
350
+
351
+ offset += h
352
+ if d == bands_n: # higher
353
+ if extra_bins_h: # if --high_end_process bypass
354
+ max_bin = bp["n_fft"] // 2
355
+ spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
356
+ if bp["hpf_start"] > 0:
357
+ if is_v51_model:
358
+ spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
359
+ else:
360
+ spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
361
+ if bands_n == 1:
362
+ wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
363
+ else:
364
+ wave = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
365
+ else:
366
+ sr = mp.param["band"][d + 1]["sr"]
367
+ if d == 1: # lower
368
+ if is_v51_model:
369
+ spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
370
+ else:
371
+ spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
372
+
373
+ try:
374
+ wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
375
+ except ValueError as e:
376
+ print(f"Error during resampling: {e}")
377
+ print(f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}")
378
+
379
+ else: # mid
380
+ if is_v51_model:
381
+ spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
382
+ spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
383
+ else:
384
+ spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
385
+ spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
386
+
387
+ wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
388
+
389
+ try:
390
+ wave = librosa.resample(wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
391
+ except ValueError as e:
392
+ print(f"Error during resampling: {e}")
393
+ print(f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}")
394
+
395
+ return wave
396
+
397
+
398
+ def get_lp_filter_mask(n_bins, bin_start, bin_stop):
399
+ mask = np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
400
+
401
+ return mask
402
+
403
+
404
+ def get_hp_filter_mask(n_bins, bin_start, bin_stop):
405
+ mask = np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
406
+
407
+ return mask
408
+
409
+
410
+ def fft_lp_filter(spec, bin_start, bin_stop):
411
+ g = 1.0
412
+ for b in range(bin_start, bin_stop):
413
+ g -= 1 / (bin_stop - bin_start)
414
+ spec[:, b, :] = g * spec[:, b, :]
415
+
416
+ spec[:, bin_stop:, :] *= 0
417
+
418
+ return spec
419
+
420
+
421
+ def fft_hp_filter(spec, bin_start, bin_stop):
422
+ g = 1.0
423
+ for b in range(bin_start, bin_stop, -1):
424
+ g -= 1 / (bin_start - bin_stop)
425
+ spec[:, b, :] = g * spec[:, b, :]
426
+
427
+ spec[:, 0 : bin_stop + 1, :] *= 0
428
+
429
+ return spec
430
+
431
+
432
+ def spectrogram_to_wave_old(spec, hop_length=1024):
433
+ if spec.ndim == 2:
434
+ wave = librosa.istft(spec, hop_length=hop_length)
435
+ elif spec.ndim == 3:
436
+ spec_left = np.asfortranarray(spec[0])
437
+ spec_right = np.asfortranarray(spec[1])
438
+
439
+ wave_left = librosa.istft(spec_left, hop_length=hop_length)
440
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
441
+ wave = np.asfortranarray([wave_left, wave_right])
442
+
443
+ return wave
444
+
445
+
446
+ def wave_to_spectrogram_old(wave, hop_length, n_fft):
447
+ wave_left = np.asfortranarray(wave[0])
448
+ wave_right = np.asfortranarray(wave[1])
449
+
450
+ spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
451
+ spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
452
+
453
+ spec = np.asfortranarray([spec_left, spec_right])
454
+
455
+ return spec
456
+
457
+
458
+ def mirroring(a, spec_m, input_high_end, mp):
459
+ if "mirroring" == a:
460
+ mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
461
+ mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
462
+
463
+ return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
464
+
465
+ if "mirroring2" == a:
466
+ mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
467
+ mi = np.multiply(mirror, input_high_end * 1.7)
468
+
469
+ return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
470
+
471
+
472
+ def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
473
+ aggr = aggressiveness["value"] * 2
474
+
475
+ if aggr != 0:
476
+ if is_non_accom_stem:
477
+ aggr = 1 - aggr
478
+
479
+ if np.any(aggr > 10) or np.any(aggr < -10):
480
+ print(f"Warning: Extreme aggressiveness values detected: {aggr}")
481
+
482
+ aggr = [aggr, aggr]
483
+
484
+ if aggressiveness["aggr_correction"] is not None:
485
+ aggr[0] += aggressiveness["aggr_correction"]["left"]
486
+ aggr[1] += aggressiveness["aggr_correction"]["right"]
487
+
488
+ for ch in range(2):
489
+ mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
490
+ mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
491
+
492
+ return mask
493
+
494
+
495
+ def stft(wave, nfft, hl):
496
+ wave_left = np.asfortranarray(wave[0])
497
+ wave_right = np.asfortranarray(wave[1])
498
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
499
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
500
+ spec = np.asfortranarray([spec_left, spec_right])
501
+
502
+ return spec
503
+
504
+
505
+ def istft(spec, hl):
506
+ spec_left = np.asfortranarray(spec[0])
507
+ spec_right = np.asfortranarray(spec[1])
508
+ wave_left = librosa.istft(spec_left, hop_length=hl)
509
+ wave_right = librosa.istft(spec_right, hop_length=hl)
510
+ wave = np.asfortranarray([wave_left, wave_right])
511
+
512
+ return wave
513
+
514
+
515
+ def spec_effects(wave, algorithm="Default", value=None):
516
+ if np.isnan(wave).any() or np.isinf(wave).any():
517
+ print(f"Warning: Detected NaN or infinite values in wave input. Shape: {wave.shape}")
518
+
519
+ spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
520
+ if algorithm == "Min_Mag":
521
+ v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
522
+ wave = istft(v_spec_m, 1024)
523
+ elif algorithm == "Max_Mag":
524
+ v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
525
+ wave = istft(v_spec_m, 1024)
526
+ elif algorithm == "Default":
527
+ wave = (wave[1] * value) + (wave[0] * (1 - value))
528
+ elif algorithm == "Invert_p":
529
+ X_mag = np.abs(spec[0])
530
+ y_mag = np.abs(spec[1])
531
+ max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
532
+ v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0]))
533
+ wave = istft(v_spec, 1024)
534
+
535
+ return wave
536
+
537
+
538
+ def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
539
+ wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
540
+
541
+ if wave.ndim == 1:
542
+ wave = np.asfortranarray([wave, wave])
543
+
544
+ return wave
545
+
546
+
547
+ def wave_to_spectrogram_no_mp(wave):
548
+
549
+ spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
550
+
551
+ if spec.ndim == 1:
552
+ spec = np.asfortranarray([spec, spec])
553
+
554
+ return spec
555
+
556
+
557
+ def invert_audio(specs, invert_p=True):
558
+
559
+ ln = min([specs[0].shape[2], specs[1].shape[2]])
560
+ specs[0] = specs[0][:, :, :ln]
561
+ specs[1] = specs[1][:, :, :ln]
562
+
563
+ if invert_p:
564
+ X_mag = np.abs(specs[0])
565
+ y_mag = np.abs(specs[1])
566
+ max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
567
+ v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0]))
568
+ else:
569
+ specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
570
+ v_spec = specs[0] - specs[1]
571
+
572
+ return v_spec
573
+
574
+
575
+ def invert_stem(mixture, stem):
576
+ mixture = wave_to_spectrogram_no_mp(mixture)
577
+ stem = wave_to_spectrogram_no_mp(stem)
578
+ output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
579
+
580
+ return -output.T
581
+
582
+
583
+ def ensembling(a, inputs, is_wavs=False):
584
+
585
+ for i in range(1, len(inputs)):
586
+ if i == 1:
587
+ input = inputs[0]
588
+
589
+ if is_wavs:
590
+ ln = min([input.shape[1], inputs[i].shape[1]])
591
+ input = input[:, :ln]
592
+ inputs[i] = inputs[i][:, :ln]
593
+ else:
594
+ ln = min([input.shape[2], inputs[i].shape[2]])
595
+ input = input[:, :, :ln]
596
+ inputs[i] = inputs[i][:, :, :ln]
597
+
598
+ if MIN_SPEC == a:
599
+ input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
600
+ if MAX_SPEC == a:
601
+ #input = np.array(np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), inputs[i], input), dtype=object)
602
+ input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
603
+ #max_spec = np.array([np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), s, specs[0]) for s in specs[1:]], dtype=object)[-1]
604
+
605
+ # linear_ensemble
606
+ # input = ensemble_wav(inputs, split_size=1)
607
+
608
+ return input
609
+
610
+
611
+ def ensemble_for_align(waves):
612
+
613
+ specs = []
614
+
615
+ for wav in waves:
616
+ spec = wave_to_spectrogram_no_mp(wav.T)
617
+ specs.append(spec)
618
+
619
+ wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
620
+ wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
621
+
622
+ return wav_aligned
623
+
624
+
625
+ def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
626
+
627
+ wavs_ = []
628
+
629
+ if algorithm == AVERAGE:
630
+ output = average_audio(audio_input)
631
+ samplerate = 44100
632
+ else:
633
+ specs = []
634
+
635
+ for i in range(len(audio_input)):
636
+ wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
637
+ wavs_.append(wave)
638
+ spec = wave if is_wave else wave_to_spectrogram_no_mp(wave)
639
+ specs.append(spec)
640
+
641
+ wave_shapes = [w.shape[1] for w in wavs_]
642
+ target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
643
+
644
+ if is_wave:
645
+ output = ensembling(algorithm, specs, is_wavs=True)
646
+ else:
647
+ output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
648
+
649
+ output = to_shape(output, target_shape.shape)
650
+
651
+ sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
652
+
653
+
654
+ def to_shape(x, target_shape):
655
+ padding_list = []
656
+ for x_dim, target_dim in zip(x.shape, target_shape):
657
+ pad_value = target_dim - x_dim
658
+ pad_tuple = (0, pad_value)
659
+ padding_list.append(pad_tuple)
660
+
661
+ return np.pad(x, tuple(padding_list), mode="constant")
662
+
663
+
664
+ def to_shape_minimize(x: np.ndarray, target_shape):
665
+
666
+ padding_list = []
667
+ for x_dim, target_dim in zip(x.shape, target_shape):
668
+ pad_value = target_dim - x_dim
669
+ pad_tuple = (0, pad_value)
670
+ padding_list.append(pad_tuple)
671
+
672
+ return np.pad(x, tuple(padding_list), mode="constant")
673
+
674
+
675
+ def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
676
+ """
677
+ Detect silence at the beginning of an audio signal.
678
+
679
+ :param audio: np.array, audio signal
680
+ :param sr: int, sample rate
681
+ :param silence_threshold: float, magnitude threshold below which is considered silence
682
+ :param frame_length: int, the number of samples to consider for each check
683
+
684
+ :return: float, duration of the leading silence in milliseconds
685
+ """
686
+
687
+ if len(audio.shape) == 2:
688
+ # If stereo, pick the channel with more energy to determine the silence
689
+ channel = np.argmax(np.sum(np.abs(audio), axis=1))
690
+ audio = audio[channel]
691
+
692
+ for i in range(0, len(audio), frame_length):
693
+ if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold:
694
+ return (i / sr) * 1000
695
+
696
+ return (len(audio) / sr) * 1000
697
+
698
+
699
+ def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
700
+ """
701
+ Adjust the leading silence of the target_audio to match the leading silence of the reference_audio.
702
+
703
+ :param target_audio: np.array, audio signal that will have its silence adjusted
704
+ :param reference_audio: np.array, audio signal used as a reference
705
+ :param sr: int, sample rate
706
+ :param silence_threshold: float, magnitude threshold below which is considered silence
707
+ :param frame_length: int, the number of samples to consider for each check
708
+
709
+ :return: np.array, target_audio adjusted to have the same leading silence as reference_audio
710
+ """
711
+
712
+ def find_silence_end(audio):
713
+ if len(audio.shape) == 2:
714
+ # If stereo, pick the channel with more energy to determine the silence
715
+ channel = np.argmax(np.sum(np.abs(audio), axis=1))
716
+ audio_mono = audio[channel]
717
+ else:
718
+ audio_mono = audio
719
+
720
+ for i in range(0, len(audio_mono), frame_length):
721
+ if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold:
722
+ return i
723
+ return len(audio_mono)
724
+
725
+ ref_silence_end = find_silence_end(reference_audio)
726
+ target_silence_end = find_silence_end(target_audio)
727
+ silence_difference = ref_silence_end - target_silence_end
728
+
729
+ try:
730
+ ref_silence_end_p = (ref_silence_end / 44100) * 1000
731
+ target_silence_end_p = (target_silence_end / 44100) * 1000
732
+ silence_difference_p = ref_silence_end_p - target_silence_end_p
733
+ print("silence_difference: ", silence_difference_p)
734
+ except Exception as e:
735
+ pass
736
+
737
+ if silence_difference > 0: # Add silence to target_audio
738
+ if len(target_audio.shape) == 2: # stereo
739
+ silence_to_add = np.zeros((target_audio.shape[0], silence_difference))
740
+ else: # mono
741
+ silence_to_add = np.zeros(silence_difference)
742
+ return np.hstack((silence_to_add, target_audio))
743
+ elif silence_difference < 0: # Remove silence from target_audio
744
+ if len(target_audio.shape) == 2: # stereo
745
+ return target_audio[:, -silence_difference:]
746
+ else: # mono
747
+ return target_audio[-silence_difference:]
748
+ else: # No adjustment needed
749
+ return target_audio
750
+
751
+
752
+ def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False):
753
+
754
+ if is_swap:
755
+ array_1, array_2 = array_1.T, array_2.T
756
+
757
+ # print("before", array_1.shape, array_2.shape)
758
+ if array_1.shape[1] > array_2.shape[1]:
759
+ array_1 = array_1[:, : array_2.shape[1]]
760
+ elif array_1.shape[1] < array_2.shape[1]:
761
+ padding = array_2.shape[1] - array_1.shape[1]
762
+ array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
763
+
764
+ # print("after", array_1.shape, array_2.shape)
765
+
766
+ if is_swap:
767
+ array_1, array_2 = array_1.T, array_2.T
768
+
769
+ return array_1
770
+
771
+
772
+ def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
773
+
774
+ if len(array_1) > len(array_2):
775
+ array_1 = array_1[: len(array_2)]
776
+ elif len(array_1) < len(array_2):
777
+ padding = len(array_2) - len(array_1)
778
+ array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
779
+
780
+ return array_1
781
+
782
+
783
+ def change_pitch_semitones(y, sr, semitone_shift):
784
+ factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling
785
+ y_pitch_tuned = []
786
+ for y_channel in y:
787
+ y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
788
+ y_pitch_tuned = np.array(y_pitch_tuned)
789
+ new_sr = sr * factor
790
+ return y_pitch_tuned, new_sr
791
+
792
+
793
+ def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
794
+
795
+ wav, sr = librosa.load(audio_file, sr=44100, mono=False)
796
+
797
+ if wav.ndim == 1:
798
+ wav = np.asfortranarray([wav, wav])
799
+
800
+ if not is_time_correction:
801
+ wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
802
+ else:
803
+ if is_pitch:
804
+ wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None)
805
+ wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None)
806
+ else:
807
+ wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None)
808
+ wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None)
809
+
810
+ if wav_1.shape > wav_2.shape:
811
+ wav_2 = to_shape(wav_2, wav_1.shape)
812
+ if wav_1.shape < wav_2.shape:
813
+ wav_1 = to_shape(wav_1, wav_2.shape)
814
+
815
+ wav_mix = np.asfortranarray([wav_1, wav_2])
816
+
817
+ sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
818
+ save_format(export_path)
819
+
820
+
821
+ def average_audio(audio):
822
+
823
+ waves = []
824
+ wave_shapes = []
825
+ final_waves = []
826
+
827
+ for i in range(len(audio)):
828
+ wave = librosa.load(audio[i], sr=44100, mono=False)
829
+ waves.append(wave[0])
830
+ wave_shapes.append(wave[0].shape[1])
831
+
832
+ wave_shapes_index = wave_shapes.index(max(wave_shapes))
833
+ target_shape = waves[wave_shapes_index]
834
+ waves.pop(wave_shapes_index)
835
+ final_waves.append(target_shape)
836
+
837
+ for n_array in waves:
838
+ wav_target = to_shape(n_array, target_shape.shape)
839
+ final_waves.append(wav_target)
840
+
841
+ waves = sum(final_waves)
842
+ waves = waves / len(audio)
843
+
844
+ return waves
845
+
846
+
847
+ def average_dual_sources(wav_1, wav_2, value):
848
+
849
+ if wav_1.shape > wav_2.shape:
850
+ wav_2 = to_shape(wav_2, wav_1.shape)
851
+ if wav_1.shape < wav_2.shape:
852
+ wav_1 = to_shape(wav_1, wav_2.shape)
853
+
854
+ wave = (wav_1 * value) + (wav_2 * (1 - value))
855
+
856
+ return wave
857
+
858
+
859
+ def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
860
+
861
+ if wav_1.shape > wav_2.shape:
862
+ wav_2 = to_shape(wav_2, wav_1.shape)
863
+ if wav_1.shape < wav_2.shape:
864
+ ln = min([wav_1.shape[1], wav_2.shape[1]])
865
+ wav_2 = wav_2[:, :ln]
866
+
867
+ ln = min([wav_1.shape[1], wav_2.shape[1]])
868
+ wav_1 = wav_1[:, :ln]
869
+ wav_2 = wav_2[:, :ln]
870
+
871
+ return wav_2
872
+
873
+
874
+ def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray):
875
+
876
+ if wav_1_shape > wav_2.shape:
877
+ wav_2 = to_shape(wav_2, wav_1_shape)
878
+
879
+ return wav_2
880
+
881
+
882
+ def combine_arrarys(audio_sources, is_swap=False):
883
+ source = np.zeros_like(max(audio_sources, key=np.size))
884
+
885
+ for v in audio_sources:
886
+ v = match_array_shapes(v, source, is_swap=is_swap)
887
+ source += v
888
+
889
+ return source
890
+
891
+
892
+ def combine_audio(paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
893
+
894
+ source = combine_arrarys([load_audio(i) for i in paths])
895
+ save_path = f"{audio_file_base}_combined.wav"
896
+ sf.write(save_path, source.T, 44100, subtype=wav_type_set)
897
+ save_format(save_path)
898
+
899
+
900
+ def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
901
+ # Reduce the volume
902
+ inst_source = inst_source * (1 - reduction_rate)
903
+
904
+ mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True)
905
+
906
+ return mix_reduced
907
+
908
+
909
+ def organize_inputs(inputs):
910
+ input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
911
+
912
+ for i in inputs:
913
+ if i.endswith("_(Vocals).wav"):
914
+ input_list["reference"] = i
915
+ elif "_RVC_" in i:
916
+ input_list["target"] = i
917
+ elif i.endswith("reverbed_stem.wav"):
918
+ input_list["reverb"] = i
919
+ elif i.endswith("_(Instrumental).wav"):
920
+ input_list["inst"] = i
921
+
922
+ return input_list
923
+
924
+
925
+ def check_if_phase_inverted(wav1, wav2, is_mono=False):
926
+ # Load the audio files
927
+ if not is_mono:
928
+ wav1 = np.mean(wav1, axis=0)
929
+ wav2 = np.mean(wav2, axis=0)
930
+
931
+ # Compute the correlation
932
+ correlation = np.corrcoef(wav1[:1000], wav2[:1000])
933
+
934
+ return correlation[0, 1] < 0
935
+
936
+
937
+ def align_audio(
938
+ file1,
939
+ file2,
940
+ file2_aligned,
941
+ file_subtracted,
942
+ wav_type_set,
943
+ is_save_aligned,
944
+ command_Text,
945
+ save_format,
946
+ align_window: list,
947
+ align_intro_val: list,
948
+ db_analysis: tuple,
949
+ set_progress_bar,
950
+ phase_option,
951
+ phase_shifts,
952
+ is_match_silence,
953
+ is_spec_match,
954
+ ):
955
+
956
+ global progress_value
957
+ progress_value = 0
958
+ is_mono = False
959
+
960
+ def get_diff(a, b):
961
+ corr = np.correlate(a, b, "full")
962
+ diff = corr.argmax() - (b.shape[0] - 1)
963
+
964
+ return diff
965
+
966
+ def progress_bar(length):
967
+ global progress_value
968
+ progress_value += 1
969
+
970
+ if (0.90 / length * progress_value) >= 0.9:
971
+ length = progress_value + 1
972
+
973
+ set_progress_bar(0.1, (0.9 / length * progress_value))
974
+
975
+ # read tracks
976
+
977
+ if file1.endswith(".mp3") and is_macos:
978
+ length1 = rerun_mp3(file1)
979
+ wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False)
980
+ else:
981
+ wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
982
+
983
+ if file2.endswith(".mp3") and is_macos:
984
+ length2 = rerun_mp3(file2)
985
+ wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False)
986
+ else:
987
+ wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
988
+
989
+ if wav1.ndim == 1 and wav2.ndim == 1:
990
+ is_mono = True
991
+ elif wav1.ndim == 1:
992
+ wav1 = np.asfortranarray([wav1, wav1])
993
+ elif wav2.ndim == 1:
994
+ wav2 = np.asfortranarray([wav2, wav2])
995
+
996
+ # Check if phase is inverted
997
+ if phase_option == AUTO_PHASE:
998
+ if check_if_phase_inverted(wav1, wav2, is_mono=is_mono):
999
+ wav2 = -wav2
1000
+ elif phase_option == POSITIVE_PHASE:
1001
+ wav2 = +wav2
1002
+ elif phase_option == NEGATIVE_PHASE:
1003
+ wav2 = -wav2
1004
+
1005
+ if is_match_silence:
1006
+ wav2 = adjust_leading_silence(wav2, wav1)
1007
+
1008
+ wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
1009
+ wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
1010
+
1011
+ if not is_mono:
1012
+ wav1 = wav1.transpose()
1013
+ wav2 = wav2.transpose()
1014
+
1015
+ wav2_org = wav2.copy()
1016
+
1017
+ command_Text("Processing files... \n")
1018
+ seconds_length = min(wav1_length, wav2_length)
1019
+
1020
+ wav2_aligned_sources = []
1021
+
1022
+ for sec_len in align_intro_val:
1023
+ # pick a position at 1 second in and get diff
1024
+ sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
1025
+ index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
1026
+
1027
+ if is_mono:
1028
+ samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
1029
+ diff = get_diff(samp1, samp2)
1030
+ # print(f"Estimated difference: {diff}\n")
1031
+ else:
1032
+ index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
1033
+ samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
1034
+ samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
1035
+ diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
1036
+ # print(f"Estimated difference Left Channel: {diff}\nEstimated difference Right Channel: {diff_r}\n")
1037
+
1038
+ # make aligned track 2
1039
+ if diff > 0:
1040
+ zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2))
1041
+ wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0)
1042
+ elif diff < 0:
1043
+ wav2_aligned = wav2_org[-diff:]
1044
+ else:
1045
+ wav2_aligned = wav2_org
1046
+ # command_Text(f"Audio files already aligned.\n")
1047
+
1048
+ if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources):
1049
+ wav2_aligned_sources.append(wav2_aligned)
1050
+
1051
+ # print("Unique Sources: ", len(wav2_aligned_sources))
1052
+
1053
+ unique_sources = len(wav2_aligned_sources)
1054
+
1055
+ sub_mapper_big_mapper = {}
1056
+
1057
+ for s in wav2_aligned_sources:
1058
+ wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
1059
+
1060
+ if align_window:
1061
+ wav_sub = time_correction(
1062
+ wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts
1063
+ )
1064
+ wav_sub_size = np.abs(wav_sub).mean()
1065
+ sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
1066
+ else:
1067
+ wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
1068
+ db_range = db_analysis[1]
1069
+
1070
+ for db_adjustment in db_range:
1071
+ # Adjust the dB of track2
1072
+ s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20))
1073
+ wav_sub = wav1 - s_adjusted
1074
+ wav_sub_size = np.abs(wav_sub).mean()
1075
+ sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
1076
+
1077
+ # print(sub_mapper_big_mapper.keys(), min(sub_mapper_big_mapper.keys()))
1078
+
1079
+ sub_mapper_value_list = list(sub_mapper_big_mapper.values())
1080
+
1081
+ if is_spec_match and len(sub_mapper_value_list) >= 2:
1082
+ # print("using spec ensemble with align")
1083
+ wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values()))
1084
+ else:
1085
+ # print("using linear ensemble with align")
1086
+ wav_sub = ensemble_wav(list(sub_mapper_big_mapper.values()))
1087
+
1088
+ # print(f"Mix Mean: {np.abs(wav1).mean()}\nInst Mean: {np.abs(wav2).mean()}")
1089
+ # print('Final: ', np.abs(wav_sub).mean())
1090
+ wav_sub = np.clip(wav_sub, -1, +1)
1091
+
1092
+ command_Text(f"Saving inverted track... ")
1093
+
1094
+ if is_save_aligned or is_spec_match:
1095
+ wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
1096
+ wav2_aligned = wav1 - wav_sub
1097
+
1098
+ if is_spec_match:
1099
+ if wav1.ndim == 1 and wav2.ndim == 1:
1100
+ wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
1101
+ wav1 = np.asfortranarray([wav1, wav1]).T
1102
+
1103
+ wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
1104
+ wav_sub = wav1 - wav2_aligned
1105
+
1106
+ if is_save_aligned:
1107
+ sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
1108
+ save_format(file2_aligned)
1109
+
1110
+ sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
1111
+ save_format(file_subtracted)
1112
+
1113
+
1114
+ def phase_shift_hilbert(signal, degree):
1115
+ analytic_signal = hilbert(signal)
1116
+ return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
1117
+
1118
+
1119
+ def get_phase_shifted_tracks(track, phase_shift):
1120
+ if phase_shift == 180:
1121
+ return [track, -track]
1122
+
1123
+ step = phase_shift
1124
+ end = 180 - (180 % step) if 180 % step == 0 else 181
1125
+ phase_range = range(step, end, step)
1126
+
1127
+ flipped_list = [track, -track]
1128
+ for i in phase_range:
1129
+ flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
1130
+
1131
+ return flipped_list
1132
+
1133
+
1134
+ def time_correction(mix: np.ndarray, instrumental: np.ndarray, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
1135
+ # Function to align two tracks using cross-correlation
1136
+
1137
+ def align_tracks(track1, track2):
1138
+ # A dictionary to store each version of track2_shifted and its mean absolute value
1139
+ shifted_tracks = {}
1140
+
1141
+ # Loop to adjust dB of track2
1142
+ track2 = track2 * np.power(10, db_analysis[0] / 20)
1143
+ db_range = db_analysis[1]
1144
+
1145
+ if phase_shifts == 190:
1146
+ track2_flipped = [track2]
1147
+ else:
1148
+ track2_flipped = get_phase_shifted_tracks(track2, phase_shifts)
1149
+
1150
+ for db_adjustment in db_range:
1151
+ for t in track2_flipped:
1152
+ # Adjust the dB of track2
1153
+ track2_adjusted = t * (10 ** (db_adjustment / 20))
1154
+ corr = correlate(track1, track2_adjusted)
1155
+ delay = np.argmax(np.abs(corr)) - (len(track1) - 1)
1156
+ track2_shifted = np.roll(track2_adjusted, shift=delay)
1157
+
1158
+ # Compute the mean absolute value of track2_shifted
1159
+ track2_shifted_sub = track1 - track2_shifted
1160
+ mean_abs_value = np.abs(track2_shifted_sub).mean()
1161
+
1162
+ # Store track2_shifted and its mean absolute value in the dictionary
1163
+ shifted_tracks[mean_abs_value] = track2_shifted
1164
+
1165
+ # Return the version of track2_shifted with the smallest mean absolute value
1166
+
1167
+ return shifted_tracks[min(shifted_tracks.keys())]
1168
+
1169
+ # Make sure the audio files have the same shape
1170
+
1171
+ assert mix.shape == instrumental.shape, f"Audio files must have the same shape - Mix: {mix.shape}, Inst: {instrumental.shape}"
1172
+
1173
+ seconds_length = seconds_length // 2
1174
+
1175
+ sub_mapper = {}
1176
+
1177
+ progress_update_interval = 120
1178
+ total_iterations = 0
1179
+
1180
+ if len(align_window) > 2:
1181
+ progress_update_interval = 320
1182
+
1183
+ for secs in align_window:
1184
+ step = secs / 2
1185
+ window_size = int(sr * secs)
1186
+ step_size = int(sr * step)
1187
+
1188
+ if len(mix.shape) == 1:
1189
+ total_mono = (len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources
1190
+ total_iterations += total_mono
1191
+ else:
1192
+ total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2
1193
+ total_stereo = (total_stereo_ // progress_update_interval) * unique_sources
1194
+ total_iterations += total_stereo
1195
+
1196
+ # print(total_iterations)
1197
+
1198
+ for secs in align_window:
1199
+ sub = np.zeros_like(mix)
1200
+ divider = np.zeros_like(mix)
1201
+ step = secs / 2
1202
+ window_size = int(sr * secs)
1203
+ step_size = int(sr * step)
1204
+ window = np.hanning(window_size)
1205
+
1206
+ # For the mono case:
1207
+ if len(mix.shape) == 1:
1208
+ # The files are mono
1209
+ counter = 0
1210
+ for i in range(0, len(mix) - window_size, step_size):
1211
+ counter += 1
1212
+ if counter % progress_update_interval == 0:
1213
+ progress_bar(total_iterations)
1214
+ window_mix = mix[i : i + window_size] * window
1215
+ window_instrumental = instrumental[i : i + window_size] * window
1216
+ window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
1217
+ sub[i : i + window_size] += window_mix - window_instrumental_aligned
1218
+ divider[i : i + window_size] += window
1219
+ else:
1220
+ # The files are stereo
1221
+ counter = 0
1222
+ for ch in range(mix.shape[1]):
1223
+ for i in range(0, len(mix[:, ch]) - window_size, step_size):
1224
+ counter += 1
1225
+ if counter % progress_update_interval == 0:
1226
+ progress_bar(total_iterations)
1227
+ window_mix = mix[i : i + window_size, ch] * window
1228
+ window_instrumental = instrumental[i : i + window_size, ch] * window
1229
+ window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
1230
+ sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
1231
+ divider[i : i + window_size, ch] += window
1232
+
1233
+ # Normalize the result by the overlap count
1234
+ sub = np.where(divider > 1e-6, sub / divider, sub)
1235
+ sub_size = np.abs(sub).mean()
1236
+ sub_mapper = {**sub_mapper, **{sub_size: sub}}
1237
+
1238
+ # print("SUB_LEN", len(list(sub_mapper.values())))
1239
+
1240
+ sub = ensemble_wav(list(sub_mapper.values()), split_size=12)
1241
+
1242
+ return sub
1243
+
1244
+
1245
+ def ensemble_wav(waveforms, split_size=240):
1246
+ # Create a dictionary to hold the thirds of each waveform and their mean absolute values
1247
+ waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
1248
+
1249
+ # Initialize the final waveform
1250
+ final_waveform = []
1251
+
1252
+ # For chunk
1253
+ for third_idx in range(split_size):
1254
+ # Compute the mean absolute value of each third from each waveform
1255
+ means = [np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))]
1256
+
1257
+ # Find the index of the waveform with the lowest mean absolute value for this third
1258
+ min_index = np.argmin(means)
1259
+
1260
+ # Add the least noisy third to the final waveform
1261
+ final_waveform.append(waveform_thirds[min_index][third_idx])
1262
+
1263
+ # Concatenate all the thirds to create the final waveform
1264
+ final_waveform = np.concatenate(final_waveform)
1265
+
1266
+ return final_waveform
1267
+
1268
+
1269
+ def ensemble_wav_min(waveforms):
1270
+ for i in range(1, len(waveforms)):
1271
+ if i == 1:
1272
+ wave = waveforms[0]
1273
+
1274
+ ln = min(len(wave), len(waveforms[i]))
1275
+ wave = wave[:ln]
1276
+ waveforms[i] = waveforms[i][:ln]
1277
+
1278
+ wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
1279
+
1280
+ return wave
1281
+
1282
+
1283
+ def align_audio_test(wav1, wav2, sr1=44100):
1284
+ def get_diff(a, b):
1285
+ corr = np.correlate(a, b, "full")
1286
+ diff = corr.argmax() - (b.shape[0] - 1)
1287
+ return diff
1288
+
1289
+ # read tracks
1290
+ wav1 = wav1.transpose()
1291
+ wav2 = wav2.transpose()
1292
+
1293
+ # print(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n")
1294
+
1295
+ wav2_org = wav2.copy()
1296
+
1297
+ # pick a position at 1 second in and get diff
1298
+ index = sr1 # *seconds_length # 1 second in, assuming sr1 = sr2 = 44100
1299
+ samp1 = wav1[index : index + sr1, 0] # currently use left channel
1300
+ samp2 = wav2[index : index + sr1, 0]
1301
+ diff = get_diff(samp1, samp2)
1302
+
1303
+ # make aligned track 2
1304
+ if diff > 0:
1305
+ wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
1306
+ elif diff < 0:
1307
+ wav2_aligned = wav2_org[-diff:]
1308
+ else:
1309
+ wav2_aligned = wav2_org
1310
+
1311
+ return wav2_aligned
1312
+
1313
+
1314
+ def load_audio(audio_file):
1315
+ wav, sr = librosa.load(audio_file, sr=44100, mono=False)
1316
+
1317
+ if wav.ndim == 1:
1318
+ wav = np.asfortranarray([wav, wav])
1319
+
1320
+ return wav
1321
+
1322
+
1323
+ def rerun_mp3(audio_file):
1324
+ with audioread.audio_open(audio_file) as f:
1325
+ track_length = int(f.duration)
1326
+
1327
+ return track_length
audio_separator/separator/uvr_lib_v5/stft.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class STFT:
5
+ """
6
+ This class performs the Short-Time Fourier Transform (STFT) and its inverse (ISTFT).
7
+ These functions are essential for converting the audio between the time domain and the frequency domain,
8
+ which is a crucial aspect of audio processing in neural networks.
9
+ """
10
+
11
+ def __init__(self, logger, n_fft, hop_length, dim_f, device):
12
+ self.logger = logger
13
+ self.n_fft = n_fft
14
+ self.hop_length = hop_length
15
+ self.dim_f = dim_f
16
+ self.device = device
17
+ # Create a Hann window tensor for use in the STFT.
18
+ self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
19
+
20
+ def __call__(self, input_tensor):
21
+ # Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
22
+ is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
23
+
24
+ # If on a non-standard device, temporarily move the tensor to CPU for processing.
25
+ if is_non_standard_device:
26
+ input_tensor = input_tensor.cpu()
27
+
28
+ # Transfer the pre-defined window tensor to the same device as the input tensor.
29
+ stft_window = self.hann_window.to(input_tensor.device)
30
+
31
+ # Extract batch dimensions (all dimensions except the last two which are channel and time).
32
+ batch_dimensions = input_tensor.shape[:-2]
33
+
34
+ # Extract channel and time dimensions (last two dimensions of the tensor).
35
+ channel_dim, time_dim = input_tensor.shape[-2:]
36
+
37
+ # Reshape the tensor to merge batch and channel dimensions for STFT processing.
38
+ reshaped_tensor = input_tensor.reshape([-1, time_dim])
39
+
40
+ # Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
41
+ stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
42
+
43
+ # Rearrange the dimensions of the STFT output to bring the frequency dimension forward.
44
+ permuted_stft_output = stft_output.permute([0, 3, 1, 2])
45
+
46
+ # Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions.
47
+ final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape(
48
+ [*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]
49
+ )
50
+
51
+ # If the original tensor was on a non-standard device, move the processed tensor back to that device.
52
+ if is_non_standard_device:
53
+ final_output = final_output.to(self.device)
54
+
55
+ # Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`).
56
+ return final_output[..., : self.dim_f, :]
57
+
58
+ def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
59
+ """
60
+ Adds zero padding to the frequency dimension of the input tensor.
61
+ """
62
+ # Create a padding tensor for the frequency dimension
63
+ freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
64
+
65
+ # Concatenate the padding to the input tensor along the frequency dimension.
66
+ padded_tensor = torch.cat([input_tensor, freq_padding], -2)
67
+
68
+ return padded_tensor
69
+
70
+ def calculate_inverse_dimensions(self, input_tensor):
71
+ # Extract batch dimensions and frequency-time dimensions.
72
+ batch_dimensions = input_tensor.shape[:-3]
73
+ channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
74
+
75
+ # Calculate the number of frequency bins for the inverse STFT.
76
+ num_freq_bins = self.n_fft // 2 + 1
77
+
78
+ return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
79
+
80
+ def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
81
+ """
82
+ Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping
83
+ and creating a complex tensor from the real and imaginary parts.
84
+ """
85
+ # Reshape the tensor to separate real and imaginary parts and prepare for ISTFT.
86
+ reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
87
+
88
+ # Flatten batch dimensions and rearrange for ISTFT.
89
+ flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
90
+
91
+ # Rearrange the dimensions of the tensor to bring the frequency dimension forward.
92
+ permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
93
+
94
+ # Combine real and imaginary parts into a complex tensor.
95
+ complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
96
+
97
+ return complex_tensor
98
+
99
+ def inverse(self, input_tensor):
100
+ # Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
101
+ is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
102
+
103
+ # If on a non-standard device, temporarily move the tensor to CPU for processing.
104
+ if is_non_standard_device:
105
+ input_tensor = input_tensor.cpu()
106
+
107
+ # Transfer the pre-defined Hann window tensor to the same device as the input tensor.
108
+ stft_window = self.hann_window.to(input_tensor.device)
109
+
110
+ batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
111
+
112
+ padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
113
+
114
+ complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
115
+
116
+ # Perform the Inverse Short-Time Fourier Transform (ISTFT).
117
+ istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
118
+
119
+ # Reshape ISTFT result to restore original batch and channel dimensions.
120
+ final_output = istft_result.reshape([*batch_dimensions, 2, -1])
121
+
122
+ # If the original tensor was on a non-standard device, move the processed tensor back to that device.
123
+ if is_non_standard_device:
124
+ final_output = final_output.to(self.device)
125
+
126
+ return final_output
audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+
5
+ class STFT:
6
+ def __init__(self, n_fft, hop_length, dim_f, device):
7
+ self.n_fft = n_fft
8
+ self.hop_length = hop_length
9
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
10
+ self.dim_f = dim_f
11
+ self.device = device
12
+
13
+ def __call__(self, x):
14
+
15
+ x_is_mps = not x.device.type in ["cuda", "cpu"]
16
+ if x_is_mps:
17
+ x = x.cpu()
18
+
19
+ window = self.window.to(x.device)
20
+ batch_dims = x.shape[:-2]
21
+ c, t = x.shape[-2:]
22
+ x = x.reshape([-1, t])
23
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False)
24
+ x = x.permute([0, 3, 1, 2])
25
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
26
+
27
+ if x_is_mps:
28
+ x = x.to(self.device)
29
+
30
+ return x[..., :self.dim_f, :]
31
+
32
+ def inverse(self, x):
33
+
34
+ x_is_mps = not x.device.type in ["cuda", "cpu"]
35
+ if x_is_mps:
36
+ x = x.cpu()
37
+
38
+ window = self.window.to(x.device)
39
+ batch_dims = x.shape[:-3]
40
+ c, f, t = x.shape[-3:]
41
+ n = self.n_fft // 2 + 1
42
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
43
+ x = torch.cat([x, f_pad], -2)
44
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
45
+ x = x.permute([0, 2, 3, 1])
46
+ x = x[..., 0] + x[..., 1] * 1.j
47
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
48
+ x = x.reshape([*batch_dims, 2, -1])
49
+
50
+ if x_is_mps:
51
+ x = x.to(self.device)
52
+
53
+ return x
54
+
55
+ def get_norm(norm_type):
56
+ def norm(c, norm_type):
57
+ if norm_type == 'BatchNorm':
58
+ return nn.BatchNorm2d(c)
59
+ elif norm_type == 'InstanceNorm':
60
+ return nn.InstanceNorm2d(c, affine=True)
61
+ elif 'GroupNorm' in norm_type:
62
+ g = int(norm_type.replace('GroupNorm', ''))
63
+ return nn.GroupNorm(num_groups=g, num_channels=c)
64
+ else:
65
+ return nn.Identity()
66
+
67
+ return partial(norm, norm_type=norm_type)
68
+
69
+
70
+ def get_act(act_type):
71
+ if act_type == 'gelu':
72
+ return nn.GELU()
73
+ elif act_type == 'relu':
74
+ return nn.ReLU()
75
+ elif act_type[:3] == 'elu':
76
+ alpha = float(act_type.replace('elu', ''))
77
+ return nn.ELU(alpha)
78
+ else:
79
+ raise Exception
80
+
81
+
82
+ class Upscale(nn.Module):
83
+ def __init__(self, in_c, out_c, scale, norm, act):
84
+ super().__init__()
85
+ self.conv = nn.Sequential(
86
+ norm(in_c),
87
+ act,
88
+ nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.conv(x)
93
+
94
+
95
+ class Downscale(nn.Module):
96
+ def __init__(self, in_c, out_c, scale, norm, act):
97
+ super().__init__()
98
+ self.conv = nn.Sequential(
99
+ norm(in_c),
100
+ act,
101
+ nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
102
+ )
103
+
104
+ def forward(self, x):
105
+ return self.conv(x)
106
+
107
+
108
+ class TFC_TDF(nn.Module):
109
+ def __init__(self, in_c, c, l, f, bn, norm, act):
110
+ super().__init__()
111
+
112
+ self.blocks = nn.ModuleList()
113
+ for i in range(l):
114
+ block = nn.Module()
115
+
116
+ block.tfc1 = nn.Sequential(
117
+ norm(in_c),
118
+ act,
119
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
120
+ )
121
+ block.tdf = nn.Sequential(
122
+ norm(c),
123
+ act,
124
+ nn.Linear(f, f // bn, bias=False),
125
+ norm(c),
126
+ act,
127
+ nn.Linear(f // bn, f, bias=False),
128
+ )
129
+ block.tfc2 = nn.Sequential(
130
+ norm(c),
131
+ act,
132
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
133
+ )
134
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
135
+
136
+ self.blocks.append(block)
137
+ in_c = c
138
+
139
+ def forward(self, x):
140
+ for block in self.blocks:
141
+ s = block.shortcut(x)
142
+ x = block.tfc1(x)
143
+ x = x + block.tdf(x)
144
+ x = block.tfc2(x)
145
+ x = x + s
146
+ return x
147
+
148
+
149
+ class TFC_TDF_net(nn.Module):
150
+ def __init__(self, config, device):
151
+ super().__init__()
152
+ self.config = config
153
+ self.device = device
154
+
155
+ norm = get_norm(norm_type=config.model.norm)
156
+ act = get_act(act_type=config.model.act)
157
+
158
+ self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
159
+ self.num_subbands = config.model.num_subbands
160
+
161
+ dim_c = self.num_subbands * config.audio.num_channels * 2
162
+ n = config.model.num_scales
163
+ scale = config.model.scale
164
+ l = config.model.num_blocks_per_scale
165
+ c = config.model.num_channels
166
+ g = config.model.growth
167
+ bn = config.model.bottleneck_factor
168
+ f = config.audio.dim_f // self.num_subbands
169
+
170
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
171
+
172
+ self.encoder_blocks = nn.ModuleList()
173
+ for i in range(n):
174
+ block = nn.Module()
175
+ block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
176
+ block.downscale = Downscale(c, c + g, scale, norm, act)
177
+ f = f // scale[1]
178
+ c += g
179
+ self.encoder_blocks.append(block)
180
+
181
+ self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
182
+
183
+ self.decoder_blocks = nn.ModuleList()
184
+ for i in range(n):
185
+ block = nn.Module()
186
+ block.upscale = Upscale(c, c - g, scale, norm, act)
187
+ f = f * scale[1]
188
+ c -= g
189
+ block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
190
+ self.decoder_blocks.append(block)
191
+
192
+ self.final_conv = nn.Sequential(
193
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
194
+ act,
195
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
196
+ )
197
+
198
+ self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
199
+
200
+ def cac2cws(self, x):
201
+ k = self.num_subbands
202
+ b, c, f, t = x.shape
203
+ x = x.reshape(b, c, k, f // k, t)
204
+ x = x.reshape(b, c * k, f // k, t)
205
+ return x
206
+
207
+ def cws2cac(self, x):
208
+ k = self.num_subbands
209
+ b, c, f, t = x.shape
210
+ x = x.reshape(b, c // k, k, f, t)
211
+ x = x.reshape(b, c // k, f * k, t)
212
+ return x
213
+
214
+ def forward(self, x):
215
+
216
+ x = self.stft(x)
217
+
218
+ mix = x = self.cac2cws(x)
219
+
220
+ first_conv_out = x = self.first_conv(x)
221
+
222
+ x = x.transpose(-1, -2)
223
+
224
+ encoder_outputs = []
225
+ for block in self.encoder_blocks:
226
+ x = block.tfc_tdf(x)
227
+ encoder_outputs.append(x)
228
+ x = block.downscale(x)
229
+
230
+ x = self.bottleneck_block(x)
231
+
232
+ for block in self.decoder_blocks:
233
+ x = block.upscale(x)
234
+ x = torch.cat([x, encoder_outputs.pop()], 1)
235
+ x = block.tfc_tdf(x)
236
+
237
+ x = x.transpose(-1, -2)
238
+
239
+ x = x * first_conv_out # reduce artifacts
240
+
241
+ x = self.final_conv(torch.cat([mix, x], 1))
242
+
243
+ x = self.cws2cac(x)
244
+
245
+ if self.num_target_instruments > 1:
246
+ b, c, f, t = x.shape
247
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
248
+
249
+ x = self.stft.inverse(x)
250
+
251
+ return x
252
+
253
+
audio_separator/separator/uvr_lib_v5/vr_network/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # VR init.
audio_separator/separator/uvr_lib_v5/vr_network/layers.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
6
+
7
+
8
+ class Conv2DBNActiv(nn.Module):
9
+ """
10
+ This class implements a convolutional layer followed by batch normalization and an activation function.
11
+ It is a common pattern in deep learning for processing images or feature maps. The convolutional layer
12
+ applies a set of learnable filters to the input. Batch normalization then normalizes the output of the
13
+ convolution, and finally, an activation function introduces non-linearity to the model, allowing it to
14
+ learn more complex patterns.
15
+
16
+ Attributes:
17
+ conv (nn.Sequential): A sequential container of Conv2d, BatchNorm2d, and an activation layer.
18
+
19
+ Args:
20
+ num_input_channels (int): Number of input channels.
21
+ num_output_channels (int): Number of output channels.
22
+ kernel_size (int, optional): Size of the kernel. Defaults to 3.
23
+ stride_length (int, optional): Stride of the convolution. Defaults to 1.
24
+ padding_size (int, optional): Padding added to all sides of the input. Defaults to 1.
25
+ dilation_rate (int, optional): Spacing between kernel elements. Defaults to 1.
26
+ activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
27
+ """
28
+
29
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
30
+ super(Conv2DBNActiv, self).__init__()
31
+
32
+ # The nn.Sequential container allows us to stack the Conv2d, BatchNorm2d, and activation layers
33
+ # into a single module, simplifying the forward pass.
34
+ self.conv = nn.Sequential(nn.Conv2d(nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False), nn.BatchNorm2d(nout), activ())
35
+
36
+ def __call__(self, input_tensor):
37
+ # Defines the computation performed at every call.
38
+ # Simply passes the input through the sequential container.
39
+ return self.conv(input_tensor)
40
+
41
+
42
+ class SeperableConv2DBNActiv(nn.Module):
43
+ """
44
+ This class implements a separable convolutional layer followed by batch normalization and an activation function.
45
+ Separable convolutions are a type of convolution that splits the convolution operation into two simpler operations:
46
+ a depthwise convolution and a pointwise convolution. This can reduce the number of parameters and computational cost,
47
+ making the network more efficient while maintaining similar performance.
48
+
49
+ The depthwise convolution applies a single filter per input channel (input depth). The pointwise convolution,
50
+ which follows, applies a 1x1 convolution to combine the outputs of the depthwise convolution across channels.
51
+ Batch normalization is then applied to stabilize learning and reduce internal covariate shift. Finally,
52
+ an activation function introduces non-linearity, allowing the network to learn complex patterns.
53
+ Attributes:
54
+ conv (nn.Sequential): A sequential container of depthwise Conv2d, pointwise Conv2d, BatchNorm2d, and an activation layer.
55
+
56
+ Args:
57
+ num_input_channels (int): Number of input channels.
58
+ num_output_channels (int): Number of output channels.
59
+ kernel_size (int, optional): Size of the kernel for the depthwise convolution. Defaults to 3.
60
+ stride_length (int, optional): Stride of the convolution. Defaults to 1.
61
+ padding_size (int, optional): Padding added to all sides of the input for the depthwise convolution. Defaults to 1.
62
+ dilation_rate (int, optional): Spacing between kernel elements for the depthwise convolution. Defaults to 1.
63
+ activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
64
+ """
65
+
66
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
67
+ super(SeperableConv2DBNActiv, self).__init__()
68
+
69
+ # Initialize the sequential container with the depthwise convolution.
70
+ # The number of groups in the depthwise convolution is set to num_input_channels, which means each input channel is treated separately.
71
+ # The pointwise convolution then combines these separate channels into num_output_channels channels.
72
+ # Batch normalization is applied to the output of the pointwise convolution.
73
+ # Finally, the activation function is applied to introduce non-linearity.
74
+ self.conv = nn.Sequential(
75
+ nn.Conv2d(
76
+ nin,
77
+ nin, # For depthwise convolution, in_channels = out_channels = num_input_channels
78
+ kernel_size=ksize,
79
+ stride=stride,
80
+ padding=pad,
81
+ dilation=dilation,
82
+ groups=nin, # This makes it a depthwise convolution
83
+ bias=False, # Bias is not used because it will be handled by BatchNorm2d
84
+ ),
85
+ nn.Conv2d(
86
+ nin,
87
+ nout, # Pointwise convolution to combine channels
88
+ kernel_size=1, # Kernel size of 1 for pointwise convolution
89
+ bias=False, # Bias is not used because it will be handled by BatchNorm2d
90
+ ),
91
+ nn.BatchNorm2d(nout), # Normalize the output of the pointwise convolution
92
+ activ(), # Apply the activation function
93
+ )
94
+
95
+ def __call__(self, input_tensor):
96
+ # Pass the input through the sequential container.
97
+ # This performs the depthwise convolution, followed by the pointwise convolution,
98
+ # batch normalization, and finally applies the activation function.
99
+ return self.conv(input_tensor)
100
+
101
+
102
+ class Encoder(nn.Module):
103
+ """
104
+ The Encoder class is a part of the neural network architecture that is responsible for processing the input data.
105
+ It consists of two convolutional layers, each followed by batch normalization and an activation function.
106
+ The purpose of the Encoder is to transform the input data into a higher-level, abstract representation.
107
+ This is achieved by applying filters (through convolutions) that can capture patterns or features in the data.
108
+ The Encoder can be thought of as a feature extractor that prepares the data for further processing by the network.
109
+ Attributes:
110
+ conv1 (Conv2DBNActiv): The first convolutional layer in the encoder.
111
+ conv2 (Conv2DBNActiv): The second convolutional layer in the encoder.
112
+
113
+ Args:
114
+ number_of_input_channels (int): Number of input channels for the first convolutional layer.
115
+ number_of_output_channels (int): Number of output channels for the convolutional layers.
116
+ kernel_size (int): Kernel size for the convolutional layers.
117
+ stride_length (int): Stride for the convolutional operations.
118
+ padding_size (int): Padding added to all sides of the input for the convolutional layers.
119
+ activation_function (callable): The activation function to use after each convolutional layer.
120
+ """
121
+
122
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
123
+ super(Encoder, self).__init__()
124
+
125
+ # The first convolutional layer takes the input and applies a convolution,
126
+ # followed by batch normalization and an activation function specified by `activation_function`.
127
+ # This layer is responsible for capturing the initial set of features from the input data.
128
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
129
+
130
+ # The second convolutional layer further processes the output from the first layer,
131
+ # applying another set of convolution, batch normalization, and activation.
132
+ # This layer helps in capturing more complex patterns in the data by building upon the initial features extracted by conv1.
133
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
134
+
135
+ def __call__(self, input_tensor):
136
+ # The input data `input_tensor` is passed through the first convolutional layer.
137
+ # The output of this layer serves as a 'skip connection' that can be used later in the network to preserve spatial information.
138
+ skip = self.conv1(input_tensor)
139
+
140
+ # The output from the first layer is then passed through the second convolutional layer.
141
+ # This processed data `hidden` is the final output of the Encoder, representing the abstracted features of the input.
142
+ hidden = self.conv2(skip)
143
+
144
+ # The Encoder returns two outputs: `hidden`, the abstracted feature representation, and `skip`, the intermediate representation from conv1.
145
+ return hidden, skip
146
+
147
+
148
+ class Decoder(nn.Module):
149
+ """
150
+ The Decoder class is part of the neural network architecture, specifically designed to perform the inverse operation of an encoder.
151
+ Its main role is to reconstruct or generate data from encoded representations, which is crucial in tasks like image segmentation or audio processing.
152
+ This class uses upsampling, convolution, optional dropout for regularization, and concatenation of skip connections to achieve its goal.
153
+
154
+ Attributes:
155
+ convolution (Conv2DBNActiv): A convolutional layer with batch normalization and activation function.
156
+ dropout_layer (nn.Dropout2d): An optional dropout layer for regularization to prevent overfitting.
157
+
158
+ Args:
159
+ input_channels (int): Number of input channels for the convolutional layer.
160
+ output_channels (int): Number of output channels for the convolutional layer.
161
+ kernel_size (int): Kernel size for the convolutional layer.
162
+ stride (int): Stride for the convolutional operations.
163
+ padding (int): Padding added to all sides of the input for the convolutional layer.
164
+ activation_function (callable): The activation function to use after the convolutional layer.
165
+ include_dropout (bool): Whether to include a dropout layer for regularization.
166
+ """
167
+
168
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
169
+ super(Decoder, self).__init__()
170
+
171
+ # Initialize the convolutional layer with specified parameters.
172
+ self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
173
+
174
+ # Initialize the dropout layer if include_dropout is set to True
175
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
176
+
177
+ def __call__(self, input_tensor, skip=None):
178
+ # Upsample the input tensor to a higher resolution using bilinear interpolation.
179
+ input_tensor = F.interpolate(input_tensor, scale_factor=2, mode="bilinear", align_corners=True)
180
+ # If a skip connection is provided, crop it to match the size of input_tensor and concatenate them along the channel dimension.
181
+ if skip is not None:
182
+ skip = spec_utils.crop_center(skip, input_tensor) # Crop skip_connection to match input_tensor's dimensions.
183
+ input_tensor = torch.cat([input_tensor, skip], dim=1) # Concatenate input_tensor and skip_connection along the channel dimension.
184
+
185
+ # Pass the concatenated tensor (or just input_tensor if no skip_connection is provided) through the convolutional layer.
186
+ output_tensor = self.conv(input_tensor)
187
+
188
+ # If dropout is enabled, apply it to the output of the convolutional layer.
189
+ if self.dropout is not None:
190
+ output_tensor = self.dropout(output_tensor)
191
+
192
+ # Return the final output tensor.
193
+ return output_tensor
194
+
195
+
196
+ class ASPPModule(nn.Module):
197
+ """
198
+ Atrous Spatial Pyramid Pooling (ASPP) Module is designed for capturing multi-scale context by applying
199
+ atrous convolution at multiple rates. This is particularly useful in segmentation tasks where capturing
200
+ objects at various scales is beneficial. The module applies several parallel dilated convolutions with
201
+ different dilation rates to the input feature map, allowing it to efficiently capture information at
202
+ multiple scales.
203
+
204
+ Attributes:
205
+ conv1 (nn.Sequential): Applies adaptive average pooling followed by a 1x1 convolution.
206
+ nn_architecture (int): Identifier for the neural network architecture being used.
207
+ six_layer (list): List containing architecture identifiers that require six layers.
208
+ seven_layer (list): List containing architecture identifiers that require seven layers.
209
+ conv2-conv7 (nn.Module): Convolutional layers with varying dilation rates for multi-scale feature extraction.
210
+ bottleneck (nn.Sequential): A 1x1 convolutional layer that combines all features followed by dropout for regularization.
211
+ """
212
+
213
+ def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
214
+ """
215
+ Initializes the ASPP module with specified parameters.
216
+
217
+ Args:
218
+ nn_architecture (int): Identifier for the neural network architecture.
219
+ input_channels (int): Number of input channels.
220
+ output_channels (int): Number of output channels.
221
+ dilations (tuple): Tuple of dilation rates for the atrous convolutions.
222
+ activation (callable): Activation function to use after convolutional layers.
223
+ """
224
+ super(ASPPModule, self).__init__()
225
+
226
+ # Adaptive average pooling reduces the spatial dimensions to 1x1, focusing on global context,
227
+ # followed by a 1x1 convolution to project back to the desired channel dimension.
228
+ self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ))
229
+
230
+ self.nn_architecture = nn_architecture
231
+ # Architecture identifiers for models requiring additional layers.
232
+ self.six_layer = [129605]
233
+ self.seven_layer = [537238, 537227, 33966]
234
+
235
+ # Extra convolutional layer used for six and seven layer configurations.
236
+ extra_conv = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
237
+
238
+ # Standard 1x1 convolution for channel reduction.
239
+ self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
240
+
241
+ # Separable convolutions with different dilation rates for multi-scale feature extraction.
242
+ self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
243
+ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
244
+ self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
245
+
246
+ # Depending on the architecture, include the extra convolutional layers.
247
+ if self.nn_architecture in self.six_layer:
248
+ self.conv6 = extra_conv
249
+ nin_x = 6
250
+ elif self.nn_architecture in self.seven_layer:
251
+ self.conv6 = extra_conv
252
+ self.conv7 = extra_conv
253
+ nin_x = 7
254
+ else:
255
+ nin_x = 5
256
+
257
+ # Bottleneck layer combines all the multi-scale features into the desired number of output channels.
258
+ self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
259
+
260
+ def forward(self, input_tensor):
261
+ """
262
+ Forward pass of the ASPP module.
263
+
264
+ Args:
265
+ input_tensor (Tensor): Input tensor.
266
+
267
+ Returns:
268
+ Tensor: Output tensor after applying ASPP.
269
+ """
270
+ _, _, h, w = input_tensor.size()
271
+
272
+ # Apply the first convolutional sequence and upsample to the original resolution.
273
+ feat1 = F.interpolate(self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True)
274
+
275
+ # Apply the remaining convolutions directly on the input.
276
+ feat2 = self.conv2(input_tensor)
277
+ feat3 = self.conv3(input_tensor)
278
+ feat4 = self.conv4(input_tensor)
279
+ feat5 = self.conv5(input_tensor)
280
+
281
+ # Concatenate features from all layers. Depending on the architecture, include the extra features.
282
+ if self.nn_architecture in self.six_layer:
283
+ feat6 = self.conv6(input_tensor)
284
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1)
285
+ elif self.nn_architecture in self.seven_layer:
286
+ feat6 = self.conv6(input_tensor)
287
+ feat7 = self.conv7(input_tensor)
288
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
289
+ else:
290
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
291
+
292
+ # Apply the bottleneck layer to combine and reduce the channel dimensions.
293
+ bottleneck_output = self.bottleneck(out)
294
+ return bottleneck_output
audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from audio_separator.separator.uvr_lib_v5 import spec_utils
6
+
7
+
8
+ class Conv2DBNActiv(nn.Module):
9
+ """
10
+ Conv2DBNActiv Class:
11
+ This class implements a convolutional layer followed by batch normalization and an activation function.
12
+ It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks.
13
+ The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation.
14
+ """
15
+
16
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
17
+ super(Conv2DBNActiv, self).__init__()
18
+
19
+ # Sequential model combining Conv2D, BatchNorm, and activation function into a single module
20
+ self.conv = nn.Sequential(nn.Conv2d(nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False), nn.BatchNorm2d(nout), activ())
21
+
22
+ def __call__(self, input_tensor):
23
+ # Forward pass through the sequential model
24
+ return self.conv(input_tensor)
25
+
26
+
27
+ class Encoder(nn.Module):
28
+ """
29
+ Encoder Class:
30
+ This class defines an encoder module typically used in autoencoder architectures.
31
+ It consists of two convolutional layers, each followed by batch normalization and an activation function.
32
+ """
33
+
34
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
35
+ super(Encoder, self).__init__()
36
+
37
+ # First convolutional layer of the encoder
38
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
39
+ # Second convolutional layer of the encoder
40
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
41
+
42
+ def __call__(self, input_tensor):
43
+ # Applying the first and then the second convolutional layers
44
+ hidden = self.conv1(input_tensor)
45
+ hidden = self.conv2(hidden)
46
+
47
+ return hidden
48
+
49
+
50
+ class Decoder(nn.Module):
51
+ """
52
+ Decoder Class:
53
+ This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures.
54
+ It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization.
55
+ """
56
+
57
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
58
+ super(Decoder, self).__init__()
59
+ # Convolutional layer with optional dropout for regularization
60
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
61
+ # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
62
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
63
+
64
+ def __call__(self, input_tensor, skip=None):
65
+ # Forward pass through the convolutional layer and optional dropout
66
+ input_tensor = F.interpolate(input_tensor, scale_factor=2, mode="bilinear", align_corners=True)
67
+
68
+ if skip is not None:
69
+ skip = spec_utils.crop_center(skip, input_tensor)
70
+ input_tensor = torch.cat([input_tensor, skip], dim=1)
71
+
72
+ hidden = self.conv1(input_tensor)
73
+ # hidden = self.conv2(hidden)
74
+
75
+ if self.dropout is not None:
76
+ hidden = self.dropout(hidden)
77
+
78
+ return hidden
79
+
80
+
81
+ class ASPPModule(nn.Module):
82
+ """
83
+ ASPPModule Class:
84
+ This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks.
85
+ It captures multi-scale contextual information by applying convolutions at multiple dilation rates.
86
+ """
87
+
88
+ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
89
+ super(ASPPModule, self).__init__()
90
+
91
+ # Global context convolution captures the overall context
92
+ self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ))
93
+ self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
94
+ self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ)
95
+ self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ)
96
+ self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ)
97
+ self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
98
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
99
+
100
+ def forward(self, input_tensor):
101
+ _, _, h, w = input_tensor.size()
102
+
103
+ # Upsample global context to match input size and combine with local and multi-scale features
104
+ feat1 = F.interpolate(self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True)
105
+ feat2 = self.conv2(input_tensor)
106
+ feat3 = self.conv3(input_tensor)
107
+ feat4 = self.conv4(input_tensor)
108
+ feat5 = self.conv5(input_tensor)
109
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
110
+ out = self.bottleneck(out)
111
+
112
+ if self.dropout is not None:
113
+ out = self.dropout(out)
114
+
115
+ return out
116
+
117
+
118
+ class LSTMModule(nn.Module):
119
+ """
120
+ LSTMModule Class:
121
+ This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling.
122
+ It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing.
123
+ """
124
+
125
+ def __init__(self, nin_conv, nin_lstm, nout_lstm):
126
+ super(LSTMModule, self).__init__()
127
+ # Convolutional layer for initial feature extraction
128
+ self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
129
+
130
+ # Bidirectional LSTM for capturing temporal dynamics
131
+ self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True)
132
+
133
+ # Dense layer for output dimensionality matching
134
+ self.dense = nn.Sequential(nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU())
135
+
136
+ def forward(self, input_tensor):
137
+ N, _, nbins, nframes = input_tensor.size()
138
+
139
+ # Extract features and prepare for LSTM
140
+ hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes
141
+ hidden = hidden.permute(2, 0, 1) # nframes, N, nbins
142
+ hidden, _ = self.lstm(hidden)
143
+
144
+ # Apply dense layer and reshape to match expected output format
145
+ hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
146
+ hidden = hidden.reshape(nframes, N, 1, nbins)
147
+ hidden = hidden.permute(1, 2, 3, 0)
148
+
149
+ return hidden
audio_separator/separator/uvr_lib_v5/vr_network/model_param_init.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ default_param = {}
4
+ default_param["bins"] = -1
5
+ default_param["unstable_bins"] = -1 # training only
6
+ default_param["stable_bins"] = -1 # training only
7
+ default_param["sr"] = 44100
8
+ default_param["pre_filter_start"] = -1
9
+ default_param["pre_filter_stop"] = -1
10
+ default_param["band"] = {}
11
+
12
+ N_BINS = "n_bins"
13
+
14
+
15
+ def int_keys(d):
16
+ """
17
+ Converts string keys that represent integers into actual integer keys in a list.
18
+
19
+ This function is particularly useful when dealing with JSON data that may represent
20
+ integer keys as strings due to the nature of JSON encoding. By converting these keys
21
+ back to integers, it ensures that the data can be used in a manner consistent with
22
+ its original representation, especially in contexts where the distinction between
23
+ string and integer keys is important.
24
+
25
+ Args:
26
+ input_list (list of tuples): A list of (key, value) pairs where keys are strings
27
+ that may represent integers.
28
+
29
+ Returns:
30
+ dict: A dictionary with keys converted to integers where applicable.
31
+ """
32
+ # Initialize an empty dictionary to hold the converted key-value pairs.
33
+ result_dict = {}
34
+ # Iterate through each key-value pair in the input list.
35
+ for key, value in d:
36
+ # Check if the key is a digit (i.e., represents an integer).
37
+ if key.isdigit():
38
+ # Convert the key from a string to an integer.
39
+ key = int(key)
40
+ result_dict[key] = value
41
+ return result_dict
42
+
43
+
44
+ class ModelParameters(object):
45
+ """
46
+ A class to manage model parameters, including loading from a configuration file.
47
+
48
+ Attributes:
49
+ param (dict): Dictionary holding all parameters for the model.
50
+ """
51
+
52
+ def __init__(self, config_path=""):
53
+ """
54
+ Initializes the ModelParameters object by loading parameters from a JSON configuration file.
55
+
56
+ Args:
57
+ config_path (str): Path to the JSON configuration file.
58
+ """
59
+
60
+ # Load parameters from the given configuration file path.
61
+ with open(config_path, "r") as f:
62
+ self.param = json.loads(f.read(), object_pairs_hook=int_keys)
63
+
64
+ # Ensure certain parameters are set to False if not specified in the configuration.
65
+ for k in ["mid_side", "mid_side_b", "mid_side_b2", "stereo_w", "stereo_n", "reverse"]:
66
+ if not k in self.param:
67
+ self.param[k] = False
68
+
69
+ # If 'n_bins' is specified in the parameters, it's used as the value for 'bins'.
70
+ if N_BINS in self.param:
71
+ self.param["bins"] = self.param[N_BINS]
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 16000,
8
+ "hl": 512,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 16000,
17
+ "pre_filter_start": 1023,
18
+ "pre_filter_stop": 1024
19
+ }
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 32000,
8
+ "hl": 512,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "kaiser_fast"
14
+ }
15
+ },
16
+ "sr": 32000,
17
+ "pre_filter_start": 1000,
18
+ "pre_filter_stop": 1021
19
+ }
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 33075,
8
+ "hl": 384,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 33075,
17
+ "pre_filter_start": 1000,
18
+ "pre_filter_stop": 1021
19
+ }