Spaces:
Running
on
Zero
Running
on
Zero
"""Module for separating audio sources using VR architecture models.""" | |
import os | |
import math | |
import torch | |
import librosa | |
import numpy as np | |
from tqdm import tqdm | |
# Check if we really need the rerun_mp3 function, remove if not | |
import audioread | |
from audio_separator.separator.common_separator import CommonSeparator | |
from audio_separator.separator.uvr_lib_v5 import spec_utils | |
from audio_separator.separator.uvr_lib_v5.vr_network import nets | |
from audio_separator.separator.uvr_lib_v5.vr_network import nets_new | |
from audio_separator.separator.uvr_lib_v5.vr_network.model_param_init import ModelParameters | |
class VRSeparator(CommonSeparator): | |
""" | |
VRSeparator is responsible for separating audio sources using VR models. | |
It initializes with configuration parameters and prepares the model for separation tasks. | |
""" | |
def __init__(self, common_config, arch_config: dict): | |
# Any configuration values which can be shared between architectures should be set already in CommonSeparator, | |
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name) | |
super().__init__(config=common_config) | |
# Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model | |
# It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method | |
# The instance variable self.model_data is passed through from Separator and set in CommonSeparator | |
self.logger.debug(f"Model data: {self.model_data}") | |
# Most of the VR models use the same number of output channels, but the VR 51 models have specific values set in model_data JSON | |
self.model_capacity = 32, 128 | |
self.is_vr_51_model = False | |
if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys(): | |
self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"] | |
self.is_vr_51_model = True | |
# Model params are additional technical parameter values from JSON files in separator/uvr_lib_v5/vr_network/modelparams/*.json, | |
# with filenames referenced by the model_data["vr_model_param"] value | |
package_root_filepath = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
vr_params_json_dir = os.path.join(package_root_filepath, "uvr_lib_v5", "vr_network", "modelparams") | |
vr_params_json_filename = f"{self.model_data['vr_model_param']}.json" | |
vr_params_json_filepath = os.path.join(vr_params_json_dir, vr_params_json_filename) | |
self.model_params = ModelParameters(vr_params_json_filepath) | |
self.logger.debug(f"Model params: {self.model_params.param}") | |
# Arch Config is the VR architecture specific user configuration options, which should all be configurable by the user | |
# either by their Separator class instantiation or by passing in a CLI parameter. | |
# While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured | |
# this way as they have architecture-specific default values. | |
# This option performs Test-Time-Augmentation to improve the separation quality. | |
# Note: Having this selected will increase the time it takes to complete a conversion | |
self.enable_tta = arch_config.get("enable_tta", False) | |
# This option can potentially identify leftover instrumental artifacts within the vocal outputs; may improve the separation of some songs. | |
# Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort. | |
self.enable_post_process = arch_config.get("enable_post_process", False) | |
# post_process_threshold values = ('0.1', '0.2', '0.3') | |
self.post_process_threshold = arch_config.get("post_process_threshold", 0.2) | |
# Number of batches to be processed at a time. | |
# - Higher values mean more RAM usage but slightly faster processing times. | |
# - Lower values mean less RAM usage but slightly longer processing times. | |
# - Batch size value has no effect on output quality. | |
# Andrew note: for some reason, lower batch sizes seem to cause broken output for VR arch; need to investigate why | |
self.batch_size = arch_config.get("batch_size", 1) | |
# Select window size to balance quality and speed: | |
# - 1024 - Quick but lesser quality. | |
# - 512 - Medium speed and quality. | |
# - 320 - Takes longer but may offer better quality. | |
self.window_size = arch_config.get("window_size", 512) | |
# The application will mirror the missing frequency range of the output. | |
self.high_end_process = arch_config.get("high_end_process", False) | |
self.input_high_end_h = None | |
self.input_high_end = None | |
# Adjust the intensity of primary stem extraction: | |
# - Ranges from -100 - 100. | |
# - Bigger values mean deeper extractions. | |
# - Typically, it's set to 5 for vocals & instrumentals. | |
# - Values beyond 5 might muddy the sound for non-vocal models. | |
self.aggression = float(int(arch_config.get("aggression", 5)) / 100) | |
self.aggressiveness = {"value": self.aggression, "split_bin": self.model_params.param["band"][1]["crop_stop"], "aggr_correction": self.model_params.param.get("aggr_correction")} | |
self.model_samplerate = self.model_params.param["sr"] | |
self.logger.debug(f"VR arch params: enable_tta={self.enable_tta}, enable_post_process={self.enable_post_process}, post_process_threshold={self.post_process_threshold}") | |
self.logger.debug(f"VR arch params: batch_size={self.batch_size}, window_size={self.window_size}") | |
self.logger.debug(f"VR arch params: high_end_process={self.high_end_process}, aggression={self.aggression}") | |
self.logger.debug(f"VR arch params: is_vr_51_model={self.is_vr_51_model}, model_samplerate={self.model_samplerate}, model_capacity={self.model_capacity}") | |
self.model_run = lambda *args, **kwargs: self.logger.error("Model run method is not initialised yet.") | |
# This should go away once we refactor to remove soundfile.write and replace with pydub like we did for the MDX rewrite | |
self.wav_subtype = "PCM_16" | |
self.logger.info("VR Separator initialisation complete") | |
def separate(self, audio_file_path, custom_output_names=None): | |
""" | |
Separates the audio file into primary and secondary sources based on the model's configuration. | |
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files. | |
Args: | |
audio_file_path (str): The path to the audio file to be processed. | |
custom_output_names (dict, optional): Custom names for the output files. Defaults to None. | |
Returns: | |
list: A list of paths to the output files generated by the separation process. | |
""" | |
self.primary_source = None | |
self.secondary_source = None | |
self.audio_file_path = audio_file_path | |
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] | |
self.logger.debug(f"Starting separation for input audio file {self.audio_file_path}...") | |
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default | |
vr_5_1_models = [56817, 218409] | |
model_size = math.ceil(os.stat(self.model_path).st_size / 1024) | |
nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size)) | |
self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}") | |
if nn_arch_size in vr_5_1_models or self.is_vr_51_model: | |
self.logger.debug("Using CascadedNet for VR 5.1 model...") | |
self.model_run = nets_new.CascadedNet(self.model_params.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1]) | |
self.is_vr_51_model = True | |
else: | |
self.logger.debug("Determining model capacity...") | |
self.model_run = nets.determine_model_capacity(self.model_params.param["bins"] * 2, nn_arch_size) | |
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu")) | |
self.model_run.to(self.torch_device) | |
self.logger.debug("Model loaded and moved to device.") | |
y_spec, v_spec = self.inference_vr(self.loading_mix(), self.torch_device, self.aggressiveness) | |
self.logger.debug("Inference completed.") | |
# Sanitize y_spec and v_spec to replace NaN and infinite values | |
y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0) | |
v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0) | |
self.logger.debug("Sanitization completed. Replaced NaN and infinite values in y_spec and v_spec.") | |
# After inference_vr call | |
self.logger.debug(f"Inference VR completed. y_spec shape: {y_spec.shape}, v_spec shape: {v_spec.shape}") | |
self.logger.debug(f"y_spec stats - min: {np.min(y_spec)}, max: {np.max(y_spec)}, isnan: {np.isnan(y_spec).any()}, isinf: {np.isinf(y_spec).any()}") | |
self.logger.debug(f"v_spec stats - min: {np.min(v_spec)}, max: {np.max(v_spec)}, isnan: {np.isnan(v_spec).any()}, isinf: {np.isinf(v_spec).any()}") | |
# Not yet implemented from UVR features: | |
# | |
# if not self.is_vocal_split_model: | |
# self.cache_source((y_spec, v_spec)) | |
# if self.is_secondary_model_activated and self.secondary_model: | |
# self.logger.debug("Processing secondary model...") | |
# self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model( | |
# self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem | |
# ) | |
# Initialize the list for output files | |
output_files = [] | |
self.logger.debug("Processing output files...") | |
# Note: logic similar to the following should probably be added to the other architectures | |
# Check if output_single_stem is set to a value that would result in no output files | |
if self.output_single_stem and (self.output_single_stem.lower() != self.primary_stem_name.lower() and self.output_single_stem.lower() != self.secondary_stem_name.lower()): | |
# If so, reset output_single_stem to None to save both stems | |
self.output_single_stem = None | |
self.logger.warning(f"The output_single_stem setting '{self.output_single_stem}' does not match any of the output files: '{self.primary_stem_name}' and '{self.secondary_stem_name}'. For this model '{self.model_name}', the output_single_stem setting will be ignored and all output files will be saved.") | |
# Save and process the primary stem if needed | |
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower(): | |
self.logger.debug(f"Processing primary stem: {self.primary_stem_name}") | |
if not isinstance(self.primary_source, np.ndarray): | |
self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {y_spec.shape}") | |
self.primary_source = self.spec_to_wav(y_spec).T | |
self.logger.debug("Converting primary source spectrogram to waveform.") | |
if not self.model_samplerate == 44100: | |
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T | |
self.logger.debug("Resampling primary source to 44100Hz.") | |
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names) | |
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...") | |
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name) | |
output_files.append(self.primary_stem_output_path) | |
# Save and process the secondary stem if needed | |
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower(): | |
self.logger.debug(f"Processing secondary stem: {self.secondary_stem_name}") | |
if not isinstance(self.secondary_source, np.ndarray): | |
self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {v_spec.shape}") | |
self.secondary_source = self.spec_to_wav(v_spec).T | |
self.logger.debug("Converting secondary source spectrogram to waveform.") | |
if not self.model_samplerate == 44100: | |
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T | |
self.logger.debug("Resampling secondary source to 44100Hz.") | |
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names) | |
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...") | |
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) | |
output_files.append(self.secondary_stem_output_path) | |
# Not yet implemented from UVR features: | |
# self.process_vocal_split_chain(secondary_sources) | |
# self.logger.debug("Vocal split chain processed.") | |
return output_files | |
def loading_mix(self): | |
X_wave, X_spec_s = {}, {} | |
bands_n = len(self.model_params.param["band"]) | |
audio_file = spec_utils.write_array_to_mem(self.audio_file_path, subtype=self.wav_subtype) | |
is_mp3 = audio_file.endswith(".mp3") if isinstance(audio_file, str) else False | |
self.logger.debug(f"loading_mix iteraring through {bands_n} bands") | |
for d in tqdm(range(bands_n, 0, -1)): | |
bp = self.model_params.param["band"][d] | |
wav_resolution = bp["res_type"] | |
if self.torch_device_mps is not None: | |
wav_resolution = "polyphase" | |
if d == bands_n: # high-end band | |
X_wave[d], _ = librosa.load(audio_file, sr=bp["sr"], mono=False, dtype=np.float32, res_type=wav_resolution) | |
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model) | |
if not np.any(X_wave[d]) and is_mp3: | |
X_wave[d] = rerun_mp3(audio_file, bp["sr"]) | |
if X_wave[d].ndim == 1: | |
X_wave[d] = np.asarray([X_wave[d], X_wave[d]]) | |
else: # lower bands | |
X_wave[d] = librosa.resample(X_wave[d + 1], orig_sr=self.model_params.param["band"][d + 1]["sr"], target_sr=bp["sr"], res_type=wav_resolution) | |
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model) | |
if d == bands_n and self.high_end_process: | |
self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (self.model_params.param["pre_filter_stop"] - self.model_params.param["pre_filter_start"]) | |
self.input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, :] | |
X_spec = spec_utils.combine_spectrograms(X_spec_s, self.model_params, is_v51_model=self.is_vr_51_model) | |
del X_wave, X_spec_s, audio_file | |
return X_spec | |
def inference_vr(self, X_spec, device, aggressiveness): | |
def _execute(X_mag_pad, roi_size): | |
X_dataset = [] | |
patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size | |
self.logger.debug(f"inference_vr appending to X_dataset for each of {patches} patches") | |
for i in tqdm(range(patches)): | |
start = i * roi_size | |
X_mag_window = X_mag_pad[:, :, start : start + self.window_size] | |
X_dataset.append(X_mag_window) | |
total_iterations = patches // self.batch_size if not self.enable_tta else (patches // self.batch_size) * 2 | |
self.logger.debug(f"inference_vr iterating through {total_iterations} batches, batch_size = {self.batch_size}") | |
X_dataset = np.asarray(X_dataset) | |
self.model_run.eval() | |
with torch.no_grad(): | |
mask = [] | |
for i in tqdm(range(0, patches, self.batch_size)): | |
X_batch = X_dataset[i : i + self.batch_size] | |
X_batch = torch.from_numpy(X_batch).to(device) | |
pred = self.model_run.predict_mask(X_batch) | |
if not pred.size()[3] > 0: | |
raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]") | |
pred = pred.detach().cpu().numpy() | |
pred = np.concatenate(pred, axis=2) | |
mask.append(pred) | |
if len(mask) == 0: | |
raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]") | |
mask = np.concatenate(mask, axis=2) | |
return mask | |
def postprocess(mask, X_mag, X_phase): | |
is_non_accom_stem = False | |
for stem in CommonSeparator.NON_ACCOM_STEMS: | |
if stem == self.primary_stem_name: | |
is_non_accom_stem = True | |
mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness) | |
if self.enable_post_process: | |
mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold) | |
y_spec = mask * X_mag * np.exp(1.0j * X_phase) | |
v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase) | |
return y_spec, v_spec | |
X_mag, X_phase = spec_utils.preprocess(X_spec) | |
n_frame = X_mag.shape[2] | |
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset) | |
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") | |
X_mag_pad /= X_mag_pad.max() | |
mask = _execute(X_mag_pad, roi_size) | |
if self.enable_tta: | |
pad_l += roi_size // 2 | |
pad_r += roi_size // 2 | |
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") | |
X_mag_pad /= X_mag_pad.max() | |
mask_tta = _execute(X_mag_pad, roi_size) | |
mask_tta = mask_tta[:, :, roi_size // 2 :] | |
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 | |
else: | |
mask = mask[:, :, :n_frame] | |
y_spec, v_spec = postprocess(mask, X_mag, X_phase) | |
return y_spec, v_spec | |
def spec_to_wav(self, spec): | |
if self.high_end_process and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h: | |
input_high_end_ = spec_utils.mirroring("mirroring", spec, self.input_high_end, self.model_params) | |
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model) | |
else: | |
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, is_v51_model=self.is_vr_51_model) | |
return wav | |
# Check if we really need the rerun_mp3 function, refactor or remove if not | |
def rerun_mp3(audio_file, sample_rate=44100): | |
with audioread.audio_open(audio_file) as f: | |
track_length = int(f.duration) | |
return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0] | |