Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import argparse | |
import logging | |
import json | |
import sys | |
import os | |
from importlib import metadata | |
def main(): | |
"""Main entry point for the CLI.""" | |
logger = logging.getLogger(__name__) | |
log_handler = logging.StreamHandler() | |
log_formatter = logging.Formatter(fmt="%(asctime)s.%(msecs)03d - %(levelname)s - %(module)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
log_handler.setFormatter(log_formatter) | |
logger.addHandler(log_handler) | |
parser = argparse.ArgumentParser(description="Separate audio file into different stems.", formatter_class=lambda prog: argparse.RawTextHelpFormatter(prog, max_help_position=60)) | |
parser.add_argument("audio_files", nargs="*", help="The audio file paths or directory to separate, in any common format.", default=argparse.SUPPRESS) | |
package_version = metadata.distribution("audio-separator").version | |
version_help = "Show the program's version number and exit." | |
debug_help = "Enable debug logging, equivalent to --log_level=debug." | |
env_info_help = "Print environment information and exit." | |
list_models_help = "List all supported models and exit. Use --list_filter to filter/sort the list and --list_limit to show only top N results." | |
log_level_help = "Log level, e.g. info, debug, warning (default: %(default)s)." | |
info_params = parser.add_argument_group("Info and Debugging") | |
info_params.add_argument("-v", "--version", action="version", version=f"%(prog)s {package_version}", help=version_help) | |
info_params.add_argument("-d", "--debug", action="store_true", help=debug_help) | |
info_params.add_argument("-e", "--env_info", action="store_true", help=env_info_help) | |
info_params.add_argument("-l", "--list_models", action="store_true", help=list_models_help) | |
info_params.add_argument("--log_level", default="info", help=log_level_help) | |
info_params.add_argument("--list_filter", help="Filter and sort the model list by 'name', 'filename', or any stem e.g. vocals, instrumental, drums") | |
info_params.add_argument("--list_limit", type=int, help="Limit the number of models shown") | |
info_params.add_argument("--list_format", choices=["pretty", "json"], default="pretty", help="Format for listing models: 'pretty' for formatted output, 'json' for raw JSON dump") | |
model_filename_help = "Model to use for separation (default: %(default)s). Example: -m 2_HP-UVR.pth" | |
output_format_help = "Output format for separated files, any common format (default: %(default)s). Example: --output_format=MP3" | |
output_bitrate_help = "Output bitrate for separated files, any ffmpeg-compatible bitrate (default: %(default)s). Example: --output_bitrate=320k" | |
output_dir_help = "Directory to write output files (default: <current dir>). Example: --output_dir=/app/separated" | |
model_file_dir_help = "Model files directory (default: %(default)s or AUDIO_SEPARATOR_MODEL_DIR env var if set). Example: --model_file_dir=/app/models" | |
download_model_only_help = "Download a single model file only, without performing separation." | |
io_params = parser.add_argument_group("Separation I/O Params") | |
io_params.add_argument("-m", "--model_filename", default="model_bs_roformer_ep_317_sdr_12.9755.ckpt", help=model_filename_help) | |
io_params.add_argument("--output_format", default="FLAC", help=output_format_help) | |
io_params.add_argument("--output_bitrate", default=None, help=output_bitrate_help) | |
io_params.add_argument("--output_dir", default=None, help=output_dir_help) | |
io_params.add_argument("--model_file_dir", default="/tmp/audio-separator-models/", help=model_file_dir_help) | |
io_params.add_argument("--download_model_only", action="store_true", help=download_model_only_help) | |
invert_spect_help = "Invert secondary stem using spectrogram (default: %(default)s). Example: --invert_spect" | |
normalization_help = "Max peak amplitude to normalize input and output audio to (default: %(default)s). Example: --normalization=0.7" | |
amplification_help = "Min peak amplitude to amplify input and output audio to (default: %(default)s). Example: --amplification=0.4" | |
single_stem_help = "Output only single stem, e.g. Instrumental, Vocals, Drums, Bass, Guitar, Piano, Other. Example: --single_stem=Instrumental" | |
sample_rate_help = "Modify the sample rate of the output audio (default: %(default)s). Example: --sample_rate=44100" | |
use_soundfile_help = "Use soundfile to write audio output (default: %(default)s). Example: --use_soundfile" | |
use_autocast_help = "Use PyTorch autocast for faster inference (default: %(default)s). Do not use for CPU inference. Example: --use_autocast" | |
custom_output_names_help = 'Custom names for all output files in JSON format (default: %(default)s). Example: --custom_output_names=\'{"Vocals": "vocals_output", "Drums": "drums_output"}\'' | |
common_params = parser.add_argument_group("Common Separation Parameters") | |
common_params.add_argument("--invert_spect", action="store_true", help=invert_spect_help) | |
common_params.add_argument("--normalization", type=float, default=0.9, help=normalization_help) | |
common_params.add_argument("--amplification", type=float, default=0.0, help=amplification_help) | |
common_params.add_argument("--single_stem", default=None, help=single_stem_help) | |
common_params.add_argument("--sample_rate", type=int, default=44100, help=sample_rate_help) | |
common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help) | |
common_params.add_argument("--use_autocast", action="store_true", help=use_autocast_help) | |
common_params.add_argument("--custom_output_names", type=json.loads, default=None, help=custom_output_names_help) | |
mdx_segment_size_help = "Larger consumes more resources, but may give better results (default: %(default)s). Example: --mdx_segment_size=256" | |
mdx_overlap_help = "Amount of overlap between prediction windows, 0.001-0.999. Higher is better but slower (default: %(default)s). Example: --mdx_overlap=0.25" | |
mdx_batch_size_help = "Larger consumes more RAM but may process slightly faster (default: %(default)s). Example: --mdx_batch_size=4" | |
mdx_hop_length_help = "Usually called stride in neural networks, only change if you know what you're doing (default: %(default)s). Example: --mdx_hop_length=1024" | |
mdx_enable_denoise_help = "Enable denoising during separation (default: %(default)s). Example: --mdx_enable_denoise" | |
mdx_params = parser.add_argument_group("MDX Architecture Parameters") | |
mdx_params.add_argument("--mdx_segment_size", type=int, default=256, help=mdx_segment_size_help) | |
mdx_params.add_argument("--mdx_overlap", type=float, default=0.25, help=mdx_overlap_help) | |
mdx_params.add_argument("--mdx_batch_size", type=int, default=1, help=mdx_batch_size_help) | |
mdx_params.add_argument("--mdx_hop_length", type=int, default=1024, help=mdx_hop_length_help) | |
mdx_params.add_argument("--mdx_enable_denoise", action="store_true", help=mdx_enable_denoise_help) | |
vr_batch_size_help = "Number of batches to process at a time. Higher = more RAM, slightly faster processing (default: %(default)s). Example: --vr_batch_size=16" | |
vr_window_size_help = "Balance quality and speed. 1024 = fast but lower, 320 = slower but better quality. (default: %(default)s). Example: --vr_window_size=320" | |
vr_aggression_help = "Intensity of primary stem extraction, -100 - 100. Typically, 5 for vocals & instrumentals (default: %(default)s). Example: --vr_aggression=2" | |
vr_enable_tta_help = "Enable Test-Time-Augmentation; slow but improves quality (default: %(default)s). Example: --vr_enable_tta" | |
vr_high_end_process_help = "Mirror the missing frequency range of the output (default: %(default)s). Example: --vr_high_end_process" | |
vr_enable_post_process_help = "Identify leftover artifacts within vocal output; may improve separation for some songs (default: %(default)s). Example: --vr_enable_post_process" | |
vr_post_process_threshold_help = "Threshold for post_process feature: 0.1-0.3 (default: %(default)s). Example: --vr_post_process_threshold=0.1" | |
vr_params = parser.add_argument_group("VR Architecture Parameters") | |
vr_params.add_argument("--vr_batch_size", type=int, default=1, help=vr_batch_size_help) | |
vr_params.add_argument("--vr_window_size", type=int, default=512, help=vr_window_size_help) | |
vr_params.add_argument("--vr_aggression", type=int, default=5, help=vr_aggression_help) | |
vr_params.add_argument("--vr_enable_tta", action="store_true", help=vr_enable_tta_help) | |
vr_params.add_argument("--vr_high_end_process", action="store_true", help=vr_high_end_process_help) | |
vr_params.add_argument("--vr_enable_post_process", action="store_true", help=vr_enable_post_process_help) | |
vr_params.add_argument("--vr_post_process_threshold", type=float, default=0.2, help=vr_post_process_threshold_help) | |
demucs_segment_size_help = "Size of segments into which the audio is split, 1-100. Higher = slower but better quality (default: %(default)s). Example: --demucs_segment_size=256" | |
demucs_shifts_help = "Number of predictions with random shifts, higher = slower but better quality (default: %(default)s). Example: --demucs_shifts=4" | |
demucs_overlap_help = "Overlap between prediction windows, 0.001-0.999. Higher = slower but better quality (default: %(default)s). Example: --demucs_overlap=0.25" | |
demucs_segments_enabled_help = "Enable segment-wise processing (default: %(default)s). Example: --demucs_segments_enabled=False" | |
demucs_params = parser.add_argument_group("Demucs Architecture Parameters") | |
demucs_params.add_argument("--demucs_segment_size", type=str, default="Default", help=demucs_segment_size_help) | |
demucs_params.add_argument("--demucs_shifts", type=int, default=2, help=demucs_shifts_help) | |
demucs_params.add_argument("--demucs_overlap", type=float, default=0.25, help=demucs_overlap_help) | |
demucs_params.add_argument("--demucs_segments_enabled", type=bool, default=True, help=demucs_segments_enabled_help) | |
mdxc_segment_size_help = "Larger consumes more resources, but may give better results (default: %(default)s). Example: --mdxc_segment_size=256" | |
mdxc_override_model_segment_size_help = "Override model default segment size instead of using the model default value. Example: --mdxc_override_model_segment_size" | |
mdxc_overlap_help = "Amount of overlap between prediction windows, 2-50. Higher is better but slower (default: %(default)s). Example: --mdxc_overlap=8" | |
mdxc_batch_size_help = "Larger consumes more RAM but may process slightly faster (default: %(default)s). Example: --mdxc_batch_size=4" | |
mdxc_pitch_shift_help = "Shift audio pitch by a number of semitones while processing. May improve output for deep/high vocals. (default: %(default)s). Example: --mdxc_pitch_shift=2" | |
mdxc_params = parser.add_argument_group("MDXC Architecture Parameters") | |
mdxc_params.add_argument("--mdxc_segment_size", type=int, default=256, help=mdxc_segment_size_help) | |
mdxc_params.add_argument("--mdxc_override_model_segment_size", action="store_true", help=mdxc_override_model_segment_size_help) | |
mdxc_params.add_argument("--mdxc_overlap", type=int, default=8, help=mdxc_overlap_help) | |
mdxc_params.add_argument("--mdxc_batch_size", type=int, default=1, help=mdxc_batch_size_help) | |
mdxc_params.add_argument("--mdxc_pitch_shift", type=int, default=0, help=mdxc_pitch_shift_help) | |
args = parser.parse_args() | |
if args.debug: | |
log_level = logging.DEBUG | |
else: | |
log_level = getattr(logging, args.log_level.upper()) | |
logger.setLevel(log_level) | |
from audio_separator.separator import Separator | |
if args.env_info: | |
separator = Separator() | |
sys.exit(0) | |
if args.list_models: | |
separator = Separator(info_only=True) | |
if args.list_format == "json": | |
model_list = separator.list_supported_model_files() | |
print(json.dumps(model_list, indent=2)) | |
else: | |
models = separator.get_simplified_model_list(filter_sort_by=args.list_filter) | |
# Apply limit if specified | |
if args.list_limit and args.list_limit > 0: | |
models = dict(list(models.items())[: args.list_limit]) | |
# Calculate maximum widths for each column | |
filename_width = max(len("Model Filename"), max(len(filename) for filename in models.keys())) | |
arch_width = max(len("Arch"), max(len(info["Type"]) for info in models.values())) | |
stems_width = max(len("Output Stems (SDR)"), max(len(", ".join(info["Stems"])) for info in models.values())) | |
name_width = max(len("Friendly Name"), max(len(info["Name"]) for info in models.values())) | |
# Calculate total width for separator line | |
total_width = filename_width + arch_width + stems_width + name_width + 15 # 15 accounts for spacing between columns | |
# Format the output with dynamic widths and extra spacing | |
print("-" * total_width) | |
print(f"{'Model Filename':<{filename_width}} {'Arch':<{arch_width}} {'Output Stems (SDR)':<{stems_width}} {'Friendly Name'}") | |
print("-" * total_width) | |
for filename, info in models.items(): | |
stems = ", ".join(info["Stems"]) | |
print(f"{filename:<{filename_width}} {info['Type']:<{arch_width}} {stems:<{stems_width}} {info['Name']}") | |
sys.exit(0) | |
if args.download_model_only: | |
logger.info(f"Separator version {package_version} downloading model {args.model_filename} to directory {args.model_file_dir}") | |
separator = Separator(log_formatter=log_formatter, log_level=log_level, model_file_dir=args.model_file_dir) | |
separator.download_model_and_data(args.model_filename) | |
logger.info(f"Model {args.model_filename} downloaded successfully.") | |
sys.exit(0) | |
if not hasattr(args, "audio_files"): | |
parser.print_help() | |
sys.exit(1) | |
# Path processing: if a directory is specified, collect all audio files from it | |
audio_files = [] | |
for path in args.audio_files: | |
if os.path.isdir(path): | |
# If the path is a directory, recursively search for all audio files | |
for root, dirs, files in os.walk(path): | |
for file in files: | |
# Check the file extension to ensure it's an audio file | |
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed | |
audio_files.append(os.path.join(root, file)) | |
else: | |
# If the path is a file, add it to the list | |
audio_files.append(path) | |
# If no audio files are found, log an error and exit the program | |
if not audio_files: | |
logger.error("No valid audio files found in the specified path(s).") | |
sys.exit(1) | |
logger.info(f"Separator version {package_version} beginning with input file(s): {', '.join(audio_files)}") | |
separator = Separator( | |
log_formatter=log_formatter, | |
log_level=log_level, | |
model_file_dir=args.model_file_dir, | |
output_dir=args.output_dir, | |
output_format=args.output_format, | |
output_bitrate=args.output_bitrate, | |
normalization_threshold=args.normalization, | |
amplification_threshold=args.amplification, | |
output_single_stem=args.single_stem, | |
invert_using_spec=args.invert_spect, | |
sample_rate=args.sample_rate, | |
use_soundfile=args.use_soundfile, | |
use_autocast=args.use_autocast, | |
mdx_params={ | |
"hop_length": args.mdx_hop_length, | |
"segment_size": args.mdx_segment_size, | |
"overlap": args.mdx_overlap, | |
"batch_size": args.mdx_batch_size, | |
"enable_denoise": args.mdx_enable_denoise, | |
}, | |
vr_params={ | |
"batch_size": args.vr_batch_size, | |
"window_size": args.vr_window_size, | |
"aggression": args.vr_aggression, | |
"enable_tta": args.vr_enable_tta, | |
"enable_post_process": args.vr_enable_post_process, | |
"post_process_threshold": args.vr_post_process_threshold, | |
"high_end_process": args.vr_high_end_process, | |
}, | |
demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled}, | |
mdxc_params={ | |
"segment_size": args.mdxc_segment_size, | |
"batch_size": args.mdxc_batch_size, | |
"overlap": args.mdxc_overlap, | |
"override_model_segment_size": args.mdxc_override_model_segment_size, | |
"pitch_shift": args.mdxc_pitch_shift, | |
}, | |
) | |
separator.load_model(model_filename=args.model_filename) | |
for audio_file in audio_files: | |
output_files = separator.separate(audio_file, custom_output_names=args.custom_output_names) | |
logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}") | |