Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload separate.py
Browse files- separate.py +355 -872
 
    	
        separate.py
    CHANGED
    
    | 
         @@ -6,15 +6,13 @@ from demucs.model_v2 import auto_load_demucs_model_v2 
     | 
|
| 6 | 
         
             
            from demucs.pretrained import get_model as _gm
         
     | 
| 7 | 
         
             
            from demucs.utils import apply_model_v1
         
     | 
| 8 | 
         
             
            from demucs.utils import apply_model_v2
         
     | 
| 9 | 
         
            -
            from lib_v5.tfc_tdf_v3 import TFC_TDF_net, STFT
         
     | 
| 10 | 
         
             
            from lib_v5 import spec_utils
         
     | 
| 11 | 
         
             
            from lib_v5.vr_network import nets
         
     | 
| 12 | 
         
             
            from lib_v5.vr_network import nets_new
         
     | 
| 13 | 
         
            -
            from lib_v5.vr_network.model_param_init import ModelParameters
         
     | 
| 14 | 
         
             
            from pathlib import Path
         
     | 
| 15 | 
         
             
            from gui_data.constants import *
         
     | 
| 16 | 
         
             
            from gui_data.error_handling import *
         
     | 
| 17 | 
         
            -
            from scipy import signal
         
     | 
| 18 | 
         
             
            import audioread
         
     | 
| 19 | 
         
             
            import gzip
         
     | 
| 20 | 
         
             
            import librosa
         
     | 
| 
         @@ -26,85 +24,31 @@ import torch 
     | 
|
| 26 | 
         
             
            import warnings
         
     | 
| 27 | 
         
             
            import pydub
         
     | 
| 28 | 
         
             
            import soundfile as sf
         
     | 
| 
         | 
|
| 29 | 
         
             
            import lib_v5.mdxnet as MdxnetSet
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
            #import random
         
     | 
| 32 | 
         
            -
            from onnx import load
         
     | 
| 33 | 
         
            -
            from onnx2pytorch import ConvertModel
         
     | 
| 34 | 
         
            -
            import gc
         
     | 
| 35 | 
         
            -
             
         
     | 
| 36 | 
         
             
            if TYPE_CHECKING:
         
     | 
| 37 | 
         
             
                from UVR import ModelData
         
     | 
| 38 | 
         | 
| 39 | 
         
            -
            # if not is_macos:
         
     | 
| 40 | 
         
            -
            #     import torch_directml
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
            mps_available = torch.backends.mps.is_available() if is_macos else False
         
     | 
| 43 | 
         
            -
            cuda_available = torch.cuda.is_available()
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
            # def get_gpu_info():
         
     | 
| 46 | 
         
            -
            #     directml_device, directml_available = DIRECTML_DEVICE, False
         
     | 
| 47 | 
         
            -
                
         
     | 
| 48 | 
         
            -
            #     if not is_macos:
         
     | 
| 49 | 
         
            -
            #         directml_available = torch_directml.is_available()
         
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
            #         if directml_available:
         
     | 
| 52 | 
         
            -
            #             directml_device = str(torch_directml.device()).partition(":")[0]
         
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
            #     return directml_device, directml_available
         
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
            # DIRECTML_DEVICE, directml_available = get_gpu_info()
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
            def clear_gpu_cache():
         
     | 
| 59 | 
         
            -
                gc.collect()
         
     | 
| 60 | 
         
            -
                if is_macos:
         
     | 
| 61 | 
         
            -
                    torch.mps.empty_cache()
         
     | 
| 62 | 
         
            -
                else:
         
     | 
| 63 | 
         
            -
                    torch.cuda.empty_cache()
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
             
            warnings.filterwarnings("ignore")
         
     | 
| 66 | 
         
             
            cpu = torch.device('cpu')
         
     | 
| 67 | 
         | 
| 68 | 
         
             
            class SeperateAttributes:
         
     | 
| 69 | 
         
            -
                def __init__(self, model_data: ModelData, 
         
     | 
| 70 | 
         
            -
                             process_data: dict, 
         
     | 
| 71 | 
         
            -
                             main_model_primary_stem_4_stem=None, 
         
     | 
| 72 | 
         
            -
                             main_process_method=None, 
         
     | 
| 73 | 
         
            -
                             is_return_dual=True, 
         
     | 
| 74 | 
         
            -
                             main_model_primary=None, 
         
     | 
| 75 | 
         
            -
                             vocal_stem_path=None, 
         
     | 
| 76 | 
         
            -
                             master_inst_source=None,
         
     | 
| 77 | 
         
            -
                             master_vocal_source=None):
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                    self.list_all_models: list
         
     | 
| 80 | 
         
             
                    self.process_data = process_data
         
     | 
| 81 | 
         
             
                    self.progress_value = 0
         
     | 
| 82 | 
         
             
                    self.set_progress_bar = process_data['set_progress_bar']
         
     | 
| 83 | 
         
             
                    self.write_to_console = process_data['write_to_console']
         
     | 
| 84 | 
         
            -
                     
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
                        self.audio_file_base_voc_split = lambda stem, split:os.path.join(self.export_path, f'{self.audio_file_base.replace("_(Vocals)", "")}_({stem}_{split}).wav')
         
     | 
| 87 | 
         
            -
                    else:
         
     | 
| 88 | 
         
            -
                        self.audio_file = process_data['audio_file']
         
     | 
| 89 | 
         
            -
                        self.audio_file_base = process_data['audio_file_base']
         
     | 
| 90 | 
         
            -
                        self.audio_file_base_voc_split = None
         
     | 
| 91 | 
         
             
                    self.export_path = process_data['export_path']
         
     | 
| 92 | 
         
             
                    self.cached_source_callback = process_data['cached_source_callback']
         
     | 
| 93 | 
         
             
                    self.cached_model_source_holder = process_data['cached_model_source_holder']
         
     | 
| 94 | 
         
             
                    self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
         
     | 
| 95 | 
         
             
                    self.list_all_models = process_data['list_all_models']
         
     | 
| 96 | 
         
             
                    self.process_iteration = process_data['process_iteration']
         
     | 
| 97 | 
         
            -
                    self.is_return_dual = is_return_dual
         
     | 
| 98 | 
         
            -
                    self.is_pitch_change = model_data.is_pitch_change
         
     | 
| 99 | 
         
            -
                    self.semitone_shift = model_data.semitone_shift
         
     | 
| 100 | 
         
            -
                    self.is_match_frequency_pitch = model_data.is_match_frequency_pitch
         
     | 
| 101 | 
         
            -
                    self.overlap = model_data.overlap
         
     | 
| 102 | 
         
            -
                    self.overlap_mdx = model_data.overlap_mdx
         
     | 
| 103 | 
         
            -
                    self.overlap_mdx23 = model_data.overlap_mdx23
         
     | 
| 104 | 
         
            -
                    self.is_mdx_combine_stems = model_data.is_mdx_combine_stems
         
     | 
| 105 | 
         
            -
                    self.is_mdx_c = model_data.is_mdx_c
         
     | 
| 106 | 
         
            -
                    self.mdx_c_configs = model_data.mdx_c_configs
         
     | 
| 107 | 
         
            -
                    self.mdxnet_stem_select = model_data.mdxnet_stem_select
         
     | 
| 108 | 
         
             
                    self.mixer_path = model_data.mixer_path
         
     | 
| 109 | 
         
             
                    self.model_samplerate = model_data.model_samplerate
         
     | 
| 110 | 
         
             
                    self.model_capacity = model_data.model_capacity
         
     | 
| 
         @@ -126,11 +70,9 @@ class SeperateAttributes: 
     | 
|
| 126 | 
         
             
                    self.is_ensemble_mode = model_data.is_ensemble_mode
         
     | 
| 127 | 
         
             
                    self.secondary_model = model_data.secondary_model #
         
     | 
| 128 | 
         
             
                    self.primary_model_primary_stem = model_data.primary_model_primary_stem
         
     | 
| 129 | 
         
            -
                    self.primary_stem_native = model_data.primary_stem_native
         
     | 
| 130 | 
         
             
                    self.primary_stem = model_data.primary_stem #
         
     | 
| 131 | 
         
             
                    self.secondary_stem = model_data.secondary_stem #
         
     | 
| 132 | 
         
             
                    self.is_invert_spec = model_data.is_invert_spec #
         
     | 
| 133 | 
         
            -
                    self.is_deverb_vocals = model_data.is_deverb_vocals
         
     | 
| 134 | 
         
             
                    self.is_mixer_mode = model_data.is_mixer_mode #
         
     | 
| 135 | 
         
             
                    self.secondary_model_scale = model_data.secondary_model_scale #
         
     | 
| 136 | 
         
             
                    self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
         
     | 
| 
         @@ -140,87 +82,49 @@ class SeperateAttributes: 
     | 
|
| 140 | 
         
             
                    self.secondary_source = None
         
     | 
| 141 | 
         
             
                    self.secondary_source_primary = None
         
     | 
| 142 | 
         
             
                    self.secondary_source_secondary = None
         
     | 
| 143 | 
         
            -
                    self.main_model_primary_stem_4_stem = main_model_primary_stem_4_stem
         
     | 
| 144 | 
         
            -
                    self.main_model_primary = main_model_primary
         
     | 
| 145 | 
         
            -
                    self.ensemble_primary_stem = model_data.ensemble_primary_stem
         
     | 
| 146 | 
         
            -
                    self.is_multi_stem_ensemble = model_data.is_multi_stem_ensemble
         
     | 
| 147 | 
         
            -
                    self.is_other_gpu = False
         
     | 
| 148 | 
         
            -
                    self.is_deverb = True
         
     | 
| 149 | 
         
            -
                    self.DENOISER_MODEL = model_data.DENOISER_MODEL
         
     | 
| 150 | 
         
            -
                    self.DEVERBER_MODEL = model_data.DEVERBER_MODEL
         
     | 
| 151 | 
         
            -
                    self.is_source_swap = False
         
     | 
| 152 | 
         
            -
                    self.vocal_split_model = model_data.vocal_split_model
         
     | 
| 153 | 
         
            -
                    self.is_vocal_split_model = model_data.is_vocal_split_model
         
     | 
| 154 | 
         
            -
                    self.master_vocal_path = None
         
     | 
| 155 | 
         
            -
                    self.set_master_inst_source = None
         
     | 
| 156 | 
         
            -
                    self.master_inst_source = master_inst_source
         
     | 
| 157 | 
         
            -
                    self.master_vocal_source = master_vocal_source
         
     | 
| 158 | 
         
            -
                    self.is_save_inst_vocal_splitter = isinstance(master_inst_source, np.ndarray) and model_data.is_save_inst_vocal_splitter
         
     | 
| 159 | 
         
            -
                    self.is_inst_only_voc_splitter = model_data.is_inst_only_voc_splitter
         
     | 
| 160 | 
         
            -
                    self.is_karaoke = model_data.is_karaoke
         
     | 
| 161 | 
         
            -
                    self.is_bv_model = model_data.is_bv_model
         
     | 
| 162 | 
         
            -
                    self.is_bv_model_rebalenced = model_data.bv_model_rebalance and self.is_vocal_split_model
         
     | 
| 163 | 
         
            -
                    self.is_sec_bv_rebalance = model_data.is_sec_bv_rebalance
         
     | 
| 164 | 
         
            -
                    self.stem_path_init = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
         
     | 
| 165 | 
         
            -
                    self.deverb_vocal_opt = model_data.deverb_vocal_opt
         
     | 
| 166 | 
         
            -
                    self.is_save_vocal_only = model_data.is_save_vocal_only
         
     | 
| 167 | 
         
            -
                    self.device = cpu
         
     | 
| 168 | 
         
            -
                    self.run_type = ['CPUExecutionProvider']
         
     | 
| 169 | 
         
            -
                    self.is_opencl = False
         
     | 
| 170 | 
         
            -
                    self.device_set = model_data.device_set
         
     | 
| 171 | 
         
            -
                    self.is_use_opencl = model_data.is_use_opencl
         
     | 
| 172 | 
         
            -
                    
         
     | 
| 173 | 
         
            -
                    if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance:
         
     | 
| 174 | 
         
            -
                        self.is_primary_stem_only = False
         
     | 
| 175 | 
         
            -
                        self.is_secondary_stem_only = False
         
     | 
| 176 | 
         
            -
                    
         
     | 
| 177 | 
         
            -
                    if main_model_primary and self.is_multi_stem_ensemble:
         
     | 
| 178 | 
         
            -
                        self.primary_stem, self.secondary_stem = main_model_primary, secondary_stem(main_model_primary)
         
     | 
| 179 | 
         | 
| 180 | 
         
            -
                    if  
     | 
| 181 | 
         
            -
                        if  
     | 
| 182 | 
         
            -
                             
     | 
| 183 | 
         
            -
             
     | 
| 184 | 
         
            -
             
     | 
| 185 | 
         
            -
             
     | 
| 186 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 187 | 
         | 
| 188 | 
         
            -
             
     | 
| 189 | 
         
            -
                             
     | 
| 190 | 
         
            -
                             
     | 
| 191 | 
         
            -
                            if cuda_available:# and not self.is_use_opencl:
         
     | 
| 192 | 
         
            -
                                self.device = CUDA_DEVICE if not device_prefix else f'{device_prefix}:{self.device_set}'
         
     | 
| 193 | 
         
            -
                                self.run_type = ['CUDAExecutionProvider']
         
     | 
| 194 | 
         | 
| 195 | 
         
             
                    if model_data.process_method == MDX_ARCH_TYPE:
         
     | 
| 196 | 
         
             
                        self.is_mdx_ckpt = model_data.is_mdx_ckpt
         
     | 
| 197 | 
         
             
                        self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
         
     | 
| 198 | 
         
            -
                        self.is_denoise = model_data.is_denoise 
     | 
| 199 | 
         
            -
                        self.is_denoise_model = model_data.is_denoise_model#
         
     | 
| 200 | 
         
            -
                        self.is_mdx_c_seg_def = model_data.is_mdx_c_seg_def#
         
     | 
| 201 | 
         
             
                        self.mdx_batch_size = model_data.mdx_batch_size
         
     | 
| 202 | 
         
             
                        self.compensate = model_data.compensate
         
     | 
| 203 | 
         
            -
                        self. 
     | 
| 204 | 
         
            -
                        
         
     | 
| 205 | 
         
            -
                        if self.is_mdx_c:
         
     | 
| 206 | 
         
            -
                            if not self.is_4_stem_ensemble:
         
     | 
| 207 | 
         
            -
                                self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
         
     | 
| 208 | 
         
            -
                                self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
         
     | 
| 209 | 
         
            -
                        else:
         
     | 
| 210 | 
         
            -
                            self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
         
     | 
| 211 | 
         
            -
                            
         
     | 
| 212 | 
         
            -
                        self.check_label_secondary_stem_runs()
         
     | 
| 213 | 
         
             
                        self.n_fft = model_data.mdx_n_fft_scale_set
         
     | 
| 214 | 
         
             
                        self.chunks = model_data.chunks
         
     | 
| 215 | 
         
             
                        self.margin = model_data.margin
         
     | 
| 216 | 
         
             
                        self.adjust = 1
         
     | 
| 217 | 
         
             
                        self.dim_c = 4
         
     | 
| 218 | 
         
             
                        self.hop = 1024
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 219 | 
         | 
| 220 | 
         
             
                    if model_data.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 221 | 
         
             
                        self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
         
     | 
| 222 | 
         
             
                        self.secondary_model_4_stem = model_data.secondary_model_4_stem
         
     | 
| 223 | 
         
             
                        self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale
         
     | 
| 
         | 
|
| 
         | 
|
| 224 | 
         
             
                        self.is_chunk_demucs = model_data.is_chunk_demucs
         
     | 
| 225 | 
         
             
                        self.segment = model_data.segment
         
     | 
| 226 | 
         
             
                        self.demucs_version = model_data.demucs_version
         
     | 
| 
         @@ -229,37 +133,28 @@ class SeperateAttributes: 
     | 
|
| 229 | 
         
             
                        self.is_demucs_combine_stems = model_data.is_demucs_combine_stems
         
     | 
| 230 | 
         
             
                        self.demucs_stem_count = model_data.demucs_stem_count
         
     | 
| 231 | 
         
             
                        self.pre_proc_model = model_data.pre_proc_model
         
     | 
| 232 | 
         
            -
                        self.device = cpu if self.is_other_gpu and not self.demucs_version in [DEMUCS_V3, DEMUCS_V4] else self.device
         
     | 
| 233 | 
         
            -
             
     | 
| 234 | 
         
            -
                        self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
         
     | 
| 235 | 
         
            -
                        self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
         
     | 
| 236 | 
         
            -
             
     | 
| 237 | 
         
            -
                        if (self.is_multi_stem_ensemble or self.is_4_stem_ensemble) and not self.is_secondary_model:
         
     | 
| 238 | 
         
            -
                            self.is_return_dual = False
         
     | 
| 239 | 
         | 
| 240 | 
         
            -
                        if self.is_multi_stem_ensemble and main_model_primary:
         
     | 
| 241 | 
         
            -
                            self.is_4_stem_ensemble = False
         
     | 
| 242 | 
         
            -
                            if main_model_primary in self.demucs_source_map.keys():
         
     | 
| 243 | 
         
            -
                                self.primary_stem = main_model_primary
         
     | 
| 244 | 
         
            -
                                self.secondary_stem = secondary_stem(main_model_primary)
         
     | 
| 245 | 
         
            -
                            elif secondary_stem(main_model_primary) in self.demucs_source_map.keys():
         
     | 
| 246 | 
         
            -
                                self.primary_stem = secondary_stem(main_model_primary)
         
     | 
| 247 | 
         
            -
                                self.secondary_stem = main_model_primary
         
     | 
| 248 | 
         
            -
             
     | 
| 249 | 
         
             
                        if self.is_secondary_model and not process_data['is_ensemble_master']:
         
     | 
| 250 | 
         
             
                            if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM:
         
     | 
| 251 | 
         
             
                                self.primary_stem = VOCAL_STEM
         
     | 
| 252 | 
         
             
                                self.secondary_stem = INST_STEM
         
     | 
| 253 | 
         
             
                            else:
         
     | 
| 254 | 
         
             
                                self.primary_stem = model_data.primary_model_primary_stem
         
     | 
| 255 | 
         
            -
                                self.secondary_stem =  
     | 
| 256 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 257 | 
         
             
                        self.shifts = model_data.shifts
         
     | 
| 258 | 
         
             
                        self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True
         
     | 
| 
         | 
|
| 259 | 
         
             
                        self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename)
         
     | 
| 260 | 
         | 
| 261 | 
         
             
                    if model_data.process_method == VR_ARCH_TYPE:
         
     | 
| 262 | 
         
            -
                        self.check_label_secondary_stem_runs()
         
     | 
| 263 | 
         
             
                        self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename)
         
     | 
| 264 | 
         
             
                        self.mp = model_data.vr_model_param
         
     | 
| 265 | 
         
             
                        self.high_end_process = model_data.is_high_end_process
         
     | 
| 
         @@ -269,44 +164,28 @@ class SeperateAttributes: 
     | 
|
| 269 | 
         
             
                        self.batch_size = model_data.batch_size
         
     | 
| 270 | 
         
             
                        self.window_size = model_data.window_size
         
     | 
| 271 | 
         
             
                        self.input_high_end_h = None
         
     | 
| 272 | 
         
            -
                        self.input_high_end = None
         
     | 
| 273 | 
         
             
                        self.post_process_threshold = model_data.post_process_threshold
         
     | 
| 274 | 
         
             
                        self.aggressiveness = {'value': model_data.aggression_setting, 
         
     | 
| 275 | 
         
             
                                               'split_bin': self.mp.param['band'][1]['crop_stop'], 
         
     | 
| 276 | 
         
             
                                               'aggr_correction': self.mp.param.get('aggr_correction')}
         
     | 
| 277 | 
         | 
| 278 | 
         
            -
                def check_label_secondary_stem_runs(self):
         
     | 
| 279 | 
         
            -
             
     | 
| 280 | 
         
            -
                    # For ensemble master that's not a 4-stem ensemble, and not mdx_c
         
     | 
| 281 | 
         
            -
                    if self.process_data['is_ensemble_master'] and not self.is_4_stem_ensemble and not self.is_mdx_c:
         
     | 
| 282 | 
         
            -
                        if self.ensemble_primary_stem != self.primary_stem:
         
     | 
| 283 | 
         
            -
                            self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
         
     | 
| 284 | 
         
            -
                        
         
     | 
| 285 | 
         
            -
                    # For secondary models
         
     | 
| 286 | 
         
            -
                    if self.is_pre_proc_model or self.is_secondary_model:
         
     | 
| 287 | 
         
            -
                        self.is_primary_stem_only = False
         
     | 
| 288 | 
         
            -
                        self.is_secondary_stem_only = False
         
     | 
| 289 | 
         
            -
                        
         
     | 
| 290 | 
         
             
                def start_inference_console_write(self):
         
     | 
| 291 | 
         
            -
                     
     | 
| 
         | 
|
| 292 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
         
     | 
| 293 | 
         | 
| 294 | 
         
             
                    if self.is_pre_proc_model:
         
     | 
| 295 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
         
     | 
| 296 | 
         
            -
                        
         
     | 
| 297 | 
         
            -
                    if self.is_vocal_split_model:
         
     | 
| 298 | 
         
            -
                        self.write_to_console(INFERENCE_STEP_2_VOC_S(self.process_method, self.model_basename))
         
     | 
| 299 | 
         | 
| 300 | 
         
             
                def running_inference_console_write(self, is_no_write=False):
         
     | 
| 
         | 
|
| 301 | 
         
             
                    self.write_to_console(DONE, base_text='') if not is_no_write else None
         
     | 
| 302 | 
         
             
                    self.set_progress_bar(0.05) if not is_no_write else None
         
     | 
| 303 | 
         | 
| 304 | 
         
            -
                    if self.is_secondary_model and not self.is_pre_proc_model 
     | 
| 305 | 
         
             
                        self.write_to_console(INFERENCE_STEP_1_SEC)
         
     | 
| 306 | 
         
             
                    elif self.is_pre_proc_model:
         
     | 
| 307 | 
         
             
                        self.write_to_console(INFERENCE_STEP_1_PRE)
         
     | 
| 308 | 
         
            -
                    elif self.is_vocal_split_model:
         
     | 
| 309 | 
         
            -
                        self.write_to_console(INFERENCE_STEP_1_VOC_S)
         
     | 
| 310 | 
         
             
                    else:
         
     | 
| 311 | 
         
             
                        self.write_to_console(INFERENCE_STEP_1)
         
     | 
| 312 | 
         | 
| 
         @@ -319,14 +198,19 @@ class SeperateAttributes: 
     | 
|
| 319 | 
         | 
| 320 | 
         
             
                        self.set_progress_bar(0.1, (0.8/length*self.progress_value))
         
     | 
| 321 | 
         | 
| 322 | 
         
            -
                def load_cached_sources(self):
         
     | 
| 323 | 
         | 
| 324 | 
         
             
                    if self.is_secondary_model and not self.is_pre_proc_model:
         
     | 
| 325 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename))
         
     | 
| 326 | 
         
             
                    elif self.is_pre_proc_model:
         
     | 
| 327 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename))
         
     | 
| 328 | 
         
             
                    else:
         
     | 
| 329 | 
         
            -
                        self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 330 | 
         | 
| 331 | 
         
             
                def cache_source(self, secondary_sources):
         
     | 
| 332 | 
         | 
| 
         @@ -341,142 +225,49 @@ class SeperateAttributes: 
     | 
|
| 341 | 
         | 
| 342 | 
         
             
                        if self.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 343 | 
         
             
                            self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename)
         
     | 
| 344 | 
         
            -
             
     | 
| 345 | 
         
            -
                def  
     | 
| 346 | 
         
            -
             
     | 
| 347 | 
         
            -
                    def is_valid_vocal_split_condition(master_vocal_source):
         
     | 
| 348 | 
         
            -
                        """Checks if conditions for vocal split processing are met."""
         
     | 
| 349 | 
         
            -
                        conditions = [
         
     | 
| 350 | 
         
            -
                            isinstance(master_vocal_source, np.ndarray),
         
     | 
| 351 | 
         
            -
                            self.vocal_split_model,
         
     | 
| 352 | 
         
            -
                            not self.is_ensemble_mode,
         
     | 
| 353 | 
         
            -
                            not self.is_karaoke,
         
     | 
| 354 | 
         
            -
                            not self.is_bv_model
         
     | 
| 355 | 
         
            -
                        ]
         
     | 
| 356 | 
         
            -
                        return all(conditions)
         
     | 
| 357 | 
         
            -
                    
         
     | 
| 358 | 
         
            -
                    # Retrieve sources from the dictionary with default fallbacks
         
     | 
| 359 | 
         
            -
                    master_inst_source = sources.get(INST_STEM, None)
         
     | 
| 360 | 
         
            -
                    master_vocal_source = sources.get(VOCAL_STEM, None)
         
     | 
| 361 | 
         
            -
             
     | 
| 362 | 
         
            -
                    # Process the vocal split chain if conditions are met
         
     | 
| 363 | 
         
            -
                    if is_valid_vocal_split_condition(master_vocal_source):
         
     | 
| 364 | 
         
            -
                        process_chain_model(
         
     | 
| 365 | 
         
            -
                            self.vocal_split_model,
         
     | 
| 366 | 
         
            -
                            self.process_data,
         
     | 
| 367 | 
         
            -
                            vocal_stem_path=self.master_vocal_path,
         
     | 
| 368 | 
         
            -
                            master_vocal_source=master_vocal_source,
         
     | 
| 369 | 
         
            -
                            master_inst_source=master_inst_source
         
     | 
| 370 | 
         
            -
                        )
         
     | 
| 371 | 
         
            -
              
         
     | 
| 372 | 
         
            -
                def process_secondary_stem(self, stem_source, secondary_model_source=None, model_scale=None):
         
     | 
| 373 | 
         
             
                    if not self.is_secondary_model:
         
     | 
| 374 | 
         
            -
                        if self.is_secondary_model_activated 
     | 
| 375 | 
         
            -
                             
     | 
| 376 | 
         
            -
             
     | 
| 377 | 
         
            -
             
     | 
| 378 | 
         
            -
                    return stem_source
         
     | 
| 379 | 
         
            -
                
         
     | 
| 380 | 
         
            -
                def final_process(self, stem_path, source, secondary_source, stem_name, samplerate):
         
     | 
| 381 | 
         
            -
                    source = self.process_secondary_stem(source, secondary_source)
         
     | 
| 382 | 
         
            -
                    self.write_audio(stem_path, source, samplerate, stem_name=stem_name)
         
     | 
| 383 | 
         
            -
                    
         
     | 
| 384 | 
         
            -
                    return {stem_name: source}
         
     | 
| 385 | 
         
            -
                
         
     | 
| 386 | 
         
            -
                def write_audio(self, stem_path: str, stem_source, samplerate, stem_name=None):
         
     | 
| 387 | 
         
            -
                    
         
     | 
| 388 | 
         
            -
                    def save_audio_file(path, source):
         
     | 
| 389 | 
         
            -
                        source = spec_utils.normalize(source, self.is_normalization)
         
     | 
| 390 | 
         
            -
                        sf.write(path, source, samplerate, subtype=self.wav_type_set)
         
     | 
| 391 | 
         
            -
             
     | 
| 392 | 
         
            -
                        if is_not_ensemble:
         
     | 
| 393 | 
         
            -
                            save_format(path, self.save_format, self.mp3_bit_set)
         
     | 
| 394 | 
         
            -
             
     | 
| 395 | 
         
            -
                    def save_voc_split_instrumental(stem_name, stem_source, is_inst_invert=False):
         
     | 
| 396 | 
         
            -
                        inst_stem_name = "Instrumental (With Lead Vocals)" if stem_name == LEAD_VOCAL_STEM else "Instrumental (With Backing Vocals)"
         
     | 
| 397 | 
         
            -
                        inst_stem_path_name = LEAD_VOCAL_STEM_I if stem_name == LEAD_VOCAL_STEM else BV_VOCAL_STEM_I
         
     | 
| 398 | 
         
            -
                        inst_stem_path = self.audio_file_base_voc_split(INST_STEM, inst_stem_path_name)
         
     | 
| 399 | 
         
            -
                        stem_source = -stem_source if is_inst_invert else stem_source
         
     | 
| 400 | 
         
            -
                        inst_stem_source = spec_utils.combine_arrarys([self.master_inst_source, stem_source], is_swap=True)
         
     | 
| 401 | 
         
            -
                        save_with_message(inst_stem_path, inst_stem_name, inst_stem_source)
         
     | 
| 402 | 
         
            -
             
     | 
| 403 | 
         
            -
                    def save_voc_split_vocal(stem_name, stem_source):
         
     | 
| 404 | 
         
            -
                        voc_split_stem_name = LEAD_VOCAL_STEM_LABEL if stem_name == LEAD_VOCAL_STEM else BV_VOCAL_STEM_LABEL
         
     | 
| 405 | 
         
            -
                        voc_split_stem_path = self.audio_file_base_voc_split(VOCAL_STEM, stem_name)
         
     | 
| 406 | 
         
            -
                        save_with_message(voc_split_stem_path, voc_split_stem_name, stem_source)
         
     | 
| 407 | 
         
            -
             
     | 
| 408 | 
         
            -
                    def save_with_message(stem_path, stem_name, stem_source):
         
     | 
| 409 | 
         
            -
                        is_deverb = self.is_deverb_vocals and (
         
     | 
| 410 | 
         
            -
                            self.deverb_vocal_opt == stem_name or
         
     | 
| 411 | 
         
            -
                            (self.deverb_vocal_opt == 'ALL' and 
         
     | 
| 412 | 
         
            -
                            (stem_name == VOCAL_STEM or stem_name == LEAD_VOCAL_STEM_LABEL or stem_name == BV_VOCAL_STEM_LABEL)))
         
     | 
| 413 | 
         
            -
             
     | 
| 414 | 
         
            -
                        self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}')
         
     | 
| 415 | 
         | 
| 416 | 
         
            -
                         
     | 
| 417 | 
         
            -
             
     | 
| 418 | 
         | 
| 419 | 
         
            -
                        save_audio_file(stem_path, stem_source)
         
     | 
| 420 | 
         
             
                        self.write_to_console(DONE, base_text='')
         
     | 
| 421 | 
         
            -
                        
         
     | 
| 422 | 
         
            -
                    def deverb_vocals(stem_path:str, stem_source):
         
     | 
| 423 | 
         
            -
                        self.write_to_console(INFERENCE_STEP_DEVERBING, base_text='')
         
     | 
| 424 | 
         
            -
                        stem_source_deverbed, stem_source_2 = vr_denoiser(stem_source, self.device, is_deverber=True, model_path=self.DEVERBER_MODEL)
         
     | 
| 425 | 
         
            -
                        save_audio_file(stem_path.replace(".wav", "_deverbed.wav"), stem_source_deverbed)
         
     | 
| 426 | 
         
            -
                        save_audio_file(stem_path.replace(".wav", "_reverb_only.wav"), stem_source_2)
         
     | 
| 427 | 
         
            -
                        
         
     | 
| 428 | 
         
            -
                    is_bv_model_lead = (self.is_bv_model_rebalenced and self.is_vocal_split_model and stem_name == LEAD_VOCAL_STEM)
         
     | 
| 429 | 
         
            -
                    is_bv_rebalance_lead = (self.is_bv_model_rebalenced and self.is_vocal_split_model and stem_name == BV_VOCAL_STEM)
         
     | 
| 430 | 
         
            -
                    is_no_vocal_save = self.is_inst_only_voc_splitter and (stem_name == VOCAL_STEM or stem_name == BV_VOCAL_STEM or stem_name == LEAD_VOCAL_STEM) or is_bv_model_lead
         
     | 
| 431 | 
         
            -
                    is_not_ensemble = (not self.is_ensemble_mode or self.is_vocal_split_model)
         
     | 
| 432 | 
         
            -
                    is_do_not_save_inst = (self.is_save_vocal_only and self.is_sec_bv_rebalance and stem_name == INST_STEM)
         
     | 
| 433 | 
         
            -
             
     | 
| 434 | 
         
            -
                    if is_bv_rebalance_lead:
         
     | 
| 435 | 
         
            -
                        master_voc_source = spec_utils.match_array_shapes(self.master_vocal_source, stem_source, is_swap=True)
         
     | 
| 436 | 
         
            -
                        bv_rebalance_lead_source = stem_source-master_voc_source
         
     | 
| 437 | 
         
            -
                        
         
     | 
| 438 | 
         
            -
                    if not is_bv_model_lead and not is_do_not_save_inst:
         
     | 
| 439 | 
         
            -
                        if self.is_vocal_split_model or not self.is_secondary_model:
         
     | 
| 440 | 
         
            -
                            if self.is_vocal_split_model and not self.is_inst_only_voc_splitter:
         
     | 
| 441 | 
         
            -
                                save_voc_split_vocal(stem_name, stem_source)
         
     | 
| 442 | 
         
            -
                                if is_bv_rebalance_lead:
         
     | 
| 443 | 
         
            -
                                    save_voc_split_vocal(LEAD_VOCAL_STEM, bv_rebalance_lead_source)
         
     | 
| 444 | 
         
            -
                            else:
         
     | 
| 445 | 
         
            -
                                if not is_no_vocal_save:
         
     | 
| 446 | 
         
            -
                                    save_with_message(stem_path, stem_name, stem_source)
         
     | 
| 447 | 
         
            -
                                
         
     | 
| 448 | 
         
            -
                            if self.is_save_inst_vocal_splitter and not self.is_save_vocal_only:
         
     | 
| 449 | 
         
            -
                                save_voc_split_instrumental(stem_name, stem_source)
         
     | 
| 450 | 
         
            -
                                if is_bv_rebalance_lead:
         
     | 
| 451 | 
         
            -
                                    save_voc_split_instrumental(LEAD_VOCAL_STEM, bv_rebalance_lead_source, is_inst_invert=True)
         
     | 
| 452 | 
         
            -
             
     | 
| 453 | 
         
            -
                            self.set_progress_bar(0.95)
         
     | 
| 454 | 
         | 
| 455 | 
         
            -
             
     | 
| 456 | 
         
            -
             
     | 
| 457 | 
         
            -
             
     | 
| 458 | 
         
            -
             
     | 
| 459 | 
         
            -
             
     | 
| 460 | 
         
            -
             
     | 
| 461 | 
         
            -
             
     | 
| 462 | 
         
            -
             
     | 
| 463 | 
         
            -
             
     | 
| 464 | 
         
            -
             
     | 
| 465 | 
         
            -
             
     | 
| 466 | 
         
            -
             
     | 
| 467 | 
         
            -
             
     | 
| 468 | 
         
            -
                         
     | 
| 469 | 
         
            -
             
     | 
| 470 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 471 | 
         | 
| 472 | 
         
             
            class SeperateMDX(SeperateAttributes):        
         
     | 
| 473 | 
         | 
| 474 | 
         
             
                def seperate(self):
         
     | 
| 475 | 
         
             
                    samplerate = 44100
         
     | 
| 476 | 
         
            -
             
     | 
| 477 | 
         
            -
                    if self.primary_model_name == self.model_basename and  
     | 
| 478 | 
         
            -
                         
     | 
| 479 | 
         
            -
                        self.load_cached_sources()
         
     | 
| 480 | 
         
             
                    else:
         
     | 
| 481 | 
         
             
                        self.start_inference_console_write()
         
     | 
| 482 | 
         | 
| 
         @@ -486,145 +277,105 @@ class SeperateMDX(SeperateAttributes): 
     | 
|
| 486 | 
         
             
                            separator = MdxnetSet.ConvTDFNet(**model_params)
         
     | 
| 487 | 
         
             
                            self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval()
         
     | 
| 488 | 
         
             
                        else:
         
     | 
| 489 | 
         
            -
                             
     | 
| 490 | 
         
            -
             
     | 
| 491 | 
         
            -
                                self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0]
         
     | 
| 492 | 
         
            -
                            else:
         
     | 
| 493 | 
         
            -
                                self.model_run = ConvertModel(load(self.model_path))
         
     | 
| 494 | 
         
            -
                                self.model_run.to(self.device).eval()
         
     | 
| 495 | 
         | 
| 
         | 
|
| 496 | 
         
             
                        self.running_inference_console_write()
         
     | 
| 497 | 
         
            -
                         
     | 
| 498 | 
         
            -
                        
         
     | 
| 499 | 
         
            -
                        source = self. 
     | 
| 500 | 
         
            -
                        
         
     | 
| 501 | 
         
            -
                        if not self.is_vocal_split_model:
         
     | 
| 502 | 
         
            -
                            self.cache_source((mix, source))
         
     | 
| 503 | 
         
             
                        self.write_to_console(DONE, base_text='')            
         
     | 
| 504 | 
         | 
| 505 | 
         
            -
                     
     | 
| 506 | 
         
            -
             
     | 
| 507 | 
         
            -
             
     | 
| 508 | 
         
            -
                        self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem)
         
     | 
| 509 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 510 | 
         
             
                    if not self.is_primary_stem_only:
         
     | 
| 
         | 
|
| 511 | 
         
             
                        secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
         
     | 
| 512 | 
         
             
                        if not isinstance(self.secondary_source, np.ndarray):
         
     | 
| 513 | 
         
            -
                            raw_mix = self. 
     | 
| 514 | 
         
            -
                            self.secondary_source = spec_utils. 
     | 
| 515 | 
         | 
| 516 | 
         
            -
             
     | 
| 517 | 
         
            -
             
     | 
| 518 | 
         
            -
             
     | 
| 519 | 
         
            -
             
     | 
| 520 | 
         | 
| 521 | 
         
            -
                         
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
                            
         
     | 
| 524 | 
         
            -
                        self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
         
     | 
| 525 | 
         
            -
                    
         
     | 
| 526 | 
         
            -
                    clear_gpu_cache()
         
     | 
| 527 | 
         | 
| 
         | 
|
| 528 | 
         
             
                    secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 529 | 
         
            -
                    
         
     | 
| 530 | 
         
            -
                    self.process_vocal_split_chain(secondary_sources)
         
     | 
| 531 | 
         | 
| 532 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 533 | 
         
             
                        return secondary_sources
         
     | 
| 534 | 
         | 
| 535 | 
         
             
                def initialize_model_settings(self):
         
     | 
| 536 | 
         
             
                    self.n_bins = self.n_fft//2+1
         
     | 
| 537 | 
         
             
                    self.trim = self.n_fft//2
         
     | 
| 538 | 
         
            -
                    self.chunk_size = self.hop * (self. 
     | 
| 
         | 
|
| 
         | 
|
| 539 | 
         
             
                    self.gen_size = self.chunk_size-2*self.trim
         
     | 
| 540 | 
         
            -
                    self.stft = STFT(self.n_fft, self.hop, self.dim_f, self.device)
         
     | 
| 541 | 
         
            -
             
     | 
| 542 | 
         
            -
                def demix(self, mix, is_match_mix=False):
         
     | 
| 543 | 
         
            -
                    self.initialize_model_settings()
         
     | 
| 544 | 
         
            -
                    
         
     | 
| 545 | 
         
            -
                    org_mix = mix
         
     | 
| 546 | 
         
            -
                    tar_waves_ = []
         
     | 
| 547 | 
         | 
| 548 | 
         
            -
             
     | 
| 549 | 
         
            -
             
     | 
| 550 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 551 | 
         
             
                    else:
         
     | 
| 552 | 
         
            -
                         
     | 
| 553 | 
         
            -
                         
     | 
| 554 | 
         
            -
                        
         
     | 
| 555 | 
         
            -
                         
     | 
| 556 | 
         
            -
             
     | 
| 557 | 
         
            -
             
     | 
| 558 | 
         
            -
             
     | 
| 559 | 
         
            -
             
     | 
| 560 | 
         
            -
             
     | 
| 561 | 
         
            -
             
     | 
| 562 | 
         
            -
             
     | 
| 563 | 
         
            -
                    step = self.chunk_size - self.n_fft if overlap == DEFAULT else int((1 - overlap) * chunk_size)
         
     | 
| 564 | 
         
            -
                    result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
         
     | 
| 565 | 
         
            -
                    divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
         
     | 
| 566 | 
         
            -
                    total = 0
         
     | 
| 567 | 
         
            -
                    total_chunks = (mixture.shape[-1] + step - 1) // step
         
     | 
| 568 | 
         
            -
             
     | 
| 569 | 
         
            -
                    for i in range(0, mixture.shape[-1], step):
         
     | 
| 570 | 
         
            -
                        total += 1
         
     | 
| 571 | 
         
            -
                        start = i
         
     | 
| 572 | 
         
            -
                        end = min(i + chunk_size, mixture.shape[-1])
         
     | 
| 573 | 
         
            -
             
     | 
| 574 | 
         
            -
                        chunk_size_actual = end - start
         
     | 
| 575 | 
         
            -
             
     | 
| 576 | 
         
            -
                        if overlap == 0:
         
     | 
| 577 | 
         
            -
                            window = None
         
     | 
| 578 | 
         
            -
                        else:
         
     | 
| 579 | 
         
            -
                            window = np.hanning(chunk_size_actual)
         
     | 
| 580 | 
         
            -
                            window = np.tile(window[None, None, :], (1, 2, 1))
         
     | 
| 581 | 
         
            -
             
     | 
| 582 | 
         
            -
                        mix_part_ = mixture[:, start:end]
         
     | 
| 583 | 
         
            -
                        if end != i + chunk_size:
         
     | 
| 584 | 
         
            -
                            pad_size = (i + chunk_size) - end
         
     | 
| 585 | 
         
            -
                            mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype='float32')), axis=-1)
         
     | 
| 586 | 
         | 
| 587 | 
         
            -
             
     | 
| 588 | 
         
            -
             
     | 
| 589 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 590 | 
         
             
                        with torch.no_grad():
         
     | 
| 591 | 
         
             
                            for mix_wave in mix_waves:
         
     | 
| 592 | 
         
            -
                                self.running_inference_progress_bar( 
     | 
| 593 | 
         
            -
             
     | 
| 594 | 
         
            -
                                 
     | 
| 595 | 
         
            -
             
     | 
| 596 | 
         
            -
             
     | 
| 597 | 
         
            -
             
     | 
| 598 | 
         
            -
             
     | 
| 599 | 
         
            -
             
     | 
| 600 | 
         
            -
             
     | 
| 601 | 
         
            -
             
     | 
| 602 | 
         
            -
                                result[..., start:end] += tar_waves[..., :end-start]
         
     | 
| 603 | 
         
            -
                        
         
     | 
| 604 | 
         
            -
                    tar_waves = result / divider
         
     | 
| 605 | 
         
            -
                    tar_waves_.append(tar_waves)
         
     | 
| 606 | 
         
            -
             
     | 
| 607 | 
         
            -
                    tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim]
         
     | 
| 608 | 
         
            -
                    tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :mix.shape[-1]]
         
     | 
| 609 | 
         | 
| 610 | 
         
            -
                     
     | 
| 611 | 
         
            -
             
     | 
| 612 | 
         
            -
                    if self.is_pitch_change and not is_match_mix:
         
     | 
| 613 | 
         
            -
                        source = self.pitch_fix(source, sr_pitched, org_mix)
         
     | 
| 614 | 
         
            -
             
     | 
| 615 | 
         
            -
                    source = source if is_match_mix else source*self.compensate
         
     | 
| 616 | 
         
            -
             
     | 
| 617 | 
         
            -
                    if self.is_denoise_model and not is_match_mix:
         
     | 
| 618 | 
         
            -
                        if NO_STEM in self.primary_stem_native or self.primary_stem_native == INST_STEM:
         
     | 
| 619 | 
         
            -
                            if org_mix.shape[1] != source.shape[1]:
         
     | 
| 620 | 
         
            -
                                source = spec_utils.match_array_shapes(source, org_mix)
         
     | 
| 621 | 
         
            -
                            source = org_mix - vr_denoiser(org_mix-source, self.device, model_path=self.DENOISER_MODEL)
         
     | 
| 622 | 
         
            -
                        else:
         
     | 
| 623 | 
         
            -
                            source = vr_denoiser(source, self.device, model_path=self.DENOISER_MODEL)
         
     | 
| 624 | 
         
            -
             
     | 
| 625 | 
         
            -
                    return source
         
     | 
| 626 | 
         | 
| 627 | 
         
            -
                def run_model(self, mix, is_match_mix=False):
         
     | 
| 628 | 
         | 
| 629 | 
         
             
                    spek = self.stft(mix.to(self.device))*self.adjust
         
     | 
| 630 | 
         
             
                    spek[:, :, :3, :] *= 0 
         
     | 
| 
         @@ -634,189 +385,58 @@ class SeperateMDX(SeperateAttributes): 
     | 
|
| 634 | 
         
             
                    else:
         
     | 
| 635 | 
         
             
                        spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
         
     | 
| 636 | 
         | 
| 637 | 
         
            -
                     
     | 
| 638 | 
         
            -
             
     | 
| 639 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 640 | 
         | 
| 641 | 
         
             
                def seperate(self):
         
     | 
| 642 | 
         
            -
                    samplerate = 44100
         
     | 
| 643 | 
         
            -
                    sources = None
         
     | 
| 644 | 
         
            -
             
     | 
| 645 | 
         
            -
                    if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
         
     | 
| 646 | 
         
            -
                        mix, sources = self.primary_sources
         
     | 
| 647 | 
         
            -
                        self.load_cached_sources()
         
     | 
| 648 | 
         
            -
                    else:
         
     | 
| 649 | 
         
            -
                        self.start_inference_console_write()
         
     | 
| 650 | 
         
            -
                        self.running_inference_console_write()
         
     | 
| 651 | 
         
            -
                        mix = prepare_mix(self.audio_file)
         
     | 
| 652 | 
         
            -
                        sources = self.demix(mix)
         
     | 
| 653 | 
         
            -
                        if not self.is_vocal_split_model:
         
     | 
| 654 | 
         
            -
                            self.cache_source((mix, sources))
         
     | 
| 655 | 
         
            -
                        self.write_to_console(DONE, base_text='')
         
     | 
| 656 | 
         | 
| 657 | 
         
            -
                    stem_list = [self.mdx_c_configs.training.target_instrument] if self.mdx_c_configs.training.target_instrument else [i for i in self.mdx_c_configs.training.instruments]
         
     | 
| 658 | 
         
            -
             
     | 
| 659 | 
         
            -
                    if self.is_secondary_model:
         
     | 
| 660 | 
         
            -
                        if self.is_pre_proc_model:
         
     | 
| 661 | 
         
            -
                            self.mdxnet_stem_select = stem_list[0]
         
     | 
| 662 | 
         
            -
                        else:
         
     | 
| 663 | 
         
            -
                            self.mdxnet_stem_select = self.main_model_primary_stem_4_stem if self.main_model_primary_stem_4_stem else self.primary_model_primary_stem
         
     | 
| 664 | 
         
            -
                        self.primary_stem = self.mdxnet_stem_select
         
     | 
| 665 | 
         
            -
                        self.secondary_stem = secondary_stem(self.mdxnet_stem_select)
         
     | 
| 666 | 
         
            -
                        self.is_primary_stem_only, self.is_secondary_stem_only = False, False
         
     | 
| 667 | 
         
            -
             
     | 
| 668 | 
         
            -
                    is_all_stems = self.mdxnet_stem_select == ALL_STEMS
         
     | 
| 669 | 
         
            -
                    is_not_ensemble_master = not self.process_data['is_ensemble_master']
         
     | 
| 670 | 
         
            -
                    is_not_single_stem = not len(stem_list) <= 2
         
     | 
| 671 | 
         
            -
                    is_not_secondary_model = not self.is_secondary_model
         
     | 
| 672 | 
         
            -
                    is_ensemble_4_stem = self.is_4_stem_ensemble and is_not_single_stem
         
     | 
| 673 | 
         
            -
             
     | 
| 674 | 
         
            -
                    if (is_all_stems and is_not_ensemble_master and is_not_single_stem and is_not_secondary_model) or is_ensemble_4_stem and not self.is_pre_proc_model:
         
     | 
| 675 | 
         
            -
                        for stem in stem_list:
         
     | 
| 676 | 
         
            -
                            primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem}).wav')
         
     | 
| 677 | 
         
            -
                            self.primary_source = sources[stem].T
         
     | 
| 678 | 
         
            -
                            self.write_audio(primary_stem_path, self.primary_source, samplerate, stem_name=stem)
         
     | 
| 679 | 
         
            -
                            
         
     | 
| 680 | 
         
            -
                            if stem == VOCAL_STEM and not self.is_sec_bv_rebalance:
         
     | 
| 681 | 
         
            -
                                self.process_vocal_split_chain({VOCAL_STEM:stem})
         
     | 
| 682 | 
         
            -
                    else:
         
     | 
| 683 | 
         
            -
                        if len(stem_list) == 1:
         
     | 
| 684 | 
         
            -
                            source_primary = sources  
         
     | 
| 685 | 
         
            -
                        else:
         
     | 
| 686 | 
         
            -
                            source_primary = sources[stem_list[0]] if self.is_multi_stem_ensemble and len(stem_list) == 2 else sources[self.mdxnet_stem_select]
         
     | 
| 687 | 
         
            -
                        if self.is_secondary_model_activated and self.secondary_model:
         
     | 
| 688 | 
         
            -
                            self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, 
         
     | 
| 689 | 
         
            -
                                                                                                                     self.process_data, 
         
     | 
| 690 | 
         
            -
                                                                                                                     main_process_method=self.process_method, 
         
     | 
| 691 | 
         
            -
                                                                                                                     main_model_primary=self.primary_stem)
         
     | 
| 692 | 
         
            -
             
     | 
| 693 | 
         
            -
                        if not self.is_primary_stem_only:
         
     | 
| 694 | 
         
            -
                            secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
         
     | 
| 695 | 
         
            -
                            if not isinstance(self.secondary_source, np.ndarray):
         
     | 
| 696 | 
         
            -
                                
         
     | 
| 697 | 
         
            -
                                if self.is_mdx_combine_stems and len(stem_list) >= 2:
         
     | 
| 698 | 
         
            -
                                    if len(stem_list) == 2:
         
     | 
| 699 | 
         
            -
                                        secondary_source = sources[self.secondary_stem]
         
     | 
| 700 | 
         
            -
                                    else:
         
     | 
| 701 | 
         
            -
                                        sources.pop(self.primary_stem)
         
     | 
| 702 | 
         
            -
                                        next_stem = next(iter(sources))
         
     | 
| 703 | 
         
            -
                                        secondary_source = np.zeros_like(sources[next_stem])
         
     | 
| 704 | 
         
            -
                                        for v in sources.values():
         
     | 
| 705 | 
         
            -
                                            secondary_source += v
         
     | 
| 706 | 
         
            -
                                            
         
     | 
| 707 | 
         
            -
                                    self.secondary_source = secondary_source.T 
         
     | 
| 708 | 
         
            -
                                else:
         
     | 
| 709 | 
         
            -
                                    self.secondary_source, raw_mix = source_primary, self.match_frequency_pitch(mix)
         
     | 
| 710 | 
         
            -
                                    self.secondary_source = spec_utils.to_shape(self.secondary_source, raw_mix.shape)
         
     | 
| 711 | 
         
            -
                                
         
     | 
| 712 | 
         
            -
                                    if self.is_invert_spec:
         
     | 
| 713 | 
         
            -
                                        self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
         
     | 
| 714 | 
         
            -
                                    else:
         
     | 
| 715 | 
         
            -
                                        self.secondary_source = (-self.secondary_source.T+raw_mix.T)
         
     | 
| 716 | 
         
            -
                                        
         
     | 
| 717 | 
         
            -
                            self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, samplerate)    
         
     | 
| 718 | 
         
            -
             
     | 
| 719 | 
         
            -
                        if not self.is_secondary_stem_only:
         
     | 
| 720 | 
         
            -
                            primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
         
     | 
| 721 | 
         
            -
                            if not isinstance(self.primary_source, np.ndarray):
         
     | 
| 722 | 
         
            -
                                self.primary_source = source_primary.T
         
     | 
| 723 | 
         
            -
             
     | 
| 724 | 
         
            -
                            self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
         
     | 
| 725 | 
         
            -
             
     | 
| 726 | 
         
            -
                    clear_gpu_cache()
         
     | 
| 727 | 
         
            -
                    
         
     | 
| 728 | 
         
            -
                    secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 729 | 
         
            -
                    self.process_vocal_split_chain(secondary_sources)
         
     | 
| 730 | 
         
            -
                    
         
     | 
| 731 | 
         
            -
                    if self.is_secondary_model or self.is_pre_proc_model:
         
     | 
| 732 | 
         
            -
                        return secondary_sources
         
     | 
| 733 | 
         
            -
             
     | 
| 734 | 
         
            -
                def demix(self, mix):
         
     | 
| 735 | 
         
            -
                    sr_pitched = 441000
         
     | 
| 736 | 
         
            -
                    org_mix = mix
         
     | 
| 737 | 
         
            -
                    if self.is_pitch_change:
         
     | 
| 738 | 
         
            -
                        mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
         
     | 
| 739 | 
         
            -
             
     | 
| 740 | 
         
            -
                    model = TFC_TDF_net(self.mdx_c_configs, device=self.device)
         
     | 
| 741 | 
         
            -
                    model.load_state_dict(torch.load(self.model_path, map_location=cpu))
         
     | 
| 742 | 
         
            -
                    model.to(self.device).eval()
         
     | 
| 743 | 
         
            -
                    mix = torch.tensor(mix, dtype=torch.float32)
         
     | 
| 744 | 
         
            -
             
     | 
| 745 | 
         
            -
                    try:
         
     | 
| 746 | 
         
            -
                        S = model.num_target_instruments
         
     | 
| 747 | 
         
            -
                    except Exception as e:
         
     | 
| 748 | 
         
            -
                        S = model.module.num_target_instruments
         
     | 
| 749 | 
         
            -
             
     | 
| 750 | 
         
            -
                    mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
         
     | 
| 751 | 
         
            -
                    
         
     | 
| 752 | 
         
            -
                    batch_size = self.mdx_batch_size
         
     | 
| 753 | 
         
            -
                    chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1)
         
     | 
| 754 | 
         
            -
                    overlap = self.overlap_mdx23
         
     | 
| 755 | 
         
            -
             
     | 
| 756 | 
         
            -
                    hop_size = chunk_size // overlap
         
     | 
| 757 | 
         
            -
                    mix_shape = mix.shape[1]
         
     | 
| 758 | 
         
            -
                    pad_size = hop_size - (mix_shape - chunk_size) % hop_size
         
     | 
| 759 | 
         
            -
                    mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
         
     | 
| 760 | 
         
            -
             
     | 
| 761 | 
         
            -
                    chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
         
     | 
| 762 | 
         
            -
                    batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
         
     | 
| 763 | 
         
            -
                    
         
     | 
| 764 | 
         
            -
                    X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix)
         
     | 
| 765 | 
         
            -
                    X = X.to(self.device)
         
     | 
| 766 | 
         
            -
             
     | 
| 767 | 
         
            -
                    with torch.no_grad():
         
     | 
| 768 | 
         
            -
                        cnt = 0
         
     | 
| 769 | 
         
            -
                        for batch in batches:
         
     | 
| 770 | 
         
            -
                            self.running_inference_progress_bar(len(batches))
         
     | 
| 771 | 
         
            -
                            x = model(batch.to(self.device))
         
     | 
| 772 | 
         
            -
                            
         
     | 
| 773 | 
         
            -
                            for w in x:
         
     | 
| 774 | 
         
            -
                                X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w
         
     | 
| 775 | 
         
            -
                                cnt += 1
         
     | 
| 776 | 
         
            -
             
     | 
| 777 | 
         
            -
                    estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap
         
     | 
| 778 | 
         
            -
                    del X
         
     | 
| 779 | 
         
            -
                    pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix)
         
     | 
| 780 | 
         
            -
             
     | 
| 781 | 
         
            -
                    if S > 1:
         
     | 
| 782 | 
         
            -
                        sources = {k: pitch_fix(v) if self.is_pitch_change else v for k, v in zip(self.mdx_c_configs.training.instruments, estimated_sources.cpu().detach().numpy())}
         
     | 
| 783 | 
         
            -
                        del estimated_sources
         
     | 
| 784 | 
         
            -
                        if self.is_denoise_model:
         
     | 
| 785 | 
         
            -
                            if VOCAL_STEM in sources.keys() and INST_STEM in sources.keys():
         
     | 
| 786 | 
         
            -
                                sources[VOCAL_STEM] = vr_denoiser(sources[VOCAL_STEM], self.device, model_path=self.DENOISER_MODEL)
         
     | 
| 787 | 
         
            -
                                if sources[VOCAL_STEM].shape[1] != org_mix.shape[1]:
         
     | 
| 788 | 
         
            -
                                    sources[VOCAL_STEM] = spec_utils.match_array_shapes(sources[VOCAL_STEM], org_mix)
         
     | 
| 789 | 
         
            -
                                sources[INST_STEM] = org_mix - sources[VOCAL_STEM]
         
     | 
| 790 | 
         
            -
                                        
         
     | 
| 791 | 
         
            -
                        return sources
         
     | 
| 792 | 
         
            -
                    else:
         
     | 
| 793 | 
         
            -
                        est_s = estimated_sources.cpu().detach().numpy()
         
     | 
| 794 | 
         
            -
                        del estimated_sources
         
     | 
| 795 | 
         
            -
                        return pitch_fix(est_s) if self.is_pitch_change else est_s
         
     | 
| 796 | 
         
            -
             
     | 
| 797 | 
         
            -
            class SeperateDemucs(SeperateAttributes):
         
     | 
| 798 | 
         
            -
                def seperate(self):
         
     | 
| 799 | 
         
             
                    samplerate = 44100
         
     | 
| 800 | 
         
             
                    source = None
         
     | 
| 801 | 
         
             
                    model_scale = None
         
     | 
| 802 | 
         
             
                    stem_source = None
         
     | 
| 803 | 
         
             
                    stem_source_secondary = None
         
     | 
| 804 | 
         
             
                    inst_mix = None
         
     | 
| 
         | 
|
| 
         | 
|
| 805 | 
         
             
                    inst_source = None
         
     | 
| 806 | 
         
             
                    is_no_write = False
         
     | 
| 807 | 
         
             
                    is_no_piano_guitar = False
         
     | 
| 808 | 
         
            -
             
     | 
| 809 | 
         
            -
                    
         
     | 
| 810 | 
         
            -
             
     | 
| 
         | 
|
| 811 | 
         
             
                        source = self.primary_sources
         
     | 
| 812 | 
         
            -
                        self.load_cached_sources()
         
     | 
| 813 | 
         
             
                    else:
         
     | 
| 814 | 
         
             
                        self.start_inference_console_write()
         
     | 
| 815 | 
         
            -
                        is_no_cache = True
         
     | 
| 816 | 
         | 
| 817 | 
         
            -
             
     | 
| 818 | 
         
            -
             
     | 
| 819 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 820 | 
         
             
                        if self.demucs_version == DEMUCS_V1:
         
     | 
| 821 | 
         
             
                            if str(self.model_path).endswith(".gz"):
         
     | 
| 822 | 
         
             
                                self.model_path = gzip.open(self.model_path, "rb")
         
     | 
| 
         @@ -842,23 +462,26 @@ class SeperateDemucs(SeperateAttributes): 
     | 
|
| 842 | 
         
             
                                is_no_write = True
         
     | 
| 843 | 
         
             
                                self.write_to_console(DONE, base_text='')
         
     | 
| 844 | 
         
             
                                mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
         
     | 
| 845 | 
         
            -
                                inst_mix = prepare_mix(mix_no_voc[INST_STEM])
         
     | 
| 846 | 
         
             
                                self.process_iteration()
         
     | 
| 847 | 
         
             
                                self.running_inference_console_write(is_no_write=is_no_write)
         
     | 
| 848 | 
         
             
                                inst_source = self.demix_demucs(inst_mix)
         
     | 
| 
         | 
|
| 849 | 
         
             
                                self.process_iteration()
         
     | 
| 850 | 
         | 
| 851 | 
         
             
                        self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None
         
     | 
| 
         | 
|
| 852 | 
         | 
| 853 | 
         
             
                        if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
         
     | 
| 854 | 
         
             
                            source = self.primary_sources
         
     | 
| 855 | 
         
             
                        else:
         
     | 
| 856 | 
         
             
                            source = self.demix_demucs(mix)
         
     | 
| 
         | 
|
| 857 | 
         | 
| 858 | 
         
             
                        self.write_to_console(DONE, base_text='')
         
     | 
| 859 | 
         | 
| 860 | 
         
             
                        del self.demucs
         
     | 
| 861 | 
         
            -
                         
     | 
| 862 | 
         | 
| 863 | 
         
             
                    if isinstance(inst_source, np.ndarray):
         
     | 
| 864 | 
         
             
                        source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
         
     | 
| 
         @@ -866,7 +489,6 @@ class SeperateDemucs(SeperateAttributes): 
     | 
|
| 866 | 
         
             
                        source = inst_source
         
     | 
| 867 | 
         | 
| 868 | 
         
             
                    if isinstance(source, np.ndarray):
         
     | 
| 869 | 
         
            -
                        
         
     | 
| 870 | 
         
             
                        if len(source) == 2:
         
     | 
| 871 | 
         
             
                            self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
         
     | 
| 872 | 
         
             
                        else:
         
     | 
| 
         @@ -881,40 +503,46 @@ class SeperateDemucs(SeperateAttributes): 
     | 
|
| 881 | 
         
             
                                    other_source += i
         
     | 
| 882 | 
         
             
                                source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source)
         
     | 
| 883 | 
         
             
                                source[self.demucs_source_map[OTHER_STEM]] = source_reshape
         
     | 
| 884 | 
         
            -
             
     | 
| 885 | 
         
            -
                    if not self. 
     | 
| 886 | 
         
             
                        self.cache_source(source)
         
     | 
| 887 | 
         
            -
             
     | 
| 888 | 
         
            -
                    if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble and not self.is_return_dual:
         
     | 
| 889 | 
         
             
                        for stem_name, stem_value in self.demucs_source_map.items():
         
     | 
| 890 | 
         
             
                            if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
         
     | 
| 891 | 
         
             
                                if self.secondary_model_4_stem[stem_value]:
         
     | 
| 892 | 
         
             
                                    model_scale = self.secondary_model_4_stem_scale[stem_value]
         
     | 
| 893 | 
         
            -
                                    stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name,  
     | 
| 894 | 
         
             
                                    if isinstance(stem_source_secondary, np.ndarray):
         
     | 
| 895 | 
         
            -
                                        stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value] 
     | 
| 
         | 
|
| 896 | 
         
             
                                    elif type(stem_source_secondary) is dict:
         
     | 
| 897 | 
         
             
                                        stem_source_secondary = stem_source_secondary[stem_name]
         
     | 
| 898 | 
         | 
| 899 | 
         
             
                            stem_source_secondary = None if stem_value >= 4 else stem_source_secondary
         
     | 
| 
         | 
|
| 900 | 
         
             
                            stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav')
         
     | 
| 901 | 
         
            -
                            stem_source = source[stem_value].T
         
     | 
| 902 | 
         
            -
                            
         
     | 
| 903 | 
         
            -
             
     | 
| 904 | 
         
            -
                            self.write_audio(stem_path, stem_source, samplerate, stem_name=stem_name)
         
     | 
| 905 | 
         
            -
                            
         
     | 
| 906 | 
         
            -
                            if stem_name == VOCAL_STEM and not self.is_sec_bv_rebalance:
         
     | 
| 907 | 
         
            -
                                self.process_vocal_split_chain({VOCAL_STEM:stem_source})
         
     | 
| 908 | 
         
            -
                            
         
     | 
| 909 | 
         
             
                        if self.is_secondary_model:    
         
     | 
| 910 | 
         
             
                            return source
         
     | 
| 911 | 
         
             
                    else:
         
     | 
| 912 | 
         
            -
                        if self.is_secondary_model_activated 
     | 
| 
         | 
|
| 913 | 
         
             
                                self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
         
     | 
| 914 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 915 | 
         
             
                        if not self.is_primary_stem_only:
         
     | 
| 916 | 
         
             
                            def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False):
         
     | 
| 917 | 
         
             
                                secondary_source = self.secondary_source if not is_inst_mixture else None
         
     | 
| 
         | 
|
| 918 | 
         
             
                                secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav')
         
     | 
| 919 | 
         
             
                                secondary_source_secondary = None
         
     | 
| 920 | 
         | 
| 
         @@ -930,12 +558,12 @@ class SeperateDemucs(SeperateAttributes): 
     | 
|
| 930 | 
         
             
                                        secondary_source = np.zeros_like(source[0])
         
     | 
| 931 | 
         
             
                                        for i in source:
         
     | 
| 932 | 
         
             
                                            secondary_source += i
         
     | 
| 933 | 
         
            -
                                        secondary_source = secondary_source.T
         
     | 
| 934 | 
         
             
                                    else:
         
     | 
| 935 | 
         
             
                                        if not isinstance(raw_mixture, np.ndarray):
         
     | 
| 936 | 
         
            -
                                            raw_mixture = prepare_mix(self.audio_file)
         
     | 
| 937 | 
         | 
| 938 | 
         
            -
                                        secondary_source = source[self.demucs_source_map[self.primary_stem]]
         
     | 
| 939 | 
         | 
| 940 | 
         
             
                                        if self.is_invert_spec:
         
     | 
| 941 | 
         
             
                                            secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source)
         
     | 
| 
         @@ -946,90 +574,86 @@ class SeperateDemucs(SeperateAttributes): 
     | 
|
| 946 | 
         
             
                                if not is_inst_mixture:
         
     | 
| 947 | 
         
             
                                    self.secondary_source = secondary_source
         
     | 
| 948 | 
         
             
                                    secondary_source_secondary = self.secondary_source_secondary
         
     | 
| 949 | 
         
            -
                                    self.secondary_source = self.process_secondary_stem(secondary_source, secondary_source_secondary)
         
     | 
| 950 | 
         
             
                                    self.secondary_source_map = {self.secondary_stem: self.secondary_source}
         
     | 
| 951 | 
         | 
| 952 | 
         
            -
                                self.write_audio(secondary_stem_path, secondary_source, samplerate,  
     | 
| 953 | 
         | 
| 954 | 
         
            -
                            secondary_save(self.secondary_stem, source, raw_mixture= 
     | 
| 955 | 
         | 
| 956 | 
         
             
                            if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
         
     | 
| 957 | 
         
            -
                                secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture= 
     | 
| 958 | 
         
            -
             
     | 
| 959 | 
         
            -
                        if not self.is_secondary_stem_only:
         
     | 
| 960 | 
         
            -
                            primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
         
     | 
| 961 | 
         
            -
                            if not isinstance(self.primary_source, np.ndarray):
         
     | 
| 962 | 
         
            -
                                self.primary_source = source[self.demucs_source_map[self.primary_stem]].T
         
     | 
| 963 | 
         
            -
                            
         
     | 
| 964 | 
         
            -
                            self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
         
     | 
| 965 | 
         | 
| 966 | 
         
             
                        secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 967 | 
         
            -
             
     | 
| 968 | 
         
            -
                        self. 
     | 
| 969 | 
         | 
| 970 | 
         
             
                        if self.is_secondary_model:    
         
     | 
| 971 | 
         
             
                            return secondary_sources
         
     | 
| 972 | 
         | 
| 973 | 
         
             
                def demix_demucs(self, mix):
         
     | 
| 974 | 
         
            -
                    
         
     | 
| 975 | 
         
            -
                    org_mix = mix
         
     | 
| 976 | 
         
            -
                    
         
     | 
| 977 | 
         
            -
                    if self.is_pitch_change:
         
     | 
| 978 | 
         
            -
                        mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
         
     | 
| 979 | 
         
            -
                    
         
     | 
| 980 | 
         
             
                    processed = {}
         
     | 
| 981 | 
         
            -
             
     | 
| 982 | 
         
            -
                     
     | 
| 983 | 
         
            -
             
     | 
| 984 | 
         
            -
                     
     | 
| 985 | 
         
            -
             
     | 
| 986 | 
         
            -
             
     | 
| 987 | 
         
            -
                         
     | 
| 988 | 
         
            -
             
     | 
| 989 | 
         
            -
             
     | 
| 990 | 
         
            -
             
     | 
| 991 | 
         
            -
             
     | 
| 992 | 
         
            -
             
     | 
| 993 | 
         
            -
                         
     | 
| 994 | 
         
            -
                             
     | 
| 995 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 996 | 
         
             
                                                        self.shifts,
         
     | 
| 997 | 
         
             
                                                        self.is_split_mode,
         
     | 
| 998 | 
         
             
                                                        self.overlap,
         
     | 
| 999 | 
         
            -
                                                         
     | 
| 1000 | 
         
            -
             
     | 
| 1001 | 
         
            -
             
     | 
| 1002 | 
         
            -
             
     | 
| 1003 | 
         
            -
             
     | 
| 1004 | 
         
            -
             
     | 
| 1005 | 
         
            -
             
     | 
| 1006 | 
         
            -
             
     | 
| 1007 | 
         
            -
             
     | 
| 1008 | 
         
            -
             
     | 
| 1009 | 
         
            -
             
     | 
| 1010 | 
         
            -
             
     | 
| 1011 | 
         
            -
                    sources[[0,1]] = sources[[1,0]]
         
     | 
| 1012 | 
         
            -
                    processed[mix] = sources[:,:,0:None].copy()
         
     | 
| 1013 | 
         
            -
                    sources = list(processed.values())
         
     | 
| 1014 | 
         
            -
                    sources = [s[:,:,0:None] for s in sources]
         
     | 
| 1015 | 
         
            -
                    #sources = [self.pitch_fix(s[:,:,0:None], sr_pitched, org_mix) if self.is_pitch_change else s[:,:,0:None] for s in sources]
         
     | 
| 1016 | 
         
             
                    sources = np.concatenate(sources, axis=-1)
         
     | 
| 1017 | 
         
            -
                                 
         
     | 
| 1018 | 
         
            -
                    if self.is_pitch_change:
         
     | 
| 1019 | 
         
            -
                        sources = np.stack([self.pitch_fix(stem, sr_pitched, org_mix) for stem in sources])
         
     | 
| 1020 | 
         | 
| 1021 | 
         
             
                    return sources
         
     | 
| 1022 | 
         | 
| 1023 | 
         
             
            class SeperateVR(SeperateAttributes):        
         
     | 
| 1024 | 
         | 
| 1025 | 
         
             
                def seperate(self):
         
     | 
| 1026 | 
         
            -
                    if self.primary_model_name == self.model_basename and  
     | 
| 1027 | 
         
            -
                         
     | 
| 1028 | 
         
            -
                        self.load_cached_sources()
         
     | 
| 1029 | 
         
             
                    else:
         
     | 
| 1030 | 
         
             
                        self.start_inference_console_write()
         
     | 
| 1031 | 
         
            -
             
     | 
| 1032 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1033 | 
         | 
| 1034 | 
         
             
                        nn_arch_sizes = [
         
     | 
| 1035 | 
         
             
                            31191, # default
         
     | 
| 
         @@ -1039,11 +663,7 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1039 | 
         
             
                        nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
         
     | 
| 1040 | 
         | 
| 1041 | 
         
             
                        if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
         
     | 
| 1042 | 
         
            -
                            self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2, 
         
     | 
| 1043 | 
         
            -
                                                                  nn_arch_size, 
         
     | 
| 1044 | 
         
            -
                                                                  nout=self.model_capacity[0], 
         
     | 
| 1045 | 
         
            -
                                                                  nout_lstm=self.model_capacity[1])
         
     | 
| 1046 | 
         
            -
                            self.is_vr_51_model = True
         
     | 
| 1047 | 
         
             
                        else:
         
     | 
| 1048 | 
         
             
                            self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size)
         
     | 
| 1049 | 
         | 
| 
         @@ -1053,36 +673,41 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1053 | 
         
             
                        self.running_inference_console_write()
         
     | 
| 1054 | 
         | 
| 1055 | 
         
             
                        y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
         
     | 
| 1056 | 
         
            -
                        if not self.is_vocal_split_model:
         
     | 
| 1057 | 
         
            -
                            self.cache_source((y_spec, v_spec))
         
     | 
| 1058 | 
         
             
                        self.write_to_console(DONE, base_text='')
         
     | 
| 1059 | 
         | 
| 1060 | 
         
            -
                    if self.is_secondary_model_activated 
     | 
| 1061 | 
         
            -
                         
     | 
| 
         | 
|
| 1062 | 
         | 
| 1063 | 
         
             
                    if not self.is_secondary_stem_only:
         
     | 
| 
         | 
|
| 1064 | 
         
             
                        primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
         
     | 
| 1065 | 
         
             
                        if not isinstance(self.primary_source, np.ndarray):
         
     | 
| 1066 | 
         
            -
                            self.primary_source = self.spec_to_wav(y_spec).T
         
     | 
| 1067 | 
         
             
                            if not self.model_samplerate == 44100:
         
     | 
| 1068 | 
         
             
                                self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
         
     | 
| 1069 | 
         | 
| 1070 | 
         
            -
                        self.primary_source_map = self. 
     | 
| 
         | 
|
| 
         | 
|
| 1071 | 
         | 
| 1072 | 
         
             
                    if not self.is_primary_stem_only:
         
     | 
| 
         | 
|
| 1073 | 
         
             
                        secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
         
     | 
| 1074 | 
         
             
                        if not isinstance(self.secondary_source, np.ndarray):
         
     | 
| 1075 | 
         
            -
                            self.secondary_source = self.spec_to_wav(v_spec) 
     | 
| 
         | 
|
| 1076 | 
         
             
                            if not self.model_samplerate == 44100:
         
     | 
| 1077 | 
         
             
                                self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
         
     | 
| 1078 | 
         | 
| 1079 | 
         
            -
                        self.secondary_source_map = self. 
     | 
| 
         | 
|
| 
         | 
|
| 1080 | 
         | 
| 1081 | 
         
            -
                     
     | 
| 1082 | 
         
             
                    secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 1083 | 
         
            -
                    
         
     | 
| 1084 | 
         
            -
             
     | 
| 1085 | 
         
            -
                    
         
     | 
| 1086 | 
         
             
                    if self.is_secondary_model:
         
     | 
| 1087 | 
         
             
                        return secondary_sources
         
     | 
| 1088 | 
         | 
| 
         @@ -1092,9 +717,6 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1092 | 
         | 
| 1093 | 
         
             
                    bands_n = len(self.mp.param['band'])
         
     | 
| 1094 | 
         | 
| 1095 | 
         
            -
                    audio_file = spec_utils.write_array_to_mem(self.audio_file, subtype=self.wav_type_set)
         
     | 
| 1096 | 
         
            -
                    is_mp3 = audio_file.endswith('.mp3') if isinstance(audio_file, str) else False
         
     | 
| 1097 | 
         
            -
             
     | 
| 1098 | 
         
             
                    for d in range(bands_n, 0, -1):        
         
     | 
| 1099 | 
         
             
                        bp = self.mp.param['band'][d]
         
     | 
| 1100 | 
         | 
| 
         @@ -1104,25 +726,26 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1104 | 
         
             
                            wav_resolution = bp['res_type']
         
     | 
| 1105 | 
         | 
| 1106 | 
         
             
                        if d == bands_n: # high-end band
         
     | 
| 1107 | 
         
            -
                            X_wave[d], _ = librosa.load(audio_file, bp['sr'], False, dtype=np.float32, res_type=wav_resolution)
         
     | 
| 1108 | 
         
            -
                            X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], self.mp, band=d, is_v51_model=self.is_vr_51_model)
         
     | 
| 1109 | 
         | 
| 1110 | 
         
            -
                            if not np.any(X_wave[d]) and  
     | 
| 1111 | 
         
            -
                                X_wave[d] = rerun_mp3(audio_file, bp['sr'])
         
     | 
| 1112 | 
         | 
| 1113 | 
         
             
                            if X_wave[d].ndim == 1:
         
     | 
| 1114 | 
         
             
                                X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
         
     | 
| 1115 | 
         
             
                        else: # lower bands
         
     | 
| 1116 | 
         
             
                            X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
         
     | 
| 1117 | 
         
            -
                             
     | 
| 1118 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 1119 | 
         
             
                        if d == bands_n and self.high_end_process != 'none':
         
     | 
| 1120 | 
         
             
                            self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start'])
         
     | 
| 1121 | 
         
             
                            self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :]
         
     | 
| 1122 | 
         | 
| 1123 | 
         
            -
                    X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp 
     | 
| 1124 | 
         | 
| 1125 | 
         
            -
                    del X_wave, X_spec_s 
     | 
| 1126 | 
         | 
| 1127 | 
         
             
                    return X_spec
         
     | 
| 1128 | 
         | 
| 
         @@ -1160,6 +783,7 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1160 | 
         
             
                        return mask
         
     | 
| 1161 | 
         | 
| 1162 | 
         
             
                    def postprocess(mask, X_mag, X_phase):
         
     | 
| 
         | 
|
| 1163 | 
         
             
                        is_non_accom_stem = False
         
     | 
| 1164 | 
         
             
                        for stem in NON_ACCOM_STEMS:
         
     | 
| 1165 | 
         
             
                            if stem == self.primary_stem:
         
     | 
| 
         @@ -1174,7 +798,6 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1174 | 
         
             
                        v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
         
     | 
| 1175 | 
         | 
| 1176 | 
         
             
                        return y_spec, v_spec
         
     | 
| 1177 | 
         
            -
                    
         
     | 
| 1178 | 
         
             
                    X_mag, X_phase = spec_utils.preprocess(X_spec)
         
     | 
| 1179 | 
         
             
                    n_frame = X_mag.shape[2]
         
     | 
| 1180 | 
         
             
                    pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
         
     | 
| 
         @@ -1198,77 +821,35 @@ class SeperateVR(SeperateAttributes): 
     | 
|
| 1198 | 
         
             
                    return y_spec, v_spec
         
     | 
| 1199 | 
         | 
| 1200 | 
         
             
                def spec_to_wav(self, spec):
         
     | 
| 1201 | 
         
            -
                     
     | 
| 
         | 
|
| 1202 | 
         
             
                        input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
         
     | 
| 1203 | 
         
            -
                        wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_ 
     | 
| 1204 | 
         
             
                    else:
         
     | 
| 1205 | 
         
            -
                        wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp 
     | 
| 1206 | 
         | 
| 1207 | 
         
             
                    return wav
         
     | 
| 1208 | 
         
            -
             
     | 
| 1209 | 
         
            -
            def process_secondary_model(secondary_model: ModelData, 
         
     | 
| 1210 | 
         
            -
                                        process_data, 
         
     | 
| 1211 | 
         
            -
                                        main_model_primary_stem_4_stem=None, 
         
     | 
| 1212 | 
         
            -
                                        is_source_load=False, 
         
     | 
| 1213 | 
         
            -
                                        main_process_method=None, 
         
     | 
| 1214 | 
         
            -
                                        is_pre_proc_model=False, 
         
     | 
| 1215 | 
         
            -
                                        is_return_dual=True, 
         
     | 
| 1216 | 
         
            -
                                        main_model_primary=None):
         
     | 
| 1217 | 
         | 
| 1218 | 
         
             
                if not is_pre_proc_model:
         
     | 
| 1219 | 
         
             
                    process_iteration = process_data['process_iteration']
         
     | 
| 1220 | 
         
             
                    process_iteration()
         
     | 
| 1221 | 
         | 
| 1222 | 
         
             
                if secondary_model.process_method == VR_ARCH_TYPE:
         
     | 
| 1223 | 
         
            -
                    seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method 
     | 
| 1224 | 
         
             
                if secondary_model.process_method == MDX_ARCH_TYPE:
         
     | 
| 1225 | 
         
            -
                     
     | 
| 1226 | 
         
            -
                        seperator = SeperateMDXC(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, is_return_dual=is_return_dual, main_model_primary=main_model_primary)
         
     | 
| 1227 | 
         
            -
                    else:
         
     | 
| 1228 | 
         
            -
                        seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, main_model_primary=main_model_primary)
         
     | 
| 1229 | 
         
             
                if secondary_model.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 1230 | 
         
            -
                    seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method 
     | 
| 1231 | 
         | 
| 1232 | 
         
             
                secondary_sources = seperator.seperate()
         
     | 
| 1233 | 
         | 
| 1234 | 
         
            -
                if type(secondary_sources) is dict and not  
     | 
| 1235 | 
         
            -
                    return gather_sources(secondary_model.primary_model_primary_stem,  
     | 
| 1236 | 
         
             
                else:
         
     | 
| 1237 | 
         
             
                    return secondary_sources
         
     | 
| 1238 | 
         | 
| 1239 | 
         
            -
            def process_chain_model(secondary_model: ModelData, 
         
     | 
| 1240 | 
         
            -
                                    process_data, 
         
     | 
| 1241 | 
         
            -
                                    vocal_stem_path, 
         
     | 
| 1242 | 
         
            -
                                    master_vocal_source, 
         
     | 
| 1243 | 
         
            -
                                    master_inst_source=None):
         
     | 
| 1244 | 
         
            -
                
         
     | 
| 1245 | 
         
            -
                process_iteration = process_data['process_iteration']
         
     | 
| 1246 | 
         
            -
                process_iteration()
         
     | 
| 1247 | 
         
            -
                
         
     | 
| 1248 | 
         
            -
                if secondary_model.bv_model_rebalance:
         
     | 
| 1249 | 
         
            -
                    vocal_source = spec_utils.reduce_mix_bv(master_inst_source, master_vocal_source, reduction_rate=secondary_model.bv_model_rebalance)
         
     | 
| 1250 | 
         
            -
                else:
         
     | 
| 1251 | 
         
            -
                    vocal_source = master_vocal_source
         
     | 
| 1252 | 
         
            -
                
         
     | 
| 1253 | 
         
            -
                vocal_stem_path = [vocal_source, os.path.splitext(os.path.basename(vocal_stem_path))[0]]
         
     | 
| 1254 | 
         
            -
             
     | 
| 1255 | 
         
            -
                if secondary_model.process_method == VR_ARCH_TYPE:
         
     | 
| 1256 | 
         
            -
                    seperator = SeperateVR(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
         
     | 
| 1257 | 
         
            -
                if secondary_model.process_method == MDX_ARCH_TYPE:
         
     | 
| 1258 | 
         
            -
                    if secondary_model.is_mdx_c:
         
     | 
| 1259 | 
         
            -
                        seperator = SeperateMDXC(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
         
     | 
| 1260 | 
         
            -
                    else:
         
     | 
| 1261 | 
         
            -
                        seperator = SeperateMDX(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
         
     | 
| 1262 | 
         
            -
                if secondary_model.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 1263 | 
         
            -
                    seperator = SeperateDemucs(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
         
     | 
| 1264 | 
         
            -
                    
         
     | 
| 1265 | 
         
            -
                secondary_sources = seperator.seperate()
         
     | 
| 1266 | 
         
            -
                
         
     | 
| 1267 | 
         
            -
                if type(secondary_sources) is dict:
         
     | 
| 1268 | 
         
            -
                    return secondary_sources
         
     | 
| 1269 | 
         
            -
                else:
         
     | 
| 1270 | 
         
            -
                    return None
         
     | 
| 1271 | 
         
            -
                
         
     | 
| 1272 | 
         
             
            def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict):
         
     | 
| 1273 | 
         | 
| 1274 | 
         
             
                source_primary = False
         
     | 
| 
         @@ -1282,23 +863,53 @@ def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: di 
     | 
|
| 1282 | 
         | 
| 1283 | 
         
             
                return source_primary, source_secondary
         
     | 
| 1284 | 
         | 
| 1285 | 
         
            -
            def prepare_mix(mix):
         
     | 
| 1286 | 
         
            -
             
     | 
| 1287 | 
         
             
                audio_path = mix
         
     | 
| 
         | 
|
| 1288 | 
         | 
| 1289 | 
         
             
                if not isinstance(mix, np.ndarray):
         
     | 
| 1290 | 
         
            -
                    mix,  
     | 
| 1291 | 
         
             
                else:
         
     | 
| 1292 | 
         
             
                    mix = mix.T
         
     | 
| 1293 | 
         | 
| 1294 | 
         
            -
                if  
     | 
| 1295 | 
         
            -
                     
     | 
| 1296 | 
         
            -
                        mix = rerun_mp3(audio_path)
         
     | 
| 1297 | 
         | 
| 1298 | 
         
             
                if mix.ndim == 1:
         
     | 
| 1299 | 
         
             
                    mix = np.asfortranarray([mix,mix])
         
     | 
| 1300 | 
         | 
| 1301 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1302 | 
         | 
| 1303 | 
         
             
            def rerun_mp3(audio_file, sample_rate=44100):
         
     | 
| 1304 | 
         | 
| 
         @@ -1323,137 +934,9 @@ def save_format(audio_path, save_format, mp3_bit_set): 
     | 
|
| 1323 | 
         | 
| 1324 | 
         
             
                    if save_format == MP3:
         
     | 
| 1325 | 
         
             
                        audio_path_mp3 = audio_path.replace(".wav", ".mp3")
         
     | 
| 1326 | 
         
            -
                         
     | 
| 1327 | 
         
            -
                            musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set, codec="libmp3lame")
         
     | 
| 1328 | 
         
            -
                        except Exception as e:
         
     | 
| 1329 | 
         
            -
                            print(e)
         
     | 
| 1330 | 
         
            -
                            musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set)
         
     | 
| 1331 | 
         | 
| 1332 | 
         
             
                    try:
         
     | 
| 1333 | 
         
             
                        os.remove(audio_path)
         
     | 
| 1334 | 
         
             
                    except Exception as e:
         
     | 
| 1335 | 
         
             
                        print(e)
         
     | 
| 1336 | 
         
            -
                        
         
     | 
| 1337 | 
         
            -
            def pitch_shift(mix):
         
     | 
| 1338 | 
         
            -
                new_sr = 31183
         
     | 
| 1339 | 
         
            -
             
     | 
| 1340 | 
         
            -
                # Resample audio file
         
     | 
| 1341 | 
         
            -
                resampled_audio = signal.resample_poly(mix, new_sr, 44100)
         
     | 
| 1342 | 
         
            -
                
         
     | 
| 1343 | 
         
            -
                return resampled_audio
         
     | 
| 1344 | 
         
            -
             
     | 
| 1345 | 
         
            -
            def list_to_dictionary(lst):
         
     | 
| 1346 | 
         
            -
                dictionary = {item: index for index, item in enumerate(lst)}
         
     | 
| 1347 | 
         
            -
                return dictionary
         
     | 
| 1348 | 
         
            -
             
     | 
| 1349 | 
         
            -
            def vr_denoiser(X, device, hop_length=1024, n_fft=2048, cropsize=256, is_deverber=False, model_path=None):
         
     | 
| 1350 | 
         
            -
                batchsize = 4
         
     | 
| 1351 | 
         
            -
             
     | 
| 1352 | 
         
            -
                if is_deverber:
         
     | 
| 1353 | 
         
            -
                    nout, nout_lstm = 64, 128
         
     | 
| 1354 | 
         
            -
                    mp = ModelParameters(os.path.join('lib_v5', 'vr_network', 'modelparams', '4band_v3.json'))
         
     | 
| 1355 | 
         
            -
                    n_fft = mp.param['bins'] * 2
         
     | 
| 1356 | 
         
            -
                else:
         
     | 
| 1357 | 
         
            -
                    mp = None
         
     | 
| 1358 | 
         
            -
                    hop_length=1024
         
     | 
| 1359 | 
         
            -
                    nout, nout_lstm = 16, 128
         
     | 
| 1360 | 
         
            -
                
         
     | 
| 1361 | 
         
            -
                model = nets_new.CascadedNet(n_fft, nout=nout, nout_lstm=nout_lstm)
         
     | 
| 1362 | 
         
            -
                model.load_state_dict(torch.load(model_path, map_location=cpu))
         
     | 
| 1363 | 
         
            -
                model.to(device)
         
     | 
| 1364 | 
         
            -
             
     | 
| 1365 | 
         
            -
                if mp is None:
         
     | 
| 1366 | 
         
            -
                    X_spec = spec_utils.wave_to_spectrogram_old(X, hop_length, n_fft)
         
     | 
| 1367 | 
         
            -
                else:
         
     | 
| 1368 | 
         
            -
                    X_spec = loading_mix(X.T, mp)
         
     | 
| 1369 | 
         
            -
               
         
     | 
| 1370 | 
         
            -
                #PreProcess
         
     | 
| 1371 | 
         
            -
                X_mag = np.abs(X_spec)
         
     | 
| 1372 | 
         
            -
                X_phase = np.angle(X_spec)
         
     | 
| 1373 | 
         
            -
             
     | 
| 1374 | 
         
            -
                #Sep
         
     | 
| 1375 | 
         
            -
                n_frame = X_mag.shape[2]
         
     | 
| 1376 | 
         
            -
                pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, cropsize, model.offset)
         
     | 
| 1377 | 
         
            -
                X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
         
     | 
| 1378 | 
         
            -
                X_mag_pad /= X_mag_pad.max()
         
     | 
| 1379 | 
         
            -
             
     | 
| 1380 | 
         
            -
                X_dataset = []
         
     | 
| 1381 | 
         
            -
                patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size
         
     | 
| 1382 | 
         
            -
                for i in range(patches):
         
     | 
| 1383 | 
         
            -
                    start = i * roi_size
         
     | 
| 1384 | 
         
            -
                    X_mag_crop = X_mag_pad[:, :, start:start + cropsize]
         
     | 
| 1385 | 
         
            -
                    X_dataset.append(X_mag_crop)
         
     | 
| 1386 | 
         
            -
             
     | 
| 1387 | 
         
            -
                X_dataset = np.asarray(X_dataset)
         
     | 
| 1388 | 
         
            -
             
     | 
| 1389 | 
         
            -
                model.eval()
         
     | 
| 1390 | 
         
            -
                
         
     | 
| 1391 | 
         
            -
                with torch.no_grad():
         
     | 
| 1392 | 
         
            -
                    mask = []
         
     | 
| 1393 | 
         
            -
                    # To reduce the overhead, dataloader is not used.
         
     | 
| 1394 | 
         
            -
                    for i in range(0, patches, batchsize):
         
     | 
| 1395 | 
         
            -
                        X_batch = X_dataset[i: i + batchsize]
         
     | 
| 1396 | 
         
            -
                        X_batch = torch.from_numpy(X_batch).to(device)
         
     | 
| 1397 | 
         
            -
             
     | 
| 1398 | 
         
            -
                        pred = model.predict_mask(X_batch)
         
     | 
| 1399 | 
         
            -
             
     | 
| 1400 | 
         
            -
                        pred = pred.detach().cpu().numpy()
         
     | 
| 1401 | 
         
            -
                        pred = np.concatenate(pred, axis=2)
         
     | 
| 1402 | 
         
            -
                        mask.append(pred)
         
     | 
| 1403 | 
         
            -
             
     | 
| 1404 | 
         
            -
                    mask = np.concatenate(mask, axis=2)
         
     | 
| 1405 | 
         
            -
                
         
     | 
| 1406 | 
         
            -
                mask = mask[:, :, :n_frame]
         
     | 
| 1407 | 
         
            -
             
     | 
| 1408 | 
         
            -
                #Post Proc
         
     | 
| 1409 | 
         
            -
                if is_deverber:
         
     | 
| 1410 | 
         
            -
                    v_spec = mask * X_mag * np.exp(1.j * X_phase)
         
     | 
| 1411 | 
         
            -
                    y_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
         
     | 
| 1412 | 
         
            -
                else:
         
     | 
| 1413 | 
         
            -
                    v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
         
     | 
| 1414 | 
         
            -
             
     | 
| 1415 | 
         
            -
                if mp is None:
         
     | 
| 1416 | 
         
            -
                    wave = spec_utils.spectrogram_to_wave_old(v_spec, hop_length=1024)
         
     | 
| 1417 | 
         
            -
                else:
         
     | 
| 1418 | 
         
            -
                    wave = spec_utils.cmb_spectrogram_to_wave(v_spec, mp, is_v51_model=True).T
         
     | 
| 1419 | 
         
            -
                    
         
     | 
| 1420 | 
         
            -
                wave = spec_utils.match_array_shapes(wave, X)
         
     | 
| 1421 | 
         
            -
             
     | 
| 1422 | 
         
            -
                if is_deverber:
         
     | 
| 1423 | 
         
            -
                    wave_2 = spec_utils.cmb_spectrogram_to_wave(y_spec, mp, is_v51_model=True).T
         
     | 
| 1424 | 
         
            -
                    wave_2 = spec_utils.match_array_shapes(wave_2, X)
         
     | 
| 1425 | 
         
            -
                    return wave, wave_2
         
     | 
| 1426 | 
         
            -
                else:
         
     | 
| 1427 | 
         
            -
                    return wave
         
     | 
| 1428 | 
         
            -
             
     | 
| 1429 | 
         
            -
            def loading_mix(X, mp):
         
     | 
| 1430 | 
         
            -
             
     | 
| 1431 | 
         
            -
                X_wave, X_spec_s = {}, {}
         
     | 
| 1432 | 
         
            -
                
         
     | 
| 1433 | 
         
            -
                bands_n = len(mp.param['band'])
         
     | 
| 1434 | 
         
            -
                
         
     | 
| 1435 | 
         
            -
                for d in range(bands_n, 0, -1):        
         
     | 
| 1436 | 
         
            -
                    bp = mp.param['band'][d]
         
     | 
| 1437 | 
         
            -
                
         
     | 
| 1438 | 
         
            -
                    if OPERATING_SYSTEM == 'Darwin':
         
     | 
| 1439 | 
         
            -
                        wav_resolution = 'polyphase' if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp['res_type']
         
     | 
| 1440 | 
         
            -
                    else:
         
     | 
| 1441 | 
         
            -
                        wav_resolution = 'polyphase'#bp['res_type']
         
     | 
| 1442 | 
         
            -
                
         
     | 
| 1443 | 
         
            -
                    if d == bands_n: # high-end band
         
     | 
| 1444 | 
         
            -
                        X_wave[d] = X
         
     | 
| 1445 | 
         
            -
             
     | 
| 1446 | 
         
            -
                    else: # lower bands
         
     | 
| 1447 | 
         
            -
                        X_wave[d] = librosa.resample(X_wave[d+1], mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
         
     | 
| 1448 | 
         
            -
                        
         
     | 
| 1449 | 
         
            -
                    X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], mp, band=d, is_v51_model=True)
         
     | 
| 1450 | 
         
            -
                    
         
     | 
| 1451 | 
         
            -
                    # if d == bands_n and is_high_end_process:
         
     | 
| 1452 | 
         
            -
                    #     input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
         
     | 
| 1453 | 
         
            -
                    #     input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :]
         
     | 
| 1454 | 
         
            -
             
     | 
| 1455 | 
         
            -
                X_spec = spec_utils.combine_spectrograms(X_spec_s, mp)
         
     | 
| 1456 | 
         
            -
                
         
     | 
| 1457 | 
         
            -
                del X_wave, X_spec_s
         
     | 
| 1458 | 
         
            -
             
     | 
| 1459 | 
         
            -
                return X_spec
         
     | 
| 
         | 
|
| 6 | 
         
             
            from demucs.pretrained import get_model as _gm
         
     | 
| 7 | 
         
             
            from demucs.utils import apply_model_v1
         
     | 
| 8 | 
         
             
            from demucs.utils import apply_model_v2
         
     | 
| 
         | 
|
| 9 | 
         
             
            from lib_v5 import spec_utils
         
     | 
| 10 | 
         
             
            from lib_v5.vr_network import nets
         
     | 
| 11 | 
         
             
            from lib_v5.vr_network import nets_new
         
     | 
| 12 | 
         
            +
            #from lib_v5.vr_network.model_param_init import ModelParameters
         
     | 
| 13 | 
         
             
            from pathlib import Path
         
     | 
| 14 | 
         
             
            from gui_data.constants import *
         
     | 
| 15 | 
         
             
            from gui_data.error_handling import *
         
     | 
| 
         | 
|
| 16 | 
         
             
            import audioread
         
     | 
| 17 | 
         
             
            import gzip
         
     | 
| 18 | 
         
             
            import librosa
         
     | 
| 
         | 
|
| 24 | 
         
             
            import warnings
         
     | 
| 25 | 
         
             
            import pydub
         
     | 
| 26 | 
         
             
            import soundfile as sf
         
     | 
| 27 | 
         
            +
            import traceback
         
     | 
| 28 | 
         
             
            import lib_v5.mdxnet as MdxnetSet
         
     | 
| 29 | 
         
            +
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
            if TYPE_CHECKING:
         
     | 
| 31 | 
         
             
                from UVR import ModelData
         
     | 
| 32 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 33 | 
         
             
            warnings.filterwarnings("ignore")
         
     | 
| 34 | 
         
             
            cpu = torch.device('cpu')
         
     | 
| 35 | 
         | 
| 36 | 
         
             
            class SeperateAttributes:
         
     | 
| 37 | 
         
            +
                def __init__(self, model_data: ModelData, process_data: dict, main_model_primary_stem_4_stem=None, main_process_method=None):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 38 | 
         | 
| 39 | 
         
             
                    self.list_all_models: list
         
     | 
| 40 | 
         
             
                    self.process_data = process_data
         
     | 
| 41 | 
         
             
                    self.progress_value = 0
         
     | 
| 42 | 
         
             
                    self.set_progress_bar = process_data['set_progress_bar']
         
     | 
| 43 | 
         
             
                    self.write_to_console = process_data['write_to_console']
         
     | 
| 44 | 
         
            +
                    self.audio_file = process_data['audio_file']
         
     | 
| 45 | 
         
            +
                    self.audio_file_base = process_data['audio_file_base']
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 46 | 
         
             
                    self.export_path = process_data['export_path']
         
     | 
| 47 | 
         
             
                    self.cached_source_callback = process_data['cached_source_callback']
         
     | 
| 48 | 
         
             
                    self.cached_model_source_holder = process_data['cached_model_source_holder']
         
     | 
| 49 | 
         
             
                    self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
         
     | 
| 50 | 
         
             
                    self.list_all_models = process_data['list_all_models']
         
     | 
| 51 | 
         
             
                    self.process_iteration = process_data['process_iteration']
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 52 | 
         
             
                    self.mixer_path = model_data.mixer_path
         
     | 
| 53 | 
         
             
                    self.model_samplerate = model_data.model_samplerate
         
     | 
| 54 | 
         
             
                    self.model_capacity = model_data.model_capacity
         
     | 
| 
         | 
|
| 70 | 
         
             
                    self.is_ensemble_mode = model_data.is_ensemble_mode
         
     | 
| 71 | 
         
             
                    self.secondary_model = model_data.secondary_model #
         
     | 
| 72 | 
         
             
                    self.primary_model_primary_stem = model_data.primary_model_primary_stem
         
     | 
| 
         | 
|
| 73 | 
         
             
                    self.primary_stem = model_data.primary_stem #
         
     | 
| 74 | 
         
             
                    self.secondary_stem = model_data.secondary_stem #
         
     | 
| 75 | 
         
             
                    self.is_invert_spec = model_data.is_invert_spec #
         
     | 
| 
         | 
|
| 76 | 
         
             
                    self.is_mixer_mode = model_data.is_mixer_mode #
         
     | 
| 77 | 
         
             
                    self.secondary_model_scale = model_data.secondary_model_scale #
         
     | 
| 78 | 
         
             
                    self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
         
     | 
| 
         | 
|
| 82 | 
         
             
                    self.secondary_source = None
         
     | 
| 83 | 
         
             
                    self.secondary_source_primary = None
         
     | 
| 84 | 
         
             
                    self.secondary_source_secondary = None
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 85 | 
         | 
| 86 | 
         
            +
                    if not model_data.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 87 | 
         
            +
                        if process_data['is_ensemble_master'] and not self.is_4_stem_ensemble:
         
     | 
| 88 | 
         
            +
                            if not model_data.ensemble_primary_stem == self.primary_stem:
         
     | 
| 89 | 
         
            +
                                self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
         
     | 
| 90 | 
         
            +
                        
         
     | 
| 91 | 
         
            +
                        if self.is_secondary_model and not process_data['is_ensemble_master']:
         
     | 
| 92 | 
         
            +
                            if not self.primary_model_primary_stem == self.primary_stem and not main_model_primary_stem_4_stem:
         
     | 
| 93 | 
         
            +
                                self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
         
     | 
| 94 | 
         
            +
                                
         
     | 
| 95 | 
         
            +
                        if main_model_primary_stem_4_stem:
         
     | 
| 96 | 
         
            +
                            self.is_primary_stem_only = True if main_model_primary_stem_4_stem == self.primary_stem else False
         
     | 
| 97 | 
         
            +
                            self.is_secondary_stem_only = True if not main_model_primary_stem_4_stem == self.primary_stem else False
         
     | 
| 98 | 
         | 
| 99 | 
         
            +
                        if self.is_pre_proc_model:
         
     | 
| 100 | 
         
            +
                            self.is_primary_stem_only = True if self.primary_stem == INST_STEM else False
         
     | 
| 101 | 
         
            +
                            self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 102 | 
         | 
| 103 | 
         
             
                    if model_data.process_method == MDX_ARCH_TYPE:
         
     | 
| 104 | 
         
             
                        self.is_mdx_ckpt = model_data.is_mdx_ckpt
         
     | 
| 105 | 
         
             
                        self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
         
     | 
| 106 | 
         
            +
                        self.is_denoise = model_data.is_denoise
         
     | 
| 
         | 
|
| 
         | 
|
| 107 | 
         
             
                        self.mdx_batch_size = model_data.mdx_batch_size
         
     | 
| 108 | 
         
             
                        self.compensate = model_data.compensate
         
     | 
| 109 | 
         
            +
                        self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 110 | 
         
             
                        self.n_fft = model_data.mdx_n_fft_scale_set
         
     | 
| 111 | 
         
             
                        self.chunks = model_data.chunks
         
     | 
| 112 | 
         
             
                        self.margin = model_data.margin
         
     | 
| 113 | 
         
             
                        self.adjust = 1
         
     | 
| 114 | 
         
             
                        self.dim_c = 4
         
     | 
| 115 | 
         
             
                        self.hop = 1024
         
     | 
| 116 | 
         
            +
                        
         
     | 
| 117 | 
         
            +
                        if self.is_gpu_conversion >= 0 and torch.cuda.is_available():
         
     | 
| 118 | 
         
            +
                            self.device, self.run_type = torch.device('cuda:0'), ['CUDAExecutionProvider']
         
     | 
| 119 | 
         
            +
                        else:
         
     | 
| 120 | 
         
            +
                            self.device, self.run_type = torch.device('cpu'), ['CPUExecutionProvider']
         
     | 
| 121 | 
         | 
| 122 | 
         
             
                    if model_data.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 123 | 
         
             
                        self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
         
     | 
| 124 | 
         
             
                        self.secondary_model_4_stem = model_data.secondary_model_4_stem
         
     | 
| 125 | 
         
             
                        self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale
         
     | 
| 126 | 
         
            +
                        self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
         
     | 
| 127 | 
         
            +
                        self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
         
     | 
| 128 | 
         
             
                        self.is_chunk_demucs = model_data.is_chunk_demucs
         
     | 
| 129 | 
         
             
                        self.segment = model_data.segment
         
     | 
| 130 | 
         
             
                        self.demucs_version = model_data.demucs_version
         
     | 
| 
         | 
|
| 133 | 
         
             
                        self.is_demucs_combine_stems = model_data.is_demucs_combine_stems
         
     | 
| 134 | 
         
             
                        self.demucs_stem_count = model_data.demucs_stem_count
         
     | 
| 135 | 
         
             
                        self.pre_proc_model = model_data.pre_proc_model
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 136 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 137 | 
         
             
                        if self.is_secondary_model and not process_data['is_ensemble_master']:
         
     | 
| 138 | 
         
             
                            if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM:
         
     | 
| 139 | 
         
             
                                self.primary_stem = VOCAL_STEM
         
     | 
| 140 | 
         
             
                                self.secondary_stem = INST_STEM
         
     | 
| 141 | 
         
             
                            else:
         
     | 
| 142 | 
         
             
                                self.primary_stem = model_data.primary_model_primary_stem
         
     | 
| 143 | 
         
            +
                                self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
         
     | 
| 144 | 
         
            +
                        
         
     | 
| 145 | 
         
            +
                        if self.is_chunk_demucs:
         
     | 
| 146 | 
         
            +
                            self.chunks_demucs = model_data.chunks_demucs
         
     | 
| 147 | 
         
            +
                            self.margin_demucs = model_data.margin_demucs
         
     | 
| 148 | 
         
            +
                        else:
         
     | 
| 149 | 
         
            +
                            self.chunks_demucs = 0
         
     | 
| 150 | 
         
            +
                            self.margin_demucs = 44100
         
     | 
| 151 | 
         
            +
                            
         
     | 
| 152 | 
         
             
                        self.shifts = model_data.shifts
         
     | 
| 153 | 
         
             
                        self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True
         
     | 
| 154 | 
         
            +
                        self.overlap = model_data.overlap
         
     | 
| 155 | 
         
             
                        self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename)
         
     | 
| 156 | 
         | 
| 157 | 
         
             
                    if model_data.process_method == VR_ARCH_TYPE:
         
     | 
| 
         | 
|
| 158 | 
         
             
                        self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename)
         
     | 
| 159 | 
         
             
                        self.mp = model_data.vr_model_param
         
     | 
| 160 | 
         
             
                        self.high_end_process = model_data.is_high_end_process
         
     | 
| 
         | 
|
| 164 | 
         
             
                        self.batch_size = model_data.batch_size
         
     | 
| 165 | 
         
             
                        self.window_size = model_data.window_size
         
     | 
| 166 | 
         
             
                        self.input_high_end_h = None
         
     | 
| 
         | 
|
| 167 | 
         
             
                        self.post_process_threshold = model_data.post_process_threshold
         
     | 
| 168 | 
         
             
                        self.aggressiveness = {'value': model_data.aggression_setting, 
         
     | 
| 169 | 
         
             
                                               'split_bin': self.mp.param['band'][1]['crop_stop'], 
         
     | 
| 170 | 
         
             
                                               'aggr_correction': self.mp.param.get('aggr_correction')}
         
     | 
| 171 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 172 | 
         
             
                def start_inference_console_write(self):
         
     | 
| 173 | 
         
            +
                    
         
     | 
| 174 | 
         
            +
                    if self.is_secondary_model and not self.is_pre_proc_model:
         
     | 
| 175 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
         
     | 
| 176 | 
         | 
| 177 | 
         
             
                    if self.is_pre_proc_model:
         
     | 
| 178 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 179 | 
         | 
| 180 | 
         
             
                def running_inference_console_write(self, is_no_write=False):
         
     | 
| 181 | 
         
            +
                    
         
     | 
| 182 | 
         
             
                    self.write_to_console(DONE, base_text='') if not is_no_write else None
         
     | 
| 183 | 
         
             
                    self.set_progress_bar(0.05) if not is_no_write else None
         
     | 
| 184 | 
         | 
| 185 | 
         
            +
                    if self.is_secondary_model and not self.is_pre_proc_model:
         
     | 
| 186 | 
         
             
                        self.write_to_console(INFERENCE_STEP_1_SEC)
         
     | 
| 187 | 
         
             
                    elif self.is_pre_proc_model:
         
     | 
| 188 | 
         
             
                        self.write_to_console(INFERENCE_STEP_1_PRE)
         
     | 
| 
         | 
|
| 
         | 
|
| 189 | 
         
             
                    else:
         
     | 
| 190 | 
         
             
                        self.write_to_console(INFERENCE_STEP_1)
         
     | 
| 191 | 
         | 
| 
         | 
|
| 198 | 
         | 
| 199 | 
         
             
                        self.set_progress_bar(0.1, (0.8/length*self.progress_value))
         
     | 
| 200 | 
         | 
| 201 | 
         
            +
                def load_cached_sources(self, is_4_stem_demucs=False):
         
     | 
| 202 | 
         | 
| 203 | 
         
             
                    if self.is_secondary_model and not self.is_pre_proc_model:
         
     | 
| 204 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename))
         
     | 
| 205 | 
         
             
                    elif self.is_pre_proc_model:
         
     | 
| 206 | 
         
             
                        self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename))
         
     | 
| 207 | 
         
             
                    else:
         
     | 
| 208 | 
         
            +
                        self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    if not is_4_stem_demucs:
         
     | 
| 211 | 
         
            +
                        primary_stem, secondary_stem = gather_sources(self.primary_stem, self.secondary_stem, self.primary_sources)
         
     | 
| 212 | 
         
            +
                        
         
     | 
| 213 | 
         
            +
                        return primary_stem, secondary_stem
         
     | 
| 214 | 
         | 
| 215 | 
         
             
                def cache_source(self, secondary_sources):
         
     | 
| 216 | 
         | 
| 
         | 
|
| 225 | 
         | 
| 226 | 
         
             
                        if self.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 227 | 
         
             
                            self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename)
         
     | 
| 228 | 
         
            +
                            
         
     | 
| 229 | 
         
            +
                def write_audio(self, stem_path, stem_source, samplerate, secondary_model_source=None, model_scale=None):
         
     | 
| 230 | 
         
            +
                            
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 231 | 
         
             
                    if not self.is_secondary_model:
         
     | 
| 232 | 
         
            +
                        if self.is_secondary_model_activated:
         
     | 
| 233 | 
         
            +
                            if isinstance(secondary_model_source, np.ndarray):
         
     | 
| 234 | 
         
            +
                                secondary_model_scale = model_scale if model_scale else self.secondary_model_scale
         
     | 
| 235 | 
         
            +
                                stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 236 | 
         | 
| 237 | 
         
            +
                        sf.write(stem_path, stem_source, samplerate, subtype=self.wav_type_set)
         
     | 
| 238 | 
         
            +
                        save_format(stem_path, self.save_format, self.mp3_bit_set) if not self.is_ensemble_mode else None
         
     | 
| 239 | 
         | 
| 
         | 
|
| 240 | 
         
             
                        self.write_to_console(DONE, base_text='')
         
     | 
| 241 | 
         
            +
                        self.set_progress_bar(0.95)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 242 | 
         | 
| 243 | 
         
            +
                def run_mixer(self, mix, sources):
         
     | 
| 244 | 
         
            +
                    try:
         
     | 
| 245 | 
         
            +
                        if self.is_mixer_mode and len(sources) == 4:
         
     | 
| 246 | 
         
            +
                            mixer = MdxnetSet.Mixer(self.device, self.mixer_path).eval()
         
     | 
| 247 | 
         
            +
                            with torch.no_grad():
         
     | 
| 248 | 
         
            +
                                mix = torch.tensor(mix, dtype=torch.float32)
         
     | 
| 249 | 
         
            +
                                sources_ = torch.tensor(sources).detach()
         
     | 
| 250 | 
         
            +
                                x = torch.cat([sources_, mix.unsqueeze(0)], 0)
         
     | 
| 251 | 
         
            +
                                sources_ = mixer(x)
         
     | 
| 252 | 
         
            +
                            final_source = np.array(sources_)
         
     | 
| 253 | 
         
            +
                        else:
         
     | 
| 254 | 
         
            +
                            final_source = sources
         
     | 
| 255 | 
         
            +
                    except Exception as e:
         
     | 
| 256 | 
         
            +
                        error_name = f'{type(e).__name__}'
         
     | 
| 257 | 
         
            +
                        traceback_text = ''.join(traceback.format_tb(e.__traceback__))
         
     | 
| 258 | 
         
            +
                        message = f'{error_name}: "{e}"\n{traceback_text}"'
         
     | 
| 259 | 
         
            +
                        print('Mixer Failed: ', message)
         
     | 
| 260 | 
         
            +
                        final_source = sources
         
     | 
| 261 | 
         
            +
                        
         
     | 
| 262 | 
         
            +
                    return final_source
         
     | 
| 263 | 
         | 
| 264 | 
         
             
            class SeperateMDX(SeperateAttributes):        
         
     | 
| 265 | 
         | 
| 266 | 
         
             
                def seperate(self):
         
     | 
| 267 | 
         
             
                    samplerate = 44100
         
     | 
| 268 | 
         
            +
                      
         
     | 
| 269 | 
         
            +
                    if self.primary_model_name == self.model_basename and self.primary_sources:
         
     | 
| 270 | 
         
            +
                        self.primary_source, self.secondary_source = self.load_cached_sources()
         
     | 
| 
         | 
|
| 271 | 
         
             
                    else:
         
     | 
| 272 | 
         
             
                        self.start_inference_console_write()
         
     | 
| 273 | 
         | 
| 
         | 
|
| 277 | 
         
             
                            separator = MdxnetSet.ConvTDFNet(**model_params)
         
     | 
| 278 | 
         
             
                            self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval()
         
     | 
| 279 | 
         
             
                        else:
         
     | 
| 280 | 
         
            +
                            ort_ = ort.InferenceSession(self.model_path, providers=self.run_type)
         
     | 
| 281 | 
         
            +
                            self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 282 | 
         | 
| 283 | 
         
            +
                        self.initialize_model_settings()
         
     | 
| 284 | 
         
             
                        self.running_inference_console_write()
         
     | 
| 285 | 
         
            +
                        mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False
         
     | 
| 286 | 
         
            +
                        mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut)
         
     | 
| 287 | 
         
            +
                        source = self.demix_base(mix, is_ckpt=self.is_mdx_ckpt)[0]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 288 | 
         
             
                        self.write_to_console(DONE, base_text='')            
         
     | 
| 289 | 
         | 
| 290 | 
         
            +
                    if self.is_secondary_model_activated:
         
     | 
| 291 | 
         
            +
                        if self.secondary_model:
         
     | 
| 292 | 
         
            +
                            self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
         
     | 
| 
         | 
|
| 293 | 
         | 
| 294 | 
         
            +
                    if not self.is_secondary_stem_only:
         
     | 
| 295 | 
         
            +
                        self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 296 | 
         
            +
                        primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
         
     | 
| 297 | 
         
            +
                        if not isinstance(self.primary_source, np.ndarray):
         
     | 
| 298 | 
         
            +
                            self.primary_source = spec_utils.normalize(source, self.is_normalization).T
         
     | 
| 299 | 
         
            +
                        self.primary_source_map = {self.primary_stem: self.primary_source}
         
     | 
| 300 | 
         
            +
                        self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
             
                    if not self.is_primary_stem_only:
         
     | 
| 303 | 
         
            +
                        self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 304 | 
         
             
                        secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
         
     | 
| 305 | 
         
             
                        if not isinstance(self.secondary_source, np.ndarray):
         
     | 
| 306 | 
         
            +
                            raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix
         
     | 
| 307 | 
         
            +
                            self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source*self.compensate, raw_mix, self.is_normalization)
         
     | 
| 308 | 
         | 
| 309 | 
         
            +
                            if self.is_invert_spec:
         
     | 
| 310 | 
         
            +
                                self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
         
     | 
| 311 | 
         
            +
                            else:
         
     | 
| 312 | 
         
            +
                                self.secondary_source = (-self.secondary_source.T+raw_mix.T)
         
     | 
| 313 | 
         | 
| 314 | 
         
            +
                        self.secondary_source_map = {self.secondary_stem: self.secondary_source}
         
     | 
| 315 | 
         
            +
                        self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 316 | 
         | 
| 317 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 318 | 
         
             
                    secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 
         | 
|
| 
         | 
|
| 319 | 
         | 
| 320 | 
         
            +
                    self.cache_source(secondary_sources)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    if self.is_secondary_model:
         
     | 
| 323 | 
         
             
                        return secondary_sources
         
     | 
| 324 | 
         | 
| 325 | 
         
             
                def initialize_model_settings(self):
         
     | 
| 326 | 
         
             
                    self.n_bins = self.n_fft//2+1
         
     | 
| 327 | 
         
             
                    self.trim = self.n_fft//2
         
     | 
| 328 | 
         
            +
                    self.chunk_size = self.hop * (self.dim_t-1)
         
     | 
| 329 | 
         
            +
                    self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device)
         
     | 
| 330 | 
         
            +
                    self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device)
         
     | 
| 331 | 
         
             
                    self.gen_size = self.chunk_size-2*self.trim
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 332 | 
         | 
| 333 | 
         
            +
                def initialize_mix(self, mix, is_ckpt=False):
         
     | 
| 334 | 
         
            +
                    if is_ckpt:
         
     | 
| 335 | 
         
            +
                        pad = self.gen_size + self.trim - ((mix.shape[-1]) % self.gen_size)
         
     | 
| 336 | 
         
            +
                        mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'),mix, np.zeros((2, pad), dtype='float32')), 1)
         
     | 
| 337 | 
         
            +
                        num_chunks = mixture.shape[-1] // self.gen_size
         
     | 
| 338 | 
         
            +
                        mix_waves = [mixture[:, i * self.gen_size: i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
         
     | 
| 339 | 
         
             
                    else:
         
     | 
| 340 | 
         
            +
                        mix_waves = []
         
     | 
| 341 | 
         
            +
                        n_sample = mix.shape[1]
         
     | 
| 342 | 
         
            +
                        pad = self.gen_size - n_sample%self.gen_size
         
     | 
| 343 | 
         
            +
                        mix_p = np.concatenate((np.zeros((2,self.trim)), mix, np.zeros((2,pad)), np.zeros((2,self.trim))), 1)
         
     | 
| 344 | 
         
            +
                        i = 0
         
     | 
| 345 | 
         
            +
                        while i < n_sample + pad:
         
     | 
| 346 | 
         
            +
                            waves = np.array(mix_p[:, i:i+self.chunk_size])
         
     | 
| 347 | 
         
            +
                            mix_waves.append(waves)
         
     | 
| 348 | 
         
            +
                            i += self.gen_size
         
     | 
| 349 | 
         
            +
                            
         
     | 
| 350 | 
         
            +
                    mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 351 | 
         | 
| 352 | 
         
            +
                    return mix_waves, pad
         
     | 
| 353 | 
         
            +
                
         
     | 
| 354 | 
         
            +
                def demix_base(self, mix, is_ckpt=False, is_match_mix=False):
         
     | 
| 355 | 
         
            +
                    chunked_sources = []
         
     | 
| 356 | 
         
            +
                    for slice in mix:
         
     | 
| 357 | 
         
            +
                        sources = []
         
     | 
| 358 | 
         
            +
                        tar_waves_ = []
         
     | 
| 359 | 
         
            +
                        mix_p = mix[slice]
         
     | 
| 360 | 
         
            +
                        mix_waves, pad = self.initialize_mix(mix_p, is_ckpt=is_ckpt)
         
     | 
| 361 | 
         
            +
                        mix_waves = mix_waves.split(self.mdx_batch_size)
         
     | 
| 362 | 
         
            +
                        pad = mix_p.shape[-1] if is_ckpt else -pad
         
     | 
| 363 | 
         
             
                        with torch.no_grad():
         
     | 
| 364 | 
         
             
                            for mix_wave in mix_waves:
         
     | 
| 365 | 
         
            +
                                self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix)
         
     | 
| 366 | 
         
            +
                                tar_waves = self.run_model(mix_wave, is_ckpt=is_ckpt, is_match_mix=is_match_mix)
         
     | 
| 367 | 
         
            +
                                tar_waves_.append(tar_waves)
         
     | 
| 368 | 
         
            +
                            tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim] if is_ckpt else tar_waves_
         
     | 
| 369 | 
         
            +
                            tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :pad]
         
     | 
| 370 | 
         
            +
                            start = 0 if slice == 0 else self.margin
         
     | 
| 371 | 
         
            +
                            end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin
         
     | 
| 372 | 
         
            +
                            sources.append(tar_waves[:,start:end]*(1/self.adjust))
         
     | 
| 373 | 
         
            +
                        chunked_sources.append(sources)
         
     | 
| 374 | 
         
            +
                    sources = np.concatenate(chunked_sources, axis=-1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 375 | 
         | 
| 376 | 
         
            +
                    return sources
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 377 | 
         | 
| 378 | 
         
            +
                def run_model(self, mix, is_ckpt=False, is_match_mix=False):
         
     | 
| 379 | 
         | 
| 380 | 
         
             
                    spek = self.stft(mix.to(self.device))*self.adjust
         
     | 
| 381 | 
         
             
                    spek[:, :, :3, :] *= 0 
         
     | 
| 
         | 
|
| 385 | 
         
             
                    else:
         
     | 
| 386 | 
         
             
                        spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
         
     | 
| 387 | 
         | 
| 388 | 
         
            +
                    if is_ckpt:
         
     | 
| 389 | 
         
            +
                        return self.istft(spec_pred).cpu().detach().numpy()
         
     | 
| 390 | 
         
            +
                    else: 
         
     | 
| 391 | 
         
            +
                        return self.istft(torch.tensor(spec_pred).to(self.device)).to(cpu)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1).numpy()
         
     | 
| 392 | 
         
            +
                
         
     | 
| 393 | 
         
            +
                def stft(self, x):
         
     | 
| 394 | 
         
            +
                    x = x.reshape([-1, self.chunk_size])
         
     | 
| 395 | 
         
            +
                    x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
         
     | 
| 396 | 
         
            +
                    x=torch.view_as_real(x)
         
     | 
| 397 | 
         
            +
                    x = x.permute([0,3,1,2])
         
     | 
| 398 | 
         
            +
                    x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
         
     | 
| 399 | 
         
            +
                    return x[:,:,:self.dim_f]
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                def istft(self, x, freq_pad=None):
         
     | 
| 402 | 
         
            +
                    freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
         
     | 
| 403 | 
         
            +
                    x = torch.cat([x, freq_pad], -2)
         
     | 
| 404 | 
         
            +
                    x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
         
     | 
| 405 | 
         
            +
                    x = x.permute([0,2,3,1])
         
     | 
| 406 | 
         
            +
                    x=x.contiguous()
         
     | 
| 407 | 
         
            +
                    x=torch.view_as_complex(x)
         
     | 
| 408 | 
         
            +
                    x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
         
     | 
| 409 | 
         
            +
                    return x.reshape([-1,2,self.chunk_size])
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
            class SeperateDemucs(SeperateAttributes):        
         
     | 
| 412 | 
         | 
| 413 | 
         
             
                def seperate(self):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 414 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 415 | 
         
             
                    samplerate = 44100
         
     | 
| 416 | 
         
             
                    source = None
         
     | 
| 417 | 
         
             
                    model_scale = None
         
     | 
| 418 | 
         
             
                    stem_source = None
         
     | 
| 419 | 
         
             
                    stem_source_secondary = None
         
     | 
| 420 | 
         
             
                    inst_mix = None
         
     | 
| 421 | 
         
            +
                    inst_raw_mix = None
         
     | 
| 422 | 
         
            +
                    raw_mix = None
         
     | 
| 423 | 
         
             
                    inst_source = None
         
     | 
| 424 | 
         
             
                    is_no_write = False
         
     | 
| 425 | 
         
             
                    is_no_piano_guitar = False
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    if self.primary_model_name == self.model_basename and type(self.primary_sources) is dict and not self.pre_proc_model:
         
     | 
| 428 | 
         
            +
                        self.primary_source, self.secondary_source = self.load_cached_sources()
         
     | 
| 429 | 
         
            +
                    elif self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model:
         
     | 
| 430 | 
         
             
                        source = self.primary_sources
         
     | 
| 431 | 
         
            +
                        self.load_cached_sources(is_4_stem_demucs=True)
         
     | 
| 432 | 
         
             
                    else:
         
     | 
| 433 | 
         
             
                        self.start_inference_console_write()
         
     | 
| 
         | 
|
| 434 | 
         | 
| 435 | 
         
            +
                        if self.is_gpu_conversion >= 0:
         
     | 
| 436 | 
         
            +
                            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
         
     | 
| 437 | 
         
            +
                        else:
         
     | 
| 438 | 
         
            +
                            self.device = torch.device('cpu')
         
     | 
| 439 | 
         
            +
                        
         
     | 
| 440 | 
         
             
                        if self.demucs_version == DEMUCS_V1:
         
     | 
| 441 | 
         
             
                            if str(self.model_path).endswith(".gz"):
         
     | 
| 442 | 
         
             
                                self.model_path = gzip.open(self.model_path, "rb")
         
     | 
| 
         | 
|
| 462 | 
         
             
                                is_no_write = True
         
     | 
| 463 | 
         
             
                                self.write_to_console(DONE, base_text='')
         
     | 
| 464 | 
         
             
                                mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
         
     | 
| 465 | 
         
            +
                                inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs)
         
     | 
| 466 | 
         
             
                                self.process_iteration()
         
     | 
| 467 | 
         
             
                                self.running_inference_console_write(is_no_write=is_no_write)
         
     | 
| 468 | 
         
             
                                inst_source = self.demix_demucs(inst_mix)
         
     | 
| 469 | 
         
            +
                                inst_source = self.run_mixer(inst_raw_mix, inst_source)
         
     | 
| 470 | 
         
             
                                self.process_iteration()
         
     | 
| 471 | 
         | 
| 472 | 
         
             
                        self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None
         
     | 
| 473 | 
         
            +
                        mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs)
         
     | 
| 474 | 
         | 
| 475 | 
         
             
                        if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
         
     | 
| 476 | 
         
             
                            source = self.primary_sources
         
     | 
| 477 | 
         
             
                        else:
         
     | 
| 478 | 
         
             
                            source = self.demix_demucs(mix)
         
     | 
| 479 | 
         
            +
                            source = self.run_mixer(raw_mix, source)
         
     | 
| 480 | 
         | 
| 481 | 
         
             
                        self.write_to_console(DONE, base_text='')
         
     | 
| 482 | 
         | 
| 483 | 
         
             
                        del self.demucs
         
     | 
| 484 | 
         
            +
                        torch.cuda.empty_cache()
         
     | 
| 485 | 
         | 
| 486 | 
         
             
                    if isinstance(inst_source, np.ndarray):
         
     | 
| 487 | 
         
             
                        source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
         
     | 
| 
         | 
|
| 489 | 
         
             
                        source = inst_source
         
     | 
| 490 | 
         | 
| 491 | 
         
             
                    if isinstance(source, np.ndarray):
         
     | 
| 
         | 
|
| 492 | 
         
             
                        if len(source) == 2:
         
     | 
| 493 | 
         
             
                            self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
         
     | 
| 494 | 
         
             
                        else:
         
     | 
| 
         | 
|
| 503 | 
         
             
                                    other_source += i
         
     | 
| 504 | 
         
             
                                source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source)
         
     | 
| 505 | 
         
             
                                source[self.demucs_source_map[OTHER_STEM]] = source_reshape
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                    if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble:
         
     | 
| 508 | 
         
             
                        self.cache_source(source)
         
     | 
| 509 | 
         
            +
                        
         
     | 
| 
         | 
|
| 510 | 
         
             
                        for stem_name, stem_value in self.demucs_source_map.items():
         
     | 
| 511 | 
         
             
                            if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
         
     | 
| 512 | 
         
             
                                if self.secondary_model_4_stem[stem_value]:
         
     | 
| 513 | 
         
             
                                    model_scale = self.secondary_model_4_stem_scale[stem_value]
         
     | 
| 514 | 
         
            +
                                    stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_4_stem_demucs=True)
         
     | 
| 515 | 
         
             
                                    if isinstance(stem_source_secondary, np.ndarray):
         
     | 
| 516 | 
         
            +
                                        stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value]
         
     | 
| 517 | 
         
            +
                                        stem_source_secondary = spec_utils.normalize(stem_source_secondary, self.is_normalization).T
         
     | 
| 518 | 
         
             
                                    elif type(stem_source_secondary) is dict:
         
     | 
| 519 | 
         
             
                                        stem_source_secondary = stem_source_secondary[stem_name]
         
     | 
| 520 | 
         | 
| 521 | 
         
             
                            stem_source_secondary = None if stem_value >= 4 else stem_source_secondary
         
     | 
| 522 | 
         
            +
                            self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 523 | 
         
             
                            stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav')
         
     | 
| 524 | 
         
            +
                            stem_source = spec_utils.normalize(source[stem_value], self.is_normalization).T
         
     | 
| 525 | 
         
            +
                            self.write_audio(stem_path, stem_source, samplerate, secondary_model_source=stem_source_secondary, model_scale=model_scale)
         
     | 
| 526 | 
         
            +
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 527 | 
         
             
                        if self.is_secondary_model:    
         
     | 
| 528 | 
         
             
                            return source
         
     | 
| 529 | 
         
             
                    else:
         
     | 
| 530 | 
         
            +
                        if self.is_secondary_model_activated:
         
     | 
| 531 | 
         
            +
                            if self.secondary_model:
         
     | 
| 532 | 
         
             
                                self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                        if not self.is_secondary_stem_only:
         
     | 
| 535 | 
         
            +
                            self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 536 | 
         
            +
                            primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
         
     | 
| 537 | 
         
            +
                            if not isinstance(self.primary_source, np.ndarray):
         
     | 
| 538 | 
         
            +
                                self.primary_source = spec_utils.normalize(source[self.demucs_source_map[self.primary_stem]], self.is_normalization).T
         
     | 
| 539 | 
         
            +
                            self.primary_source_map = {self.primary_stem: self.primary_source}
         
     | 
| 540 | 
         
            +
                            self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
             
                        if not self.is_primary_stem_only:
         
     | 
| 543 | 
         
             
                            def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False):
         
     | 
| 544 | 
         
             
                                secondary_source = self.secondary_source if not is_inst_mixture else None
         
     | 
| 545 | 
         
            +
                                self.write_to_console(f'{SAVING_STEM[0]}{sec_stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 546 | 
         
             
                                secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav')
         
     | 
| 547 | 
         
             
                                secondary_source_secondary = None
         
     | 
| 548 | 
         | 
| 
         | 
|
| 558 | 
         
             
                                        secondary_source = np.zeros_like(source[0])
         
     | 
| 559 | 
         
             
                                        for i in source:
         
     | 
| 560 | 
         
             
                                            secondary_source += i
         
     | 
| 561 | 
         
            +
                                        secondary_source = spec_utils.normalize(secondary_source, self.is_normalization).T
         
     | 
| 562 | 
         
             
                                    else:
         
     | 
| 563 | 
         
             
                                        if not isinstance(raw_mixture, np.ndarray):
         
     | 
| 564 | 
         
            +
                                            raw_mixture = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs, is_missing_mix=True)
         
     | 
| 565 | 
         | 
| 566 | 
         
            +
                                        secondary_source, raw_mixture = spec_utils.normalize_two_stem(source[self.demucs_source_map[self.primary_stem]], raw_mixture, self.is_normalization)
         
     | 
| 567 | 
         | 
| 568 | 
         
             
                                        if self.is_invert_spec:
         
     | 
| 569 | 
         
             
                                            secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source)
         
     | 
| 
         | 
|
| 574 | 
         
             
                                if not is_inst_mixture:
         
     | 
| 575 | 
         
             
                                    self.secondary_source = secondary_source
         
     | 
| 576 | 
         
             
                                    secondary_source_secondary = self.secondary_source_secondary
         
     | 
| 
         | 
|
| 577 | 
         
             
                                    self.secondary_source_map = {self.secondary_stem: self.secondary_source}
         
     | 
| 578 | 
         | 
| 579 | 
         
            +
                                self.write_audio(secondary_stem_path, secondary_source, samplerate, secondary_source_secondary)
         
     | 
| 580 | 
         | 
| 581 | 
         
            +
                            secondary_save(self.secondary_stem, source, raw_mixture=raw_mix)
         
     | 
| 582 | 
         | 
| 583 | 
         
             
                            if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
         
     | 
| 584 | 
         
            +
                                secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 585 | 
         | 
| 586 | 
         
             
                        secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                        self.cache_source(secondary_sources)
         
     | 
| 589 | 
         | 
| 590 | 
         
             
                        if self.is_secondary_model:    
         
     | 
| 591 | 
         
             
                            return secondary_sources
         
     | 
| 592 | 
         | 
| 593 | 
         
             
                def demix_demucs(self, mix):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 594 | 
         
             
                    processed = {}
         
     | 
| 595 | 
         
            +
             
     | 
| 596 | 
         
            +
                    set_progress_bar = None if self.is_chunk_demucs else self.set_progress_bar
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
                    for nmix in mix:
         
     | 
| 599 | 
         
            +
                        self.progress_value += 1
         
     | 
| 600 | 
         
            +
                        self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if self.is_chunk_demucs else None
         
     | 
| 601 | 
         
            +
                        cmix = mix[nmix]
         
     | 
| 602 | 
         
            +
                        cmix = torch.tensor(cmix, dtype=torch.float32)
         
     | 
| 603 | 
         
            +
                        ref = cmix.mean(0)        
         
     | 
| 604 | 
         
            +
                        cmix = (cmix - ref.mean()) / ref.std()
         
     | 
| 605 | 
         
            +
                        mix_infer = cmix 
         
     | 
| 606 | 
         
            +
                        
         
     | 
| 607 | 
         
            +
                        with torch.no_grad():
         
     | 
| 608 | 
         
            +
                            if self.demucs_version == DEMUCS_V1:
         
     | 
| 609 | 
         
            +
                                sources = apply_model_v1(self.demucs, 
         
     | 
| 610 | 
         
            +
                                                            mix_infer.to(self.device), 
         
     | 
| 611 | 
         
            +
                                                            self.shifts, 
         
     | 
| 612 | 
         
            +
                                                            self.is_split_mode,
         
     | 
| 613 | 
         
            +
                                                            set_progress_bar=set_progress_bar)
         
     | 
| 614 | 
         
            +
                            elif self.demucs_version == DEMUCS_V2:
         
     | 
| 615 | 
         
            +
                                sources = apply_model_v2(self.demucs, 
         
     | 
| 616 | 
         
            +
                                                            mix_infer.to(self.device), 
         
     | 
| 617 | 
         
            +
                                                            self.shifts,
         
     | 
| 618 | 
         
            +
                                                            self.is_split_mode,
         
     | 
| 619 | 
         
            +
                                                            self.overlap,
         
     | 
| 620 | 
         
            +
                                                            set_progress_bar=set_progress_bar)
         
     | 
| 621 | 
         
            +
                            else:
         
     | 
| 622 | 
         
            +
                                sources = apply_model(self.demucs, 
         
     | 
| 623 | 
         
            +
                                                        mix_infer[None], 
         
     | 
| 624 | 
         
             
                                                        self.shifts,
         
     | 
| 625 | 
         
             
                                                        self.is_split_mode,
         
     | 
| 626 | 
         
             
                                                        self.overlap,
         
     | 
| 627 | 
         
            +
                                                        static_shifts=1 if self.shifts == 0 else self.shifts,
         
     | 
| 628 | 
         
            +
                                                        set_progress_bar=set_progress_bar,
         
     | 
| 629 | 
         
            +
                                                        device=self.device)[0]
         
     | 
| 630 | 
         
            +
                        
         
     | 
| 631 | 
         
            +
                        sources = (sources * ref.std() + ref.mean()).cpu().numpy()
         
     | 
| 632 | 
         
            +
                        sources[[0,1]] = sources[[1,0]]
         
     | 
| 633 | 
         
            +
                        start = 0 if nmix == 0 else self.margin_demucs
         
     | 
| 634 | 
         
            +
                        end = None if nmix == list(mix.keys())[::-1][0] else -self.margin_demucs
         
     | 
| 635 | 
         
            +
                        if self.margin_demucs == 0:
         
     | 
| 636 | 
         
            +
                            end = None
         
     | 
| 637 | 
         
            +
                        processed[nmix] = sources[:,:,start:end].copy()
         
     | 
| 638 | 
         
            +
                        sources = list(processed.values())
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 639 | 
         
             
                    sources = np.concatenate(sources, axis=-1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 640 | 
         | 
| 641 | 
         
             
                    return sources
         
     | 
| 642 | 
         | 
| 643 | 
         
             
            class SeperateVR(SeperateAttributes):        
         
     | 
| 644 | 
         | 
| 645 | 
         
             
                def seperate(self):
         
     | 
| 646 | 
         
            +
                    if self.primary_model_name == self.model_basename and self.primary_sources:
         
     | 
| 647 | 
         
            +
                        self.primary_source, self.secondary_source = self.load_cached_sources()
         
     | 
| 
         | 
|
| 648 | 
         
             
                    else:
         
     | 
| 649 | 
         
             
                        self.start_inference_console_write()
         
     | 
| 650 | 
         
            +
                        if self.is_gpu_conversion >= 0:
         
     | 
| 651 | 
         
            +
                            if OPERATING_SYSTEM == 'Darwin':
         
     | 
| 652 | 
         
            +
                                device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
         
     | 
| 653 | 
         
            +
                            else:
         
     | 
| 654 | 
         
            +
                                device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
         
     | 
| 655 | 
         
            +
                        else:
         
     | 
| 656 | 
         
            +
                            device = torch.device('cpu')
         
     | 
| 657 | 
         | 
| 658 | 
         
             
                        nn_arch_sizes = [
         
     | 
| 659 | 
         
             
                            31191, # default
         
     | 
| 
         | 
|
| 663 | 
         
             
                        nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
         
     | 
| 664 | 
         | 
| 665 | 
         
             
                        if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
         
     | 
| 666 | 
         
            +
                            self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 667 | 
         
             
                        else:
         
     | 
| 668 | 
         
             
                            self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size)
         
     | 
| 669 | 
         | 
| 
         | 
|
| 673 | 
         
             
                        self.running_inference_console_write()
         
     | 
| 674 | 
         | 
| 675 | 
         
             
                        y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
         
     | 
| 
         | 
|
| 
         | 
|
| 676 | 
         
             
                        self.write_to_console(DONE, base_text='')
         
     | 
| 677 | 
         | 
| 678 | 
         
            +
                    if self.is_secondary_model_activated:
         
     | 
| 679 | 
         
            +
                        if self.secondary_model:
         
     | 
| 680 | 
         
            +
                            self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
         
     | 
| 681 | 
         | 
| 682 | 
         
             
                    if not self.is_secondary_stem_only:
         
     | 
| 683 | 
         
            +
                        self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 684 | 
         
             
                        primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
         
     | 
| 685 | 
         
             
                        if not isinstance(self.primary_source, np.ndarray):
         
     | 
| 686 | 
         
            +
                            self.primary_source = spec_utils.normalize(self.spec_to_wav(y_spec), self.is_normalization).T
         
     | 
| 687 | 
         
             
                            if not self.model_samplerate == 44100:
         
     | 
| 688 | 
         
             
                                self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
         
     | 
| 689 | 
         | 
| 690 | 
         
            +
                        self.primary_source_map = {self.primary_stem: self.primary_source}
         
     | 
| 691 | 
         
            +
                        
         
     | 
| 692 | 
         
            +
                        self.write_audio(primary_stem_path, self.primary_source, 44100, self.secondary_source_primary)
         
     | 
| 693 | 
         | 
| 694 | 
         
             
                    if not self.is_primary_stem_only:
         
     | 
| 695 | 
         
            +
                        self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
         
     | 
| 696 | 
         
             
                        secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
         
     | 
| 697 | 
         
             
                        if not isinstance(self.secondary_source, np.ndarray):
         
     | 
| 698 | 
         
            +
                            self.secondary_source = self.spec_to_wav(v_spec)
         
     | 
| 699 | 
         
            +
                            self.secondary_source = spec_utils.normalize(self.spec_to_wav(v_spec), self.is_normalization).T
         
     | 
| 700 | 
         
             
                            if not self.model_samplerate == 44100:
         
     | 
| 701 | 
         
             
                                self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
         
     | 
| 702 | 
         | 
| 703 | 
         
            +
                        self.secondary_source_map = {self.secondary_stem: self.secondary_source}
         
     | 
| 704 | 
         
            +
                        
         
     | 
| 705 | 
         
            +
                        self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary)
         
     | 
| 706 | 
         | 
| 707 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 708 | 
         
             
                    secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
         
     | 
| 709 | 
         
            +
                    self.cache_source(secondary_sources)
         
     | 
| 710 | 
         
            +
             
     | 
| 
         | 
|
| 711 | 
         
             
                    if self.is_secondary_model:
         
     | 
| 712 | 
         
             
                        return secondary_sources
         
     | 
| 713 | 
         | 
| 
         | 
|
| 717 | 
         | 
| 718 | 
         
             
                    bands_n = len(self.mp.param['band'])
         
     | 
| 719 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 720 | 
         
             
                    for d in range(bands_n, 0, -1):        
         
     | 
| 721 | 
         
             
                        bp = self.mp.param['band'][d]
         
     | 
| 722 | 
         | 
| 
         | 
|
| 726 | 
         
             
                            wav_resolution = bp['res_type']
         
     | 
| 727 | 
         | 
| 728 | 
         
             
                        if d == bands_n: # high-end band
         
     | 
| 729 | 
         
            +
                            X_wave[d], _ = librosa.load(self.audio_file, bp['sr'], False, dtype=np.float32, res_type=wav_resolution)
         
     | 
| 
         | 
|
| 730 | 
         | 
| 731 | 
         
            +
                            if not np.any(X_wave[d]) and self.audio_file.endswith('.mp3'):
         
     | 
| 732 | 
         
            +
                                X_wave[d] = rerun_mp3(self.audio_file, bp['sr'])
         
     | 
| 733 | 
         | 
| 734 | 
         
             
                            if X_wave[d].ndim == 1:
         
     | 
| 735 | 
         
             
                                X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
         
     | 
| 736 | 
         
             
                        else: # lower bands
         
     | 
| 737 | 
         
             
                            X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
         
     | 
| 738 | 
         
            +
                            
         
     | 
| 739 | 
         
            +
                        X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], self.mp.param['mid_side'], 
         
     | 
| 740 | 
         
            +
                                                                        self.mp.param['mid_side_b2'], self.mp.param['reverse'])
         
     | 
| 741 | 
         
            +
                        
         
     | 
| 742 | 
         
             
                        if d == bands_n and self.high_end_process != 'none':
         
     | 
| 743 | 
         
             
                            self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start'])
         
     | 
| 744 | 
         
             
                            self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :]
         
     | 
| 745 | 
         | 
| 746 | 
         
            +
                    X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp)
         
     | 
| 747 | 
         | 
| 748 | 
         
            +
                    del X_wave, X_spec_s
         
     | 
| 749 | 
         | 
| 750 | 
         
             
                    return X_spec
         
     | 
| 751 | 
         | 
| 
         | 
|
| 783 | 
         
             
                        return mask
         
     | 
| 784 | 
         | 
| 785 | 
         
             
                    def postprocess(mask, X_mag, X_phase):
         
     | 
| 786 | 
         
            +
                        
         
     | 
| 787 | 
         
             
                        is_non_accom_stem = False
         
     | 
| 788 | 
         
             
                        for stem in NON_ACCOM_STEMS:
         
     | 
| 789 | 
         
             
                            if stem == self.primary_stem:
         
     | 
| 
         | 
|
| 798 | 
         
             
                        v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
         
     | 
| 799 | 
         | 
| 800 | 
         
             
                        return y_spec, v_spec
         
     | 
| 
         | 
|
| 801 | 
         
             
                    X_mag, X_phase = spec_utils.preprocess(X_spec)
         
     | 
| 802 | 
         
             
                    n_frame = X_mag.shape[2]
         
     | 
| 803 | 
         
             
                    pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
         
     | 
| 
         | 
|
| 821 | 
         
             
                    return y_spec, v_spec
         
     | 
| 822 | 
         | 
| 823 | 
         
             
                def spec_to_wav(self, spec):
         
     | 
| 824 | 
         
            +
                    
         
     | 
| 825 | 
         
            +
                    if self.high_end_process.startswith('mirroring'):        
         
     | 
| 826 | 
         
             
                        input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
         
     | 
| 827 | 
         
            +
                        wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_)       
         
     | 
| 828 | 
         
             
                    else:
         
     | 
| 829 | 
         
            +
                        wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp)
         
     | 
| 830 | 
         | 
| 831 | 
         
             
                    return wav
         
     | 
| 832 | 
         
            +
               
         
     | 
| 833 | 
         
            +
            def process_secondary_model(secondary_model: ModelData, process_data, main_model_primary_stem_4_stem=None, is_4_stem_demucs=False, main_process_method=None, is_pre_proc_model=False):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 834 | 
         | 
| 835 | 
         
             
                if not is_pre_proc_model:
         
     | 
| 836 | 
         
             
                    process_iteration = process_data['process_iteration']
         
     | 
| 837 | 
         
             
                    process_iteration()
         
     | 
| 838 | 
         | 
| 839 | 
         
             
                if secondary_model.process_method == VR_ARCH_TYPE:
         
     | 
| 840 | 
         
            +
                    seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
         
     | 
| 841 | 
         
             
                if secondary_model.process_method == MDX_ARCH_TYPE:
         
     | 
| 842 | 
         
            +
                    seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 843 | 
         
             
                if secondary_model.process_method == DEMUCS_ARCH_TYPE:
         
     | 
| 844 | 
         
            +
                    seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
         
     | 
| 845 | 
         | 
| 846 | 
         
             
                secondary_sources = seperator.seperate()
         
     | 
| 847 | 
         | 
| 848 | 
         
            +
                if type(secondary_sources) is dict and not is_4_stem_demucs and not is_pre_proc_model:
         
     | 
| 849 | 
         
            +
                    return gather_sources(secondary_model.primary_model_primary_stem, STEM_PAIR_MAPPER[secondary_model.primary_model_primary_stem], secondary_sources)
         
     | 
| 850 | 
         
             
                else:
         
     | 
| 851 | 
         
             
                    return secondary_sources
         
     | 
| 852 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 853 | 
         
             
            def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict):
         
     | 
| 854 | 
         | 
| 855 | 
         
             
                source_primary = False
         
     | 
| 
         | 
|
| 863 | 
         | 
| 864 | 
         
             
                return source_primary, source_secondary
         
     | 
| 865 | 
         | 
| 866 | 
         
            +
            def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=False):
         
     | 
| 867 | 
         
            +
             
     | 
| 868 | 
         
             
                audio_path = mix
         
     | 
| 869 | 
         
            +
                samplerate = 44100
         
     | 
| 870 | 
         | 
| 871 | 
         
             
                if not isinstance(mix, np.ndarray):
         
     | 
| 872 | 
         
            +
                    mix, samplerate = librosa.load(mix, mono=False, sr=44100)
         
     | 
| 873 | 
         
             
                else:
         
     | 
| 874 | 
         
             
                    mix = mix.T
         
     | 
| 875 | 
         | 
| 876 | 
         
            +
                if not np.any(mix) and audio_path.endswith('.mp3'):
         
     | 
| 877 | 
         
            +
                    mix = rerun_mp3(audio_path)
         
     | 
| 
         | 
|
| 878 | 
         | 
| 879 | 
         
             
                if mix.ndim == 1:
         
     | 
| 880 | 
         
             
                    mix = np.asfortranarray([mix,mix])
         
     | 
| 881 | 
         | 
| 882 | 
         
            +
                def get_segmented_mix(chunk_set=chunk_set):
         
     | 
| 883 | 
         
            +
                    segmented_mix = {}
         
     | 
| 884 | 
         
            +
                    
         
     | 
| 885 | 
         
            +
                    samples = mix.shape[-1]
         
     | 
| 886 | 
         
            +
                    margin = margin_set
         
     | 
| 887 | 
         
            +
                    chunk_size = chunk_set*44100
         
     | 
| 888 | 
         
            +
                    assert not margin == 0, 'margin cannot be zero!'
         
     | 
| 889 | 
         
            +
                    
         
     | 
| 890 | 
         
            +
                    if margin > chunk_size:
         
     | 
| 891 | 
         
            +
                        margin = chunk_size
         
     | 
| 892 | 
         
            +
                    if chunk_set == 0 or samples < chunk_size:
         
     | 
| 893 | 
         
            +
                        chunk_size = samples
         
     | 
| 894 | 
         
            +
                    
         
     | 
| 895 | 
         
            +
                    counter = -1
         
     | 
| 896 | 
         
            +
                    for skip in range(0, samples, chunk_size):
         
     | 
| 897 | 
         
            +
                        counter+=1
         
     | 
| 898 | 
         
            +
                        s_margin = 0 if counter == 0 else margin
         
     | 
| 899 | 
         
            +
                        end = min(skip+chunk_size+margin, samples)
         
     | 
| 900 | 
         
            +
                        start = skip-s_margin
         
     | 
| 901 | 
         
            +
                        segmented_mix[skip] = mix[:,start:end].copy()
         
     | 
| 902 | 
         
            +
                        if end == samples:
         
     | 
| 903 | 
         
            +
                            break
         
     | 
| 904 | 
         
            +
                        
         
     | 
| 905 | 
         
            +
                    return segmented_mix
         
     | 
| 906 | 
         
            +
             
     | 
| 907 | 
         
            +
                if is_missing_mix:
         
     | 
| 908 | 
         
            +
                    return mix
         
     | 
| 909 | 
         
            +
                else:
         
     | 
| 910 | 
         
            +
                    segmented_mix = get_segmented_mix()
         
     | 
| 911 | 
         
            +
                    raw_mix = get_segmented_mix(chunk_set=0) if mdx_net_cut else mix
         
     | 
| 912 | 
         
            +
                    return segmented_mix, raw_mix, samplerate
         
     | 
| 913 | 
         | 
| 914 | 
         
             
            def rerun_mp3(audio_file, sample_rate=44100):
         
     | 
| 915 | 
         | 
| 
         | 
|
| 934 | 
         | 
| 935 | 
         
             
                    if save_format == MP3:
         
     | 
| 936 | 
         
             
                        audio_path_mp3 = audio_path.replace(".wav", ".mp3")
         
     | 
| 937 | 
         
            +
                        musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 938 | 
         | 
| 939 | 
         
             
                    try:
         
     | 
| 940 | 
         
             
                        os.remove(audio_path)
         
     | 
| 941 | 
         
             
                    except Exception as e:
         
     | 
| 942 | 
         
             
                        print(e)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         |