Spaces:
Running
on
Zero
Running
on
Zero
File size: 18,271 Bytes
01f8b5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 |
""" This file contains the CommonSeparator class, common to all architecture-specific Separator classes. """
from logging import Logger
import os
import re
import gc
import numpy as np
import librosa
import torch
from pydub import AudioSegment
import soundfile as sf
from audio_separator.separator.uvr_lib_v5 import spec_utils
class CommonSeparator:
"""
This class contains the common methods and attributes common to all architecture-specific Separator classes.
"""
ALL_STEMS = "All Stems"
VOCAL_STEM = "Vocals"
INST_STEM = "Instrumental"
OTHER_STEM = "Other"
BASS_STEM = "Bass"
DRUM_STEM = "Drums"
GUITAR_STEM = "Guitar"
PIANO_STEM = "Piano"
SYNTH_STEM = "Synthesizer"
STRINGS_STEM = "Strings"
WOODWINDS_STEM = "Woodwinds"
BRASS_STEM = "Brass"
WIND_INST_STEM = "Wind Inst"
NO_OTHER_STEM = "No Other"
NO_BASS_STEM = "No Bass"
NO_DRUM_STEM = "No Drums"
NO_GUITAR_STEM = "No Guitar"
NO_PIANO_STEM = "No Piano"
NO_SYNTH_STEM = "No Synthesizer"
NO_STRINGS_STEM = "No Strings"
NO_WOODWINDS_STEM = "No Woodwinds"
NO_WIND_INST_STEM = "No Wind Inst"
NO_BRASS_STEM = "No Brass"
PRIMARY_STEM = "Primary Stem"
SECONDARY_STEM = "Secondary Stem"
LEAD_VOCAL_STEM = "lead_only"
BV_VOCAL_STEM = "backing_only"
LEAD_VOCAL_STEM_I = "with_lead_vocals"
BV_VOCAL_STEM_I = "with_backing_vocals"
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
BV_VOCAL_STEM_LABEL = "Backing Vocals"
NO_STEM = "No "
STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
def __init__(self, config):
self.logger: Logger = config.get("logger")
self.log_level: int = config.get("log_level")
# Inferencing device / acceleration config
self.torch_device = config.get("torch_device")
self.torch_device_cpu = config.get("torch_device_cpu")
self.torch_device_mps = config.get("torch_device_mps")
self.onnx_execution_provider = config.get("onnx_execution_provider")
# Model data
self.model_name = config.get("model_name")
self.model_path = config.get("model_path")
self.model_data = config.get("model_data")
# Output directory and format
self.output_dir = config.get("output_dir")
self.output_format = config.get("output_format")
self.output_bitrate = config.get("output_bitrate")
# Functional options which are applicable to all architectures and the user may tweak to affect the output
self.normalization_threshold = config.get("normalization_threshold")
self.amplification_threshold = config.get("amplification_threshold")
self.enable_denoise = config.get("enable_denoise")
self.output_single_stem = config.get("output_single_stem")
self.invert_using_spec = config.get("invert_using_spec")
self.sample_rate = config.get("sample_rate")
self.use_soundfile = config.get("use_soundfile")
# Model specific properties
# Check if model_data has a "training" key with "instruments" list
self.primary_stem_name = None
self.secondary_stem_name = None
if "training" in self.model_data and "instruments" in self.model_data["training"]:
instruments = self.model_data["training"]["instruments"]
if instruments:
self.primary_stem_name = instruments[0]
self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
if self.primary_stem_name is None:
self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
self.is_karaoke = self.model_data.get("is_karaoke", False)
self.is_bv_model = self.model_data.get("is_bv_model", False)
self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
self.logger.debug(f"Common params: model_name={self.model_name}, model_path={self.model_path}")
self.logger.debug(f"Common params: output_dir={self.output_dir}, output_format={self.output_format}")
self.logger.debug(f"Common params: normalization_threshold={self.normalization_threshold}, amplification_threshold={self.amplification_threshold}")
self.logger.debug(f"Common params: enable_denoise={self.enable_denoise}, output_single_stem={self.output_single_stem}")
self.logger.debug(f"Common params: invert_using_spec={self.invert_using_spec}, sample_rate={self.sample_rate}")
self.logger.debug(f"Common params: primary_stem_name={self.primary_stem_name}, secondary_stem_name={self.secondary_stem_name}")
self.logger.debug(f"Common params: is_karaoke={self.is_karaoke}, is_bv_model={self.is_bv_model}, bv_model_rebalance={self.bv_model_rebalance}")
# File-specific variables which need to be cleared between processing different audio inputs
self.audio_file_path = None
self.audio_file_base = None
self.primary_source = None
self.secondary_source = None
self.primary_stem_output_path = None
self.secondary_stem_output_path = None
self.cached_sources_map = {}
def secondary_stem(self, primary_stem: str):
"""Determines secondary stem name based on the primary stem name."""
primary_stem = primary_stem if primary_stem else self.NO_STEM
if primary_stem in self.STEM_PAIR_MAPPER:
secondary_stem = self.STEM_PAIR_MAPPER[primary_stem]
else:
secondary_stem = primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
return secondary_stem
def separate(self, audio_file_path):
"""
Placeholder method for separating audio sources. Should be overridden by subclasses.
"""
raise NotImplementedError("This method should be overridden by subclasses.")
def final_process(self, stem_path, source, stem_name):
"""
Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
"""
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
self.write_audio(stem_path, source)
return {stem_name: source}
def cached_sources_clear(self):
"""
Clears the cache dictionaries for VR, MDX, and Demucs models.
This function is essential for ensuring that the cache does not hold outdated or irrelevant data
between different processing sessions or when a new batch of audio files is processed.
It helps in managing memory efficiently and prevents potential errors due to stale data.
"""
self.cached_sources_map = {}
def cached_source_callback(self, model_architecture, model_name=None):
"""
Retrieves the model and sources from the cache based on the processing method and model name.
Args:
model_architecture: The architecture type (VR, MDX, or Demucs) being used for processing.
model_name: The specific model name within the architecture type, if applicable.
Returns:
A tuple containing the model and its sources if found in the cache; otherwise, None.
This function is crucial for optimizing performance by avoiding redundant processing.
If the requested model and its sources are already in the cache, they can be reused directly,
saving time and computational resources.
"""
model, sources = None, None
mapper = self.cached_sources_map[model_architecture]
for key, value in mapper.items():
if model_name in key:
model = key
sources = value
return model, sources
def cached_model_source_holder(self, model_architecture, sources, model_name=None):
"""
Update the dictionary for the given model_architecture with the new model name and its sources.
Use the model_architecture as a key to access the corresponding cache source mapper dictionary.
"""
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
def prepare_mix(self, mix):
"""
Prepares the mix for processing. This includes loading the audio from a file if necessary,
ensuring the mix is in the correct format, and converting mono to stereo if needed.
"""
# Store the original path or the mix itself for later checks
audio_path = mix
# Check if the input is a file path (string) and needs to be loaded
if not isinstance(mix, np.ndarray):
self.logger.debug(f"Loading audio from file: {mix}")
mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
self.logger.debug(f"Audio loaded. Sample rate: {sr}, Audio shape: {mix.shape}")
else:
# Transpose the mix if it's already an ndarray (expected shape: [channels, samples])
self.logger.debug("Transposing the provided mix array.")
mix = mix.T
self.logger.debug(f"Transposed mix shape: {mix.shape}")
# If the original input was a filepath, check if the loaded mix is empty
if isinstance(audio_path, str):
if not np.any(mix):
error_msg = f"Audio file {audio_path} is empty or not valid"
self.logger.error(error_msg)
raise ValueError(error_msg)
else:
self.logger.debug("Audio file is valid and contains data.")
# Ensure the mix is in stereo format
if mix.ndim == 1:
self.logger.debug("Mix is mono. Converting to stereo.")
mix = np.asfortranarray([mix, mix])
self.logger.debug("Converted to stereo mix.")
# Final log indicating successful preparation of the mix
self.logger.debug("Mix preparation completed.")
return mix
def write_audio(self, stem_path: str, stem_source):
"""
Writes the separated audio source to a file using pydub or soundfile
Pydub supports a much wider range of audio formats and produces better encoded lossy files for some formats.
Soundfile is used for very large files (longer than 1 hour), as pydub has memory issues with large files:
https://github.com/jiaaro/pydub/issues/135
"""
# Get the duration of the input audio file
duration_seconds = librosa.get_duration(filename=self.audio_file_path)
duration_hours = duration_seconds / 3600
self.logger.info(f"Audio duration is {duration_hours:.2f} hours ({duration_seconds:.2f} seconds).")
if self.use_soundfile:
self.logger.warning(f"Using soundfile for writing.")
self.write_audio_soundfile(stem_path, stem_source)
else:
self.logger.info(f"Using pydub for writing.")
self.write_audio_pydub(stem_path, stem_source)
def write_audio_pydub(self, stem_path: str, stem_source):
"""
Writes the separated audio source to a file using pydub (ffmpeg)
"""
self.logger.debug(f"Entering write_audio_pydub with stem_path: {stem_path}")
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
# Check if the numpy array is empty or contains very low values
if np.max(np.abs(stem_source)) < 1e-6:
self.logger.warning("Warning: stem_source array is near-silent or empty.")
return
# If output_dir is specified, create it and join it with stem_path
if self.output_dir:
os.makedirs(self.output_dir, exist_ok=True)
stem_path = os.path.join(self.output_dir, stem_path)
self.logger.debug(f"Audio data shape before processing: {stem_source.shape}")
self.logger.debug(f"Data type before conversion: {stem_source.dtype}")
# Ensure the audio data is in the correct format (e.g., int16)
if stem_source.dtype != np.int16:
stem_source = (stem_source * 32767).astype(np.int16)
self.logger.debug("Converted stem_source to int16.")
# Correctly interleave stereo channels
stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
stem_source_interleaved[0::2] = stem_source[:, 0] # Left channel
stem_source_interleaved[1::2] = stem_source[:, 1] # Right channel
self.logger.debug(f"Interleaved audio data shape: {stem_source_interleaved.shape}")
# Create a pydub AudioSegment
try:
audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
self.logger.debug("Created AudioSegment successfully.")
except (IOError, ValueError) as e:
self.logger.error(f"Specific error creating AudioSegment: {e}")
return
# Determine file format based on the file extension
file_format = stem_path.lower().split(".")[-1]
# For m4a files, specify mp4 as the container format as the extension doesn't match the format name
if file_format == "m4a":
file_format = "mp4"
elif file_format == "mka":
file_format = "matroska"
# Set the bitrate to 320k for mp3 files if output_bitrate is not specified
bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate
# Export using the determined format
try:
audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
self.logger.debug(f"Exported audio file successfully to {stem_path}")
except (IOError, ValueError) as e:
self.logger.error(f"Error exporting audio file: {e}")
def write_audio_soundfile(self, stem_path: str, stem_source):
"""
Writes the separated audio source to a file using soundfile library.
"""
self.logger.debug(f"Entering write_audio_soundfile with stem_path: {stem_path}")
# Correctly interleave stereo channels if needed
if stem_source.shape[1] == 2:
# If the audio is already interleaved, ensure it's in the correct order
# Check if the array is Fortran contiguous (column-major)
if stem_source.flags["F_CONTIGUOUS"]:
# Convert to C contiguous (row-major)
stem_source = np.ascontiguousarray(stem_source)
# Otherwise, perform interleaving
else:
stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
# Left channel
stereo_interleaved[0::2] = stem_source[:, 0]
# Right channel
stereo_interleaved[1::2] = stem_source[:, 1]
stem_source = stereo_interleaved
self.logger.debug(f"Interleaved audio data shape: {stem_source.shape}")
"""
Write audio using soundfile (for formats other than M4A).
"""
# Save audio using soundfile
try:
# Specify the subtype to define the sample width
sf.write(stem_path, stem_source, self.sample_rate)
self.logger.debug(f"Exported audio file successfully to {stem_path}")
except Exception as e:
self.logger.error(f"Error exporting audio file: {e}")
def clear_gpu_cache(self):
"""
This method clears the GPU cache to free up memory.
"""
self.logger.debug("Running garbage collection...")
gc.collect()
if self.torch_device == torch.device("mps"):
self.logger.debug("Clearing MPS cache...")
torch.mps.empty_cache()
if self.torch_device == torch.device("cuda"):
self.logger.debug("Clearing CUDA cache...")
torch.cuda.empty_cache()
def clear_file_specific_paths(self):
"""
Clears the file-specific variables which need to be cleared between processing different audio inputs.
"""
self.logger.info("Clearing input audio file paths, sources and stems...")
self.audio_file_path = None
self.audio_file_base = None
self.primary_source = None
self.secondary_source = None
self.primary_stem_output_path = None
self.secondary_stem_output_path = None
def sanitize_filename(self, filename):
"""
Cleans the filename by replacing invalid characters with underscores.
"""
sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename)
sanitized = re.sub(r'_+', '_', sanitized)
sanitized = sanitized.strip('_. ')
return sanitized
def get_stem_output_path(self, stem_name, custom_output_names):
"""
Gets the output path for a stem based on the stem name and custom output names.
"""
# Convert custom_output_names keys to lowercase for case-insensitive comparison
if custom_output_names:
custom_output_names_lower = {k.lower(): v for k, v in custom_output_names.items()}
stem_name_lower = stem_name.lower()
if stem_name_lower in custom_output_names_lower:
sanitized_custom_name = self.sanitize_filename(custom_output_names_lower[stem_name_lower])
return os.path.join(f"{sanitized_custom_name}.{self.output_format.lower()}")
sanitized_audio_base = self.sanitize_filename(self.audio_file_base)
sanitized_stem_name = self.sanitize_filename(stem_name)
sanitized_model_name = self.sanitize_filename(self.model_name)
filename = f"{sanitized_audio_base}_({sanitized_stem_name})_{sanitized_model_name}.{self.output_format.lower()}"
return os.path.join(filename)
|