Spaces:
Running
on
Zero
Running
on
Zero
"""Module for separating audio sources using MDX architecture models.""" | |
import os | |
import platform | |
import torch | |
import onnx | |
import onnxruntime as ort | |
import numpy as np | |
import onnx2torch | |
from tqdm import tqdm | |
from audio_separator.separator.uvr_lib_v5 import spec_utils | |
from audio_separator.separator.uvr_lib_v5.stft import STFT | |
from audio_separator.separator.common_separator import CommonSeparator | |
class MDXSeparator(CommonSeparator): | |
""" | |
MDXSeparator is responsible for separating audio sources using MDX models. | |
It initializes with configuration parameters and prepares the model for separation tasks. | |
""" | |
def __init__(self, common_config, arch_config): | |
# Any configuration values which can be shared between architectures should be set already in CommonSeparator, | |
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name) | |
super().__init__(config=common_config) | |
# Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance | |
# Pick a segment size to balance speed, resource use, and quality: | |
# - Smaller sizes consume less resources. | |
# - Bigger sizes consume more resources, but may provide better results. | |
# - Default size is 256. Quality can change based on your pick. | |
self.segment_size = arch_config.get("segment_size") | |
# This option controls the amount of overlap between prediction windows. | |
# - Higher values can provide better results, but will lead to longer processing times. | |
# - For Non-MDX23C models: You can choose between 0.001-0.999 | |
self.overlap = arch_config.get("overlap") | |
# Number of batches to be processed at a time. | |
# - Higher values mean more RAM usage but slightly faster processing times. | |
# - Lower values mean less RAM usage but slightly longer processing times. | |
# - Batch size value has no effect on output quality. | |
# BATCH_SIZE = ('1', ''2', '3', '4', '5', '6', '7', '8', '9', '10') | |
self.batch_size = arch_config.get("batch_size", 1) | |
# hop_length is equivalent to the more commonly used term "stride" in convolutional neural networks | |
# In machine learning, particularly in the context of convolutional neural networks (CNNs), | |
# the term "stride" refers to the number of pixels by which we move the filter across the input image. | |
# Strides are a crucial component in the convolution operation, a fundamental building block of CNNs used primarily in the field of computer vision. | |
# Stride is a parameter that dictates the movement of the kernel, or filter, across the input data, such as an image. | |
# When performing a convolution operation, the stride determines how many units the filter shifts at each step. | |
# The choice of stride affects the model in several ways: | |
# Output Size: A larger stride will result in a smaller output spatial dimension. | |
# Computational Efficiency: Increasing the stride can decrease the computational load. | |
# Field of View: A higher stride means that each step of the filter takes into account a wider area of the input image. | |
# This can be beneficial when the model needs to capture more global features rather than focusing on finer details. | |
self.hop_length = arch_config.get("hop_length") | |
# If enabled, model will be run twice to reduce noise in output audio. | |
self.enable_denoise = arch_config.get("enable_denoise") | |
self.logger.debug(f"MDX arch params: batch_size={self.batch_size}, segment_size={self.segment_size}") | |
self.logger.debug(f"MDX arch params: overlap={self.overlap}, hop_length={self.hop_length}, enable_denoise={self.enable_denoise}") | |
# Initializing model-specific parameters from model_data JSON | |
self.compensate = self.model_data["compensate"] | |
self.dim_f = self.model_data["mdx_dim_f_set"] | |
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"] | |
self.n_fft = self.model_data["mdx_n_fft_scale_set"] | |
self.config_yaml = self.model_data.get("config_yaml", None) | |
self.logger.debug(f"MDX arch params: compensate={self.compensate}, dim_f={self.dim_f}, dim_t={self.dim_t}, n_fft={self.n_fft}") | |
self.logger.debug(f"MDX arch params: config_yaml={self.config_yaml}") | |
# In UVR, these variables are set but either aren't useful or are better handled in audio-separator. | |
# Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator. | |
# "chunks" is not actually used for anything in UVR... | |
# self.chunks = 0 | |
# "adjust" is hard-coded to 1 in UVR, and only used as a multiplier in run_model, so it does nothing. | |
# self.adjust = 1 | |
# "hop" is hard-coded to 1024 in UVR. We have a "hop_length" parameter instead | |
# self.hop = 1024 | |
# "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead. | |
# self.margin = 44100 | |
# "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models. | |
# We haven't implemented support for the checkpoint models here, so we're not using it. | |
# self.dim_c = 4 | |
self.load_model() | |
self.n_bins = 0 | |
self.trim = 0 | |
self.chunk_size = 0 | |
self.gen_size = 0 | |
self.stft = None | |
self.primary_source = None | |
self.secondary_source = None | |
self.audio_file_path = None | |
self.audio_file_base = None | |
def load_model(self): | |
""" | |
Load the model into memory from file on disk, initialize it with config from the model data, | |
and prepare for inferencing using hardware accelerated Torch device. | |
""" | |
self.logger.debug("Loading ONNX model for inference...") | |
if self.segment_size == self.dim_t: | |
ort_session_options = ort.SessionOptions() | |
if self.log_level > 10: | |
ort_session_options.log_severity_level = 3 | |
else: | |
ort_session_options.log_severity_level = 0 | |
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options) | |
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0] | |
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.") | |
else: | |
if platform.system() == 'Windows': | |
onnx_model = onnx.load(self.model_path) | |
self.model_run = onnx2torch.convert(onnx_model) | |
else: | |
self.model_run = onnx2torch.convert(self.model_path) | |
self.model_run.to(self.torch_device).eval() | |
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.") | |
def separate(self, audio_file_path, custom_output_names=None): | |
""" | |
Separates the audio file into primary and secondary sources based on the model's configuration. | |
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files. | |
Args: | |
audio_file_path (str): The path to the audio file to be processed. | |
custom_output_names (dict, optional): Custom names for the output files. Defaults to None. | |
Returns: | |
list: A list of paths to the output files generated by the separation process. | |
""" | |
self.audio_file_path = audio_file_path | |
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] | |
# Prepare the mix for processing | |
self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...") | |
mix = self.prepare_mix(self.audio_file_path) | |
self.logger.debug("Normalizing mix before demixing...") | |
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold) | |
# Start the demixing process | |
source = self.demix(mix) | |
self.logger.debug("Demixing completed.") | |
# In UVR, the source is cached here if it's a vocal split model, but we're not supporting that yet | |
# Initialize the list for output files | |
output_files = [] | |
self.logger.debug("Processing output files...") | |
# Normalize and transpose the primary source if it's not already an array | |
if not isinstance(self.primary_source, np.ndarray): | |
self.logger.debug("Normalizing primary source...") | |
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T | |
# Process the secondary source if not already an array | |
if not isinstance(self.secondary_source, np.ndarray): | |
self.logger.debug("Producing secondary source: demixing in match_mix mode") | |
raw_mix = self.demix(mix, is_match_mix=True) | |
if self.invert_using_spec: | |
self.logger.debug("Inverting secondary stem using spectogram as invert_using_spec is set to True") | |
self.secondary_source = spec_utils.invert_stem(raw_mix, source) | |
else: | |
self.logger.debug("Inverting secondary stem by subtracting of transposed demixed stem from transposed original mix") | |
self.secondary_source = mix.T - source.T | |
# Save and process the secondary stem if needed | |
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower(): | |
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names) | |
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...") | |
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) | |
output_files.append(self.secondary_stem_output_path) | |
# Save and process the primary stem if needed | |
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower(): | |
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names) | |
if not isinstance(self.primary_source, np.ndarray): | |
self.primary_source = source.T | |
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...") | |
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name) | |
output_files.append(self.primary_stem_output_path) | |
# Not yet implemented from UVR features: | |
# self.process_vocal_split_chain(secondary_sources) | |
# self.logger.debug("Vocal split chain processed.") | |
return output_files | |
def initialize_model_settings(self): | |
""" | |
This function sets up the necessary parameters for the model, like the number of frequency bins (n_bins), the trimming size (trim), | |
the size of each audio chunk (chunk_size), and the window function for spectral transformations (window). | |
It ensures that the model is configured with the correct settings for processing the audio data. | |
""" | |
self.logger.debug("Initializing model settings...") | |
# n_bins is half the FFT size plus one (self.n_fft // 2 + 1). | |
self.n_bins = self.n_fft // 2 + 1 | |
# trim is half the FFT size (self.n_fft // 2). | |
self.trim = self.n_fft // 2 | |
# chunk_size is the hop_length size times the segment size minus one | |
self.chunk_size = self.hop_length * (self.segment_size - 1) | |
# gen_size is the chunk size minus twice the trim size | |
self.gen_size = self.chunk_size - 2 * self.trim | |
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device) | |
self.logger.debug(f"Model input params: n_fft={self.n_fft} hop_length={self.hop_length} dim_f={self.dim_f}") | |
self.logger.debug(f"Model settings: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}") | |
def initialize_mix(self, mix, is_ckpt=False): | |
""" | |
After prepare_mix segments the audio, initialize_mix further processes each segment. | |
It ensures each audio segment is in the correct format for the model, applies necessary padding, | |
and converts the segments into tensors for processing with the model. | |
This step is essential for preparing the audio data in a format that the neural network can process. | |
""" | |
# Log the initialization of the mix and whether checkpoint mode is used | |
self.logger.debug(f"Initializing mix with is_ckpt={is_ckpt}. Initial mix shape: {mix.shape}") | |
# Ensure the mix is a 2-channel (stereo) audio signal | |
if mix.shape[0] != 2: | |
error_message = f"Expected a 2-channel audio signal, but got {mix.shape[0]} channels" | |
self.logger.error(error_message) | |
raise ValueError(error_message) | |
# If in checkpoint mode, process the mix differently | |
if is_ckpt: | |
self.logger.debug("Processing in checkpoint mode...") | |
# Calculate padding based on the generation size and trim | |
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size) | |
self.logger.debug(f"Padding calculated: {pad}") | |
# Add padding at the beginning and the end of the mix | |
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1) | |
# Determine the number of chunks based on the mixture's length | |
num_chunks = mixture.shape[-1] // self.gen_size | |
self.logger.debug(f"Mixture shape after padding: {mixture.shape}, Number of chunks: {num_chunks}") | |
# Split the mixture into chunks | |
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)] | |
else: | |
# If not in checkpoint mode, process normally | |
self.logger.debug("Processing in non-checkpoint mode...") | |
mix_waves = [] | |
n_sample = mix.shape[1] | |
# Calculate necessary padding to make the total length divisible by the generation size | |
pad = self.gen_size - n_sample % self.gen_size | |
self.logger.debug(f"Number of samples: {n_sample}, Padding calculated: {pad}") | |
# Apply padding to the mix | |
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1) | |
self.logger.debug(f"Shape of mix after padding: {mix_p.shape}") | |
# Process the mix in chunks | |
i = 0 | |
while i < n_sample + pad: | |
waves = np.array(mix_p[:, i : i + self.chunk_size]) | |
mix_waves.append(waves) | |
self.logger.debug(f"Processed chunk {len(mix_waves)}: Start {i}, End {i + self.chunk_size}") | |
i += self.gen_size | |
# Convert the list of wave chunks into a tensor for processing on the specified device | |
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device) | |
self.logger.debug(f"Converted mix_waves to tensor. Tensor shape: {mix_waves_tensor.shape}") | |
return mix_waves_tensor, pad | |
def demix(self, mix, is_match_mix=False): | |
""" | |
Demixes the input mix into its constituent sources. If is_match_mix is True, the function adjusts the processing | |
to better match the mix, affecting chunk sizes and overlaps. The demixing process involves padding the mix, | |
processing it in chunks, applying windowing for overlaps, and accumulating the results to separate the sources. | |
""" | |
self.logger.debug(f"Starting demixing process with is_match_mix: {is_match_mix}...") | |
self.initialize_model_settings() | |
# Preserves the original mix for later use. | |
# In UVR, this is used for the pitch fix and VR denoise processes, which aren't yet implemented here. | |
org_mix = mix | |
self.logger.debug(f"Original mix stored. Shape: {org_mix.shape}") | |
# Initializes a list to store the separated waveforms. | |
tar_waves_ = [] | |
# Handling different chunk sizes and overlaps based on the matching requirement. | |
if is_match_mix: | |
# Sets a smaller chunk size specifically for matching the mix. | |
chunk_size = self.hop_length * (self.segment_size - 1) | |
# Sets a small overlap for the chunks. | |
overlap = 0.02 | |
self.logger.debug(f"Chunk size for matching mix: {chunk_size}, Overlap: {overlap}") | |
else: | |
# Uses the regular chunk size defined in model settings. | |
chunk_size = self.chunk_size | |
# Uses the overlap specified in the model settings. | |
overlap = self.overlap | |
self.logger.debug(f"Standard chunk size: {chunk_size}, Overlap: {overlap}") | |
# Calculates the generated size after subtracting the trim from both ends of the chunk. | |
gen_size = chunk_size - 2 * self.trim | |
self.logger.debug(f"Generated size calculated: {gen_size}") | |
# Calculates padding to make the mix length a multiple of the generated size. | |
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size) | |
# Prepares the mixture with padding at the beginning and the end. | |
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1) | |
self.logger.debug(f"Mixture prepared with padding. Mixture shape: {mixture.shape}") | |
# Calculates the step size for processing chunks based on the overlap. | |
step = int((1 - overlap) * chunk_size) | |
self.logger.debug(f"Step size for processing chunks: {step} as overlap is set to {overlap}.") | |
# Initializes arrays to store the results and to account for overlap. | |
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32) | |
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32) | |
# Initializes counters for processing chunks. | |
total = 0 | |
total_chunks = (mixture.shape[-1] + step - 1) // step | |
self.logger.debug(f"Total chunks to process: {total_chunks}") | |
# Processes each chunk of the mixture. | |
for i in tqdm(range(0, mixture.shape[-1], step)): | |
total += 1 | |
start = i | |
end = min(i + chunk_size, mixture.shape[-1]) | |
self.logger.debug(f"Processing chunk {total}/{total_chunks}: Start {start}, End {end}") | |
# Handles windowing for overlapping chunks. | |
chunk_size_actual = end - start | |
window = None | |
if overlap != 0: | |
window = np.hanning(chunk_size_actual) | |
window = np.tile(window[None, None, :], (1, 2, 1)) | |
self.logger.debug("Window applied to the chunk.") | |
# Zero-pad the chunk to prepare it for processing. | |
mix_part_ = mixture[:, start:end] | |
if end != i + chunk_size: | |
pad_size = (i + chunk_size) - end | |
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1) | |
# Converts the chunk to a tensor for processing. | |
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device) | |
# Splits the chunk into smaller batches if necessary. | |
mix_waves = mix_part.split(self.batch_size) | |
total_batches = len(mix_waves) | |
self.logger.debug(f"Mix part split into batches. Number of batches: {total_batches}") | |
with torch.no_grad(): | |
# Processes each batch in the chunk. | |
batches_processed = 0 | |
for mix_wave in mix_waves: | |
batches_processed += 1 | |
self.logger.debug(f"Processing mix_wave batch {batches_processed}/{total_batches}") | |
# Runs the model to separate the sources. | |
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix) | |
# Applies windowing if needed and accumulates the results. | |
if window is not None: | |
tar_waves[..., :chunk_size_actual] *= window | |
divider[..., start:end] += window | |
else: | |
divider[..., start:end] += 1 | |
result[..., start:end] += tar_waves[..., : end - start] | |
# Normalizes the results by the divider to account for overlap. | |
self.logger.debug("Normalizing result by dividing result by divider.") | |
tar_waves = result / divider | |
tar_waves_.append(tar_waves) | |
# Reshapes the results to match the original dimensions. | |
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim] | |
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]] | |
# Extracts the source from the results. | |
source = tar_waves[:, 0:None] | |
self.logger.debug(f"Concatenated tar_waves. Shape: {tar_waves.shape}") | |
# TODO: In UVR, pitch changing happens here. Consider implementing this as a feature. | |
# Compensates the source if not matching the mix. | |
if not is_match_mix: | |
source *= self.compensate | |
self.logger.debug("Match mix mode; compensate multiplier applied.") | |
# TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature. | |
self.logger.debug("Demixing process completed.") | |
return source | |
def run_model(self, mix, is_match_mix=False): | |
""" | |
Processes the input mix through the model to separate the sources. | |
Applies STFT, handles spectrum modifications, and runs the model for source separation. | |
""" | |
# Applying the STFT to the mix. The mix is moved to the specified device (e.g., GPU) before processing. | |
# self.logger.debug(f"Running STFT on the mix. Mix shape before STFT: {mix.shape}") | |
spek = self.stft(mix.to(self.torch_device)) | |
self.logger.debug(f"STFT applied on mix. Spectrum shape: {spek.shape}") | |
# Zeroing out the first 3 bins of the spectrum. This is often done to reduce low-frequency noise. | |
spek[:, :, :3, :] *= 0 | |
# self.logger.debug("First 3 bins of the spectrum zeroed out.") | |
# Handling the case where the mix needs to be matched (is_match_mix = True) | |
if is_match_mix: | |
# self.logger.debug("Match mix mode is enabled. Converting spectrum to NumPy array.") | |
spec_pred = spek.cpu().numpy() | |
self.logger.debug("is_match_mix: spectrum prediction obtained directly from STFT output.") | |
else: | |
# If denoising is enabled, the model is run on both the negative and positive spectrums. | |
if self.enable_denoise: | |
# Assuming spek is a tensor and self.model_run can process it directly | |
spec_pred_neg = self.model_run(-spek) # Ensure this line correctly negates spek and runs the model | |
spec_pred_pos = self.model_run(spek) | |
# Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations | |
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type] | |
self.logger.debug("Model run on both negative and positive spectrums for denoising.") | |
else: | |
spec_pred = self.model_run(spek) | |
self.logger.debug("Model run on the spectrum without denoising.") | |
# Applying the inverse STFT to convert the spectrum back to the time domain. | |
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy() | |
self.logger.debug(f"Inverse STFT applied. Returning result with shape: {result.shape}") | |
return result | |