Spaces:
Running
on
Zero
Running
on
Zero
""" This file contains the CommonSeparator class, common to all architecture-specific Separator classes. """ | |
from logging import Logger | |
import os | |
import re | |
import gc | |
import numpy as np | |
import librosa | |
import torch | |
from pydub import AudioSegment | |
import soundfile as sf | |
from audio_separator.separator.uvr_lib_v5 import spec_utils | |
class CommonSeparator: | |
""" | |
This class contains the common methods and attributes common to all architecture-specific Separator classes. | |
""" | |
ALL_STEMS = "All Stems" | |
VOCAL_STEM = "Vocals" | |
INST_STEM = "Instrumental" | |
OTHER_STEM = "Other" | |
BASS_STEM = "Bass" | |
DRUM_STEM = "Drums" | |
GUITAR_STEM = "Guitar" | |
PIANO_STEM = "Piano" | |
SYNTH_STEM = "Synthesizer" | |
STRINGS_STEM = "Strings" | |
WOODWINDS_STEM = "Woodwinds" | |
BRASS_STEM = "Brass" | |
WIND_INST_STEM = "Wind Inst" | |
NO_OTHER_STEM = "No Other" | |
NO_BASS_STEM = "No Bass" | |
NO_DRUM_STEM = "No Drums" | |
NO_GUITAR_STEM = "No Guitar" | |
NO_PIANO_STEM = "No Piano" | |
NO_SYNTH_STEM = "No Synthesizer" | |
NO_STRINGS_STEM = "No Strings" | |
NO_WOODWINDS_STEM = "No Woodwinds" | |
NO_WIND_INST_STEM = "No Wind Inst" | |
NO_BRASS_STEM = "No Brass" | |
PRIMARY_STEM = "Primary Stem" | |
SECONDARY_STEM = "Secondary Stem" | |
LEAD_VOCAL_STEM = "lead_only" | |
BV_VOCAL_STEM = "backing_only" | |
LEAD_VOCAL_STEM_I = "with_lead_vocals" | |
BV_VOCAL_STEM_I = "with_backing_vocals" | |
LEAD_VOCAL_STEM_LABEL = "Lead Vocals" | |
BV_VOCAL_STEM_LABEL = "Backing Vocals" | |
NO_STEM = "No " | |
STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM} | |
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM) | |
def __init__(self, config): | |
self.logger: Logger = config.get("logger") | |
self.log_level: int = config.get("log_level") | |
# Inferencing device / acceleration config | |
self.torch_device = config.get("torch_device") | |
self.torch_device_cpu = config.get("torch_device_cpu") | |
self.torch_device_mps = config.get("torch_device_mps") | |
self.onnx_execution_provider = config.get("onnx_execution_provider") | |
# Model data | |
self.model_name = config.get("model_name") | |
self.model_path = config.get("model_path") | |
self.model_data = config.get("model_data") | |
# Output directory and format | |
self.output_dir = config.get("output_dir") | |
self.output_format = config.get("output_format") | |
self.output_bitrate = config.get("output_bitrate") | |
# Functional options which are applicable to all architectures and the user may tweak to affect the output | |
self.normalization_threshold = config.get("normalization_threshold") | |
self.amplification_threshold = config.get("amplification_threshold") | |
self.enable_denoise = config.get("enable_denoise") | |
self.output_single_stem = config.get("output_single_stem") | |
self.invert_using_spec = config.get("invert_using_spec") | |
self.sample_rate = config.get("sample_rate") | |
self.use_soundfile = config.get("use_soundfile") | |
# Model specific properties | |
# Check if model_data has a "training" key with "instruments" list | |
self.primary_stem_name = None | |
self.secondary_stem_name = None | |
if "training" in self.model_data and "instruments" in self.model_data["training"]: | |
instruments = self.model_data["training"]["instruments"] | |
if instruments: | |
self.primary_stem_name = instruments[0] | |
self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name) | |
if self.primary_stem_name is None: | |
self.primary_stem_name = self.model_data.get("primary_stem", "Vocals") | |
self.secondary_stem_name = self.secondary_stem(self.primary_stem_name) | |
self.is_karaoke = self.model_data.get("is_karaoke", False) | |
self.is_bv_model = self.model_data.get("is_bv_model", False) | |
self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0) | |
self.logger.debug(f"Common params: model_name={self.model_name}, model_path={self.model_path}") | |
self.logger.debug(f"Common params: output_dir={self.output_dir}, output_format={self.output_format}") | |
self.logger.debug(f"Common params: normalization_threshold={self.normalization_threshold}, amplification_threshold={self.amplification_threshold}") | |
self.logger.debug(f"Common params: enable_denoise={self.enable_denoise}, output_single_stem={self.output_single_stem}") | |
self.logger.debug(f"Common params: invert_using_spec={self.invert_using_spec}, sample_rate={self.sample_rate}") | |
self.logger.debug(f"Common params: primary_stem_name={self.primary_stem_name}, secondary_stem_name={self.secondary_stem_name}") | |
self.logger.debug(f"Common params: is_karaoke={self.is_karaoke}, is_bv_model={self.is_bv_model}, bv_model_rebalance={self.bv_model_rebalance}") | |
# File-specific variables which need to be cleared between processing different audio inputs | |
self.audio_file_path = None | |
self.audio_file_base = None | |
self.primary_source = None | |
self.secondary_source = None | |
self.primary_stem_output_path = None | |
self.secondary_stem_output_path = None | |
self.cached_sources_map = {} | |
def secondary_stem(self, primary_stem: str): | |
"""Determines secondary stem name based on the primary stem name.""" | |
primary_stem = primary_stem if primary_stem else self.NO_STEM | |
if primary_stem in self.STEM_PAIR_MAPPER: | |
secondary_stem = self.STEM_PAIR_MAPPER[primary_stem] | |
else: | |
secondary_stem = primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}" | |
return secondary_stem | |
def separate(self, audio_file_path): | |
""" | |
Placeholder method for separating audio sources. Should be overridden by subclasses. | |
""" | |
raise NotImplementedError("This method should be overridden by subclasses.") | |
def final_process(self, stem_path, source, stem_name): | |
""" | |
Finalizes the processing of a stem by writing the audio to a file and returning the processed source. | |
""" | |
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...") | |
self.write_audio(stem_path, source) | |
return {stem_name: source} | |
def cached_sources_clear(self): | |
""" | |
Clears the cache dictionaries for VR, MDX, and Demucs models. | |
This function is essential for ensuring that the cache does not hold outdated or irrelevant data | |
between different processing sessions or when a new batch of audio files is processed. | |
It helps in managing memory efficiently and prevents potential errors due to stale data. | |
""" | |
self.cached_sources_map = {} | |
def cached_source_callback(self, model_architecture, model_name=None): | |
""" | |
Retrieves the model and sources from the cache based on the processing method and model name. | |
Args: | |
model_architecture: The architecture type (VR, MDX, or Demucs) being used for processing. | |
model_name: The specific model name within the architecture type, if applicable. | |
Returns: | |
A tuple containing the model and its sources if found in the cache; otherwise, None. | |
This function is crucial for optimizing performance by avoiding redundant processing. | |
If the requested model and its sources are already in the cache, they can be reused directly, | |
saving time and computational resources. | |
""" | |
model, sources = None, None | |
mapper = self.cached_sources_map[model_architecture] | |
for key, value in mapper.items(): | |
if model_name in key: | |
model = key | |
sources = value | |
return model, sources | |
def cached_model_source_holder(self, model_architecture, sources, model_name=None): | |
""" | |
Update the dictionary for the given model_architecture with the new model name and its sources. | |
Use the model_architecture as a key to access the corresponding cache source mapper dictionary. | |
""" | |
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}} | |
def prepare_mix(self, mix): | |
""" | |
Prepares the mix for processing. This includes loading the audio from a file if necessary, | |
ensuring the mix is in the correct format, and converting mono to stereo if needed. | |
""" | |
# Store the original path or the mix itself for later checks | |
audio_path = mix | |
# Check if the input is a file path (string) and needs to be loaded | |
if not isinstance(mix, np.ndarray): | |
self.logger.debug(f"Loading audio from file: {mix}") | |
mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate) | |
self.logger.debug(f"Audio loaded. Sample rate: {sr}, Audio shape: {mix.shape}") | |
else: | |
# Transpose the mix if it's already an ndarray (expected shape: [channels, samples]) | |
self.logger.debug("Transposing the provided mix array.") | |
mix = mix.T | |
self.logger.debug(f"Transposed mix shape: {mix.shape}") | |
# If the original input was a filepath, check if the loaded mix is empty | |
if isinstance(audio_path, str): | |
if not np.any(mix): | |
error_msg = f"Audio file {audio_path} is empty or not valid" | |
self.logger.error(error_msg) | |
raise ValueError(error_msg) | |
else: | |
self.logger.debug("Audio file is valid and contains data.") | |
# Ensure the mix is in stereo format | |
if mix.ndim == 1: | |
self.logger.debug("Mix is mono. Converting to stereo.") | |
mix = np.asfortranarray([mix, mix]) | |
self.logger.debug("Converted to stereo mix.") | |
# Final log indicating successful preparation of the mix | |
self.logger.debug("Mix preparation completed.") | |
return mix | |
def write_audio(self, stem_path: str, stem_source): | |
""" | |
Writes the separated audio source to a file using pydub or soundfile | |
Pydub supports a much wider range of audio formats and produces better encoded lossy files for some formats. | |
Soundfile is used for very large files (longer than 1 hour), as pydub has memory issues with large files: | |
https://github.com/jiaaro/pydub/issues/135 | |
""" | |
# Get the duration of the input audio file | |
duration_seconds = librosa.get_duration(filename=self.audio_file_path) | |
duration_hours = duration_seconds / 3600 | |
self.logger.info(f"Audio duration is {duration_hours:.2f} hours ({duration_seconds:.2f} seconds).") | |
if self.use_soundfile: | |
self.logger.warning(f"Using soundfile for writing.") | |
self.write_audio_soundfile(stem_path, stem_source) | |
else: | |
self.logger.info(f"Using pydub for writing.") | |
self.write_audio_pydub(stem_path, stem_source) | |
def write_audio_pydub(self, stem_path: str, stem_source): | |
""" | |
Writes the separated audio source to a file using pydub (ffmpeg) | |
""" | |
self.logger.debug(f"Entering write_audio_pydub with stem_path: {stem_path}") | |
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold) | |
# Check if the numpy array is empty or contains very low values | |
if np.max(np.abs(stem_source)) < 1e-6: | |
self.logger.warning("Warning: stem_source array is near-silent or empty.") | |
return | |
# If output_dir is specified, create it and join it with stem_path | |
if self.output_dir: | |
os.makedirs(self.output_dir, exist_ok=True) | |
stem_path = os.path.join(self.output_dir, stem_path) | |
self.logger.debug(f"Audio data shape before processing: {stem_source.shape}") | |
self.logger.debug(f"Data type before conversion: {stem_source.dtype}") | |
# Ensure the audio data is in the correct format (e.g., int16) | |
if stem_source.dtype != np.int16: | |
stem_source = (stem_source * 32767).astype(np.int16) | |
self.logger.debug("Converted stem_source to int16.") | |
# Correctly interleave stereo channels | |
stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16) | |
stem_source_interleaved[0::2] = stem_source[:, 0] # Left channel | |
stem_source_interleaved[1::2] = stem_source[:, 1] # Right channel | |
self.logger.debug(f"Interleaved audio data shape: {stem_source_interleaved.shape}") | |
# Create a pydub AudioSegment | |
try: | |
audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2) | |
self.logger.debug("Created AudioSegment successfully.") | |
except (IOError, ValueError) as e: | |
self.logger.error(f"Specific error creating AudioSegment: {e}") | |
return | |
# Determine file format based on the file extension | |
file_format = stem_path.lower().split(".")[-1] | |
# For m4a files, specify mp4 as the container format as the extension doesn't match the format name | |
if file_format == "m4a": | |
file_format = "mp4" | |
elif file_format == "mka": | |
file_format = "matroska" | |
# Set the bitrate to 320k for mp3 files if output_bitrate is not specified | |
bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate | |
# Export using the determined format | |
try: | |
audio_segment.export(stem_path, format=file_format, bitrate=bitrate) | |
self.logger.debug(f"Exported audio file successfully to {stem_path}") | |
except (IOError, ValueError) as e: | |
self.logger.error(f"Error exporting audio file: {e}") | |
def write_audio_soundfile(self, stem_path: str, stem_source): | |
""" | |
Writes the separated audio source to a file using soundfile library. | |
""" | |
self.logger.debug(f"Entering write_audio_soundfile with stem_path: {stem_path}") | |
# Correctly interleave stereo channels if needed | |
if stem_source.shape[1] == 2: | |
# If the audio is already interleaved, ensure it's in the correct order | |
# Check if the array is Fortran contiguous (column-major) | |
if stem_source.flags["F_CONTIGUOUS"]: | |
# Convert to C contiguous (row-major) | |
stem_source = np.ascontiguousarray(stem_source) | |
# Otherwise, perform interleaving | |
else: | |
stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16) | |
# Left channel | |
stereo_interleaved[0::2] = stem_source[:, 0] | |
# Right channel | |
stereo_interleaved[1::2] = stem_source[:, 1] | |
stem_source = stereo_interleaved | |
self.logger.debug(f"Interleaved audio data shape: {stem_source.shape}") | |
""" | |
Write audio using soundfile (for formats other than M4A). | |
""" | |
# Save audio using soundfile | |
try: | |
# Specify the subtype to define the sample width | |
sf.write(stem_path, stem_source, self.sample_rate) | |
self.logger.debug(f"Exported audio file successfully to {stem_path}") | |
except Exception as e: | |
self.logger.error(f"Error exporting audio file: {e}") | |
def clear_gpu_cache(self): | |
""" | |
This method clears the GPU cache to free up memory. | |
""" | |
self.logger.debug("Running garbage collection...") | |
gc.collect() | |
if self.torch_device == torch.device("mps"): | |
self.logger.debug("Clearing MPS cache...") | |
torch.mps.empty_cache() | |
if self.torch_device == torch.device("cuda"): | |
self.logger.debug("Clearing CUDA cache...") | |
torch.cuda.empty_cache() | |
def clear_file_specific_paths(self): | |
""" | |
Clears the file-specific variables which need to be cleared between processing different audio inputs. | |
""" | |
self.logger.info("Clearing input audio file paths, sources and stems...") | |
self.audio_file_path = None | |
self.audio_file_base = None | |
self.primary_source = None | |
self.secondary_source = None | |
self.primary_stem_output_path = None | |
self.secondary_stem_output_path = None | |
def sanitize_filename(self, filename): | |
""" | |
Cleans the filename by replacing invalid characters with underscores. | |
""" | |
sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename) | |
sanitized = re.sub(r'_+', '_', sanitized) | |
sanitized = sanitized.strip('_. ') | |
return sanitized | |
def get_stem_output_path(self, stem_name, custom_output_names): | |
""" | |
Gets the output path for a stem based on the stem name and custom output names. | |
""" | |
# Convert custom_output_names keys to lowercase for case-insensitive comparison | |
if custom_output_names: | |
custom_output_names_lower = {k.lower(): v for k, v in custom_output_names.items()} | |
stem_name_lower = stem_name.lower() | |
if stem_name_lower in custom_output_names_lower: | |
sanitized_custom_name = self.sanitize_filename(custom_output_names_lower[stem_name_lower]) | |
return os.path.join(f"{sanitized_custom_name}.{self.output_format.lower()}") | |
sanitized_audio_base = self.sanitize_filename(self.audio_file_base) | |
sanitized_stem_name = self.sanitize_filename(stem_name) | |
sanitized_model_name = self.sanitize_filename(self.model_name) | |
filename = f"{sanitized_audio_base}_({sanitized_stem_name})_{sanitized_model_name}.{self.output_format.lower()}" | |
return os.path.join(filename) | |