ASesYusuf1's picture
Upload 131 files
01f8b5b verified
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from ml_collections import ConfigDict
from scipy import signal
from audio_separator.separator.common_separator import CommonSeparator
from audio_separator.separator.uvr_lib_v5 import spec_utils
from audio_separator.separator.uvr_lib_v5.tfc_tdf_v3 import TFC_TDF_net
from audio_separator.separator.uvr_lib_v5.roformer.mel_band_roformer import MelBandRoformer
from audio_separator.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer
class MDXCSeparator(CommonSeparator):
"""
MDXCSeparator is responsible for separating audio sources using MDXC models.
It initializes with configuration parameters and prepares the model for separation tasks.
"""
def __init__(self, common_config, arch_config):
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
super().__init__(config=common_config)
# Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
# It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
# The instance variable self.model_data is passed through from Separator and set in CommonSeparator
self.logger.debug(f"Model data: {self.model_data}")
# Arch Config is the MDXC architecture specific user configuration options, which should all be configurable by the user
# either by their Separator class instantiation or by passing in a CLI parameter.
# While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
# this way as they have architecture-specific default values.
self.segment_size = arch_config.get("segment_size", 256)
# Whether or not to use the segment size from model config, or the default
# The segment size is set based on the value provided in a chosen model's associated config file (yaml).
self.override_model_segment_size = arch_config.get("override_model_segment_size", False)
self.overlap = arch_config.get("overlap", 8)
self.batch_size = arch_config.get("batch_size", 1)
# Amount of pitch shift to apply during processing (this does NOT affect the pitch of the output audio):
# • Whole numbers indicate semitones.
# • Using higher pitches may cut the upper bandwidth, even in high-quality models.
# • Upping the pitch can be better for tracks with deeper vocals.
# • Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.
self.pitch_shift = arch_config.get("pitch_shift", 0)
self.process_all_stems = arch_config.get("process_all_stems", True)
self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}")
self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")
self.is_roformer = "is_roformer" in self.model_data
self.load_model()
self.primary_source = None
self.secondary_source = None
self.audio_file_path = None
self.audio_file_base = None
self.is_primary_stem_main_target = False
if self.model_data_cfgdict.training.target_instrument == "Vocals" or len(self.model_data_cfgdict.training.instruments) > 1:
self.is_primary_stem_main_target = True
self.logger.debug(f"is_primary_stem_main_target: {self.is_primary_stem_main_target}")
self.logger.info("MDXC Separator initialisation complete")
def load_model(self):
"""
Load the model into memory from file on disk, initialize it with config from the model data,
and prepare for inferencing using hardware accelerated Torch device.
"""
self.logger.debug("Loading checkpoint model for inference...")
self.model_data_cfgdict = ConfigDict(self.model_data)
try:
if self.is_roformer:
self.logger.debug("Loading Roformer model...")
# Determine the model type based on the configuration and instantiate it
if "num_bands" in self.model_data_cfgdict.model:
self.logger.debug("Loading MelBandRoformer model...")
model = MelBandRoformer(**self.model_data_cfgdict.model)
elif "freqs_per_bands" in self.model_data_cfgdict.model:
self.logger.debug("Loading BSRoformer model...")
model = BSRoformer(**self.model_data_cfgdict.model)
else:
raise ValueError("Unknown Roformer model type in the configuration.")
# Load model checkpoint
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
self.model_run.load_state_dict(checkpoint)
self.model_run.to(self.torch_device).eval()
else:
self.logger.debug("Loading TFC_TDF_net model...")
self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
self.logger.debug("Loading model onto cpu")
# For some reason loading the state onto a hardware accelerated devices causes issues,
# so we load it onto CPU first then move it to the device
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
self.model_run.to(self.torch_device).eval()
except RuntimeError as e:
self.logger.error(f"Error: {e}")
self.logger.error("An error occurred while loading the model file. This often occurs when the model file is corrupt or incomplete.")
self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
sys.exit(1)
def separate(self, audio_file_path, custom_output_names=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
"""
self.primary_source = None
self.secondary_source = None
self.audio_file_path = audio_file_path
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
mix = self.prepare_mix(self.audio_file_path)
self.logger.debug("Normalizing mix before demixing...")
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
source = self.demix(mix=mix)
self.logger.debug("Demixing completed.")
output_files = []
self.logger.debug("Processing output files...")
if isinstance(source, dict):
self.logger.debug("Source is a dict, processing each stem...")
stem_list = []
if self.model_data_cfgdict.training.target_instrument:
stem_list = [self.model_data_cfgdict.training.target_instrument]
else:
stem_list = self.model_data_cfgdict.training.instruments
self.logger.debug(f"Available stems: {stem_list}")
is_multi_stem_model = len(stem_list) > 2
should_process_all_stems = self.process_all_stems and is_multi_stem_model
if should_process_all_stems:
self.logger.debug("Processing all stems from multi-stem model...")
for stem_name in stem_list:
stem_output_path = self.get_stem_output_path(stem_name, custom_output_names)
stem_source = spec_utils.normalize(
wave=source[stem_name],
max_peak=self.normalization_threshold,
min_peak=self.amplification_threshold
).T
self.logger.info(f"Saving {stem_name} stem to {stem_output_path}...")
self.final_process(stem_output_path, stem_source, stem_name)
output_files.append(stem_output_path)
else:
# Standard processing for primary and secondary stems
if not isinstance(self.primary_source, np.ndarray):
self.logger.debug(f"Normalizing primary source for primary stem {self.primary_stem_name}...")
self.primary_source = spec_utils.normalize(
wave=source[self.primary_stem_name],
max_peak=self.normalization_threshold,
min_peak=self.amplification_threshold
).T
if not isinstance(self.secondary_source, np.ndarray):
self.logger.debug(f"Normalizing secondary source for secondary stem {self.secondary_stem_name}...")
self.secondary_source = spec_utils.normalize(
wave=source[self.secondary_stem_name],
max_peak=self.normalization_threshold,
min_peak=self.amplification_threshold
).T
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
output_files.append(self.primary_stem_output_path)
else:
# Handle case when source is not a dictionary (single source model)
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
output_files.append(self.primary_stem_output_path)
return output_files
def pitch_fix(self, source, sr_pitched, orig_mix):
"""
Change the pitch of the source audio by a number of semitones.
Args:
source (np.ndarray): The source audio to be pitch-shifted.
sr_pitched (int): The sample rate of the pitch-shifted audio.
orig_mix (np.ndarray): The original mix, used to match the shape of the pitch-shifted audio.
Returns:
np.ndarray: The pitch-shifted source audio.
"""
source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=self.pitch_shift)[0]
source = spec_utils.match_array_shapes(source, orig_mix)
return source
def overlap_add(self, result, x, weights, start, length):
"""
Adds the overlapping part of the result to the result tensor.
"""
result[..., start : start + length] += x[..., :length] * weights[:length]
return result
def demix(self, mix: np.ndarray) -> dict:
"""
Demixes the input mix into primary and secondary sources using the model and model data.
Args:
mix (np.ndarray): The mix to be demixed.
Returns:
dict: A dictionary containing the demixed sources.
"""
orig_mix = mix
if self.pitch_shift != 0:
self.logger.debug(f"Shifting pitch by -{self.pitch_shift} semitones...")
mix, sample_rate = spec_utils.change_pitch_semitones(mix, self.sample_rate, semitone_shift=-self.pitch_shift)
if self.is_roformer:
# Note: Currently, for Roformer models, `batch_size` is not utilized due to negligible performance improvements.
mix = torch.tensor(mix, dtype=torch.float32)
if self.override_model_segment_size:
mdx_segment_size = self.segment_size
self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
else:
mdx_segment_size = self.model_data_cfgdict.inference.dim_t
self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
# num_stems aka "S" in UVR
num_stems = 1 if self.model_data_cfgdict.training.target_instrument else len(self.model_data_cfgdict.training.instruments)
self.logger.debug(f"Number of stems: {num_stems}")
# chunk_size aka "C" in UVR
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
self.logger.debug(f"Chunk size: {chunk_size}")
step = int(self.overlap * self.model_data_cfgdict.audio.sample_rate)
self.logger.debug(f"Step: {step}")
# Create a weighting table and convert it to a PyTorch tensor
window = torch.tensor(signal.windows.hamming(chunk_size), dtype=torch.float32)
device = next(self.model_run.parameters()).device
with torch.no_grad():
req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
for i in tqdm(range(0, mix.shape[1], step)):
part = mix[:, i : i + chunk_size]
length = part.shape[-1]
if i + chunk_size > mix.shape[1]:
part = mix[:, -chunk_size:]
length = chunk_size
part = part.to(device)
x = self.model_run(part.unsqueeze(0))[0]
x = x.cpu()
# Perform overlap_add on CPU
if i + chunk_size > mix.shape[1]:
# Fixed to correctly add to the end of the tensor
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
counter[..., result.shape[-1] - chunk_size :] += window[:length]
else:
result = self.overlap_add(result, x, window, i, length)
counter[..., i : i + length] += window[:length]
inferenced_outputs = result / counter.clamp(min=1e-10)
else:
mix = torch.tensor(mix, dtype=torch.float32)
try:
num_stems = self.model_run.num_target_instruments
except AttributeError:
num_stems = self.model_run.module.num_target_instruments
self.logger.debug(f"Number of stems: {num_stems}")
if self.override_model_segment_size:
mdx_segment_size = self.segment_size
self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
else:
mdx_segment_size = self.model_data_cfgdict.inference.dim_t
self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
self.logger.debug(f"Chunk size: {chunk_size}")
hop_size = chunk_size // self.overlap
self.logger.debug(f"Hop size: {hop_size}")
mix_shape = mix.shape[1]
pad_size = hop_size - (mix_shape - chunk_size) % hop_size
self.logger.debug(f"Pad size: {pad_size}")
mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
self.logger.debug(f"Mix shape: {mix.shape}")
chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
self.logger.debug(f"Chunks length: {len(chunks)} and shape: {chunks.shape}")
batches = [chunks[i : i + self.batch_size] for i in range(0, len(chunks), self.batch_size)]
self.logger.debug(f"Batch size: {self.batch_size}, number of batches: {len(batches)}")
# accumulated_outputs is used to accumulate the output from processing each batch of chunks through the model.
# It starts as a tensor of zeros and is updated in-place as the model processes each batch.
# The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
with torch.no_grad():
count = 0
for batch in tqdm(batches):
# Since the model processes the audio data in batches, single_batch_result temporarily holds the model's output
# for each batch before it is accumulated into accumulated_outputs.
single_batch_result = self.model_run(batch.to(self.torch_device))
# Each individual output tensor from the current batch's processing result.
# Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
# individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
for individual_output in single_batch_result:
individual_output_cpu = individual_output.cpu()
# Accumulate outputs on CPU
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
count += 1
self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
inferenced_outputs = accumulated_outputs[..., chunk_size - hop_size : -(pad_size + chunk_size - hop_size)] / self.overlap
self.logger.debug("Deleting accumulated outputs to free up memory")
del accumulated_outputs
if num_stems > 1 or self.is_primary_stem_main_target:
self.logger.debug("Number of stems is greater than 1 or vocals are main target, detaching individual sources and correcting pitch if necessary...")
sources = {}
# Iterates over each instrument specified in the model's configuration and its corresponding separated audio source.
# self.model_data_cfgdict.training.instruments provides the list of stems.
# estimated_sources.cpu().detach().numpy() converts the separated sources tensor to a NumPy array for processing.
# Each iteration provides an instrument name ('key') and its separated audio ('value') for further processing.
for key, value in zip(self.model_data_cfgdict.training.instruments, inferenced_outputs.cpu().detach().numpy()):
self.logger.debug(f"Processing instrument: {key}")
if self.pitch_shift != 0:
self.logger.debug(f"Applying pitch correction for {key}")
sources[key] = self.pitch_fix(value, sample_rate, orig_mix)
else:
sources[key] = value
if self.is_primary_stem_main_target:
self.logger.debug(f"Primary stem: {self.primary_stem_name} is main target, detaching and matching array shapes if necessary...")
if sources[self.primary_stem_name].shape[1] != orig_mix.shape[1]:
sources[self.primary_stem_name] = spec_utils.match_array_shapes(sources[self.primary_stem_name], orig_mix)
sources[self.secondary_stem_name] = orig_mix - sources[self.primary_stem_name]
self.logger.debug("Deleting inferenced outputs to free up memory")
del inferenced_outputs
self.logger.debug("Returning separated sources")
return sources
else:
self.logger.debug("Processing single source...")
if self.is_roformer:
sources = {k: v.cpu().detach().numpy() for k, v in zip([self.model_data_cfgdict.training.target_instrument], inferenced_outputs)}
inferenced_output = sources[self.model_data_cfgdict.training.target_instrument]
else:
inferenced_output = inferenced_outputs.cpu().detach().numpy()
self.logger.debug("Demix process completed for single source.")
self.logger.debug("Deleting inferenced outputs to free up memory")
del inferenced_outputs
if self.pitch_shift != 0:
self.logger.debug("Applying pitch correction for single instrument")
return self.pitch_fix(inferenced_output, sample_rate, orig_mix)
else:
self.logger.debug("Returning inferenced output for single instrument")
return inferenced_output