Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
from pathlib import Path | |
import torch | |
import numpy as np | |
from audio_separator.separator.common_separator import CommonSeparator | |
from audio_separator.separator.uvr_lib_v5.demucs.apply import apply_model, demucs_segments | |
from audio_separator.separator.uvr_lib_v5.demucs.hdemucs import HDemucs | |
from audio_separator.separator.uvr_lib_v5.demucs.pretrained import get_model as get_demucs_model | |
from audio_separator.separator.uvr_lib_v5 import spec_utils | |
DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"] | |
DEMUCS_2_SOURCE_MAPPER = {CommonSeparator.INST_STEM: 0, CommonSeparator.VOCAL_STEM: 1} | |
DEMUCS_4_SOURCE_MAPPER = {CommonSeparator.BASS_STEM: 0, CommonSeparator.DRUM_STEM: 1, CommonSeparator.OTHER_STEM: 2, CommonSeparator.VOCAL_STEM: 3} | |
DEMUCS_6_SOURCE_MAPPER = { | |
CommonSeparator.BASS_STEM: 0, | |
CommonSeparator.DRUM_STEM: 1, | |
CommonSeparator.OTHER_STEM: 2, | |
CommonSeparator.VOCAL_STEM: 3, | |
CommonSeparator.GUITAR_STEM: 4, | |
CommonSeparator.PIANO_STEM: 5, | |
} | |
class DemucsSeparator(CommonSeparator): | |
""" | |
DemucsSeparator is responsible for separating audio sources using Demucs models. | |
It initializes with configuration parameters and prepares the model for separation tasks. | |
""" | |
def __init__(self, common_config, arch_config): | |
# Any configuration values which can be shared between architectures should be set already in CommonSeparator, | |
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name) | |
super().__init__(config=common_config) | |
# Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance | |
# Adjust segments to manage RAM or V-RAM usage: | |
# - Smaller sizes consume less resources. | |
# - Bigger sizes consume more resources, but may provide better results. | |
# - "Default" picks the optimal size. | |
# DEMUCS_SEGMENTS = (DEF_OPT, '1', '5', '10', '15', '20', | |
# '25', '30', '35', '40', '45', '50', | |
# '55', '60', '65', '70', '75', '80', | |
# '85', '90', '95', '100') | |
self.segment_size = arch_config.get("segment_size", "Default") | |
# Performs multiple predictions with random shifts of the input and averages them. | |
# The higher number of shifts, the longer the prediction will take. | |
# Not recommended unless you have a GPU. | |
# DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5, | |
# 6, 7, 8, 9, 10, 11, | |
# 12, 13, 14, 15, 16, 17, | |
# 18, 19, 20) | |
self.shifts = arch_config.get("shifts", 2) | |
# This option controls the amount of overlap between prediction windows. | |
# - Higher values can provide better results, but will lead to longer processing times. | |
# - You can choose between 0.001-0.999 | |
# DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99) | |
self.overlap = arch_config.get("overlap", 0.25) | |
# Enables "Segments". Deselecting this option is only recommended for those with powerful PCs. | |
self.segments_enabled = arch_config.get("segments_enabled", True) | |
self.logger.debug(f"Demucs arch params: segment_size={self.segment_size}, segments_enabled={self.segments_enabled}") | |
self.logger.debug(f"Demucs arch params: shifts={self.shifts}, overlap={self.overlap}") | |
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER | |
self.audio_file_path = None | |
self.audio_file_base = None | |
self.demucs_model_instance = None | |
# Add uvr_lib_v5 folder to system path so pytorch serialization can find the demucs module | |
current_dir = os.path.dirname(__file__) | |
uvr_lib_v5_path = os.path.join(current_dir, "..", "uvr_lib_v5") | |
sys.path.insert(0, uvr_lib_v5_path) | |
self.logger.info("Demucs Separator initialisation complete") | |
def separate(self, audio_file_path, custom_output_names=None): | |
""" | |
Separates the audio file into its component stems using the Demucs model. | |
Args: | |
audio_file_path (str): The path to the audio file to be processed. | |
custom_output_names (dict, optional): Custom names for the output files. Defaults to None. | |
Returns: | |
list: A list of paths to the output files generated by the separation process. | |
""" | |
self.logger.debug("Starting separation process...") | |
source = None | |
stem_source = None | |
inst_source = {} | |
self.audio_file_path = audio_file_path | |
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] | |
# Prepare the mix for processing | |
self.logger.debug("Preparing mix...") | |
mix = self.prepare_mix(self.audio_file_path) | |
self.logger.debug(f"Mix prepared for demixing. Shape: {mix.shape}") | |
self.logger.debug("Loading model for demixing...") | |
self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE) | |
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path))) | |
self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance) | |
self.demucs_model_instance.to(self.torch_device) | |
self.demucs_model_instance.eval() | |
self.logger.debug("Model loaded and set to evaluation mode.") | |
source = self.demix_demucs(mix) | |
del self.demucs_model_instance | |
self.clear_gpu_cache() | |
self.logger.debug("Model and GPU cache cleared after demixing.") | |
output_files = [] | |
self.logger.debug("Processing output files...") | |
if isinstance(inst_source, np.ndarray): | |
self.logger.debug("Processing instance source...") | |
source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]]) | |
inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape | |
source = inst_source | |
if isinstance(source, np.ndarray): | |
source_length = len(source) | |
self.logger.debug(f"Processing source array, source length is {source_length}") | |
match source_length: | |
case 2: | |
self.logger.debug("Setting source map to 2-stem...") | |
self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER | |
case 6: | |
self.logger.debug("Setting source map to 6-stem...") | |
self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER | |
case _: | |
self.logger.debug("Setting source map to 4-stem...") | |
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER | |
self.logger.debug("Processing for all stems...") | |
for stem_name, stem_value in self.demucs_source_map.items(): | |
if self.output_single_stem is not None: | |
if stem_name.lower() != self.output_single_stem.lower(): | |
self.logger.debug(f"Skipping writing stem {stem_name} as output_single_stem is set to {self.output_single_stem}...") | |
continue | |
stem_path = self.get_stem_output_path(stem_name, custom_output_names) | |
stem_source = source[stem_value].T | |
self.final_process(stem_path, stem_source, stem_name) | |
output_files.append(stem_path) | |
return output_files | |
def demix_demucs(self, mix): | |
""" | |
Demixes the input mix using the demucs model. | |
""" | |
self.logger.debug("Starting demixing process in demix_demucs...") | |
processed = {} | |
mix = torch.tensor(mix, dtype=torch.float32) | |
ref = mix.mean(0) | |
mix = (mix - ref.mean()) / ref.std() | |
mix_infer = mix | |
with torch.no_grad(): | |
self.logger.debug("Running model inference...") | |
sources = apply_model( | |
model=self.demucs_model_instance, | |
mix=mix_infer[None], | |
shifts=self.shifts, | |
split=self.segments_enabled, | |
overlap=self.overlap, | |
static_shifts=1 if self.shifts == 0 else self.shifts, | |
set_progress_bar=None, | |
device=self.torch_device, | |
progress=True, | |
)[0] | |
sources = (sources * ref.std() + ref.mean()).cpu().numpy() | |
sources[[0, 1]] = sources[[1, 0]] | |
processed[mix] = sources[:, :, 0:None].copy() | |
sources = list(processed.values()) | |
sources = [s[:, :, 0:None] for s in sources] | |
sources = np.concatenate(sources, axis=-1) | |
return sources | |