ASesYusuf1's picture
Upload 131 files
01f8b5b verified
#!/usr/bin/env python
import argparse
import logging
import json
import sys
import os
from importlib import metadata
def main():
"""Main entry point for the CLI."""
logger = logging.getLogger(__name__)
log_handler = logging.StreamHandler()
log_formatter = logging.Formatter(fmt="%(asctime)s.%(msecs)03d - %(levelname)s - %(module)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
log_handler.setFormatter(log_formatter)
logger.addHandler(log_handler)
parser = argparse.ArgumentParser(description="Separate audio file into different stems.", formatter_class=lambda prog: argparse.RawTextHelpFormatter(prog, max_help_position=60))
parser.add_argument("audio_files", nargs="*", help="The audio file paths or directory to separate, in any common format.", default=argparse.SUPPRESS)
package_version = metadata.distribution("audio-separator").version
version_help = "Show the program's version number and exit."
debug_help = "Enable debug logging, equivalent to --log_level=debug."
env_info_help = "Print environment information and exit."
list_models_help = "List all supported models and exit. Use --list_filter to filter/sort the list and --list_limit to show only top N results."
log_level_help = "Log level, e.g. info, debug, warning (default: %(default)s)."
info_params = parser.add_argument_group("Info and Debugging")
info_params.add_argument("-v", "--version", action="version", version=f"%(prog)s {package_version}", help=version_help)
info_params.add_argument("-d", "--debug", action="store_true", help=debug_help)
info_params.add_argument("-e", "--env_info", action="store_true", help=env_info_help)
info_params.add_argument("-l", "--list_models", action="store_true", help=list_models_help)
info_params.add_argument("--log_level", default="info", help=log_level_help)
info_params.add_argument("--list_filter", help="Filter and sort the model list by 'name', 'filename', or any stem e.g. vocals, instrumental, drums")
info_params.add_argument("--list_limit", type=int, help="Limit the number of models shown")
info_params.add_argument("--list_format", choices=["pretty", "json"], default="pretty", help="Format for listing models: 'pretty' for formatted output, 'json' for raw JSON dump")
model_filename_help = "Model to use for separation (default: %(default)s). Example: -m 2_HP-UVR.pth"
output_format_help = "Output format for separated files, any common format (default: %(default)s). Example: --output_format=MP3"
output_bitrate_help = "Output bitrate for separated files, any ffmpeg-compatible bitrate (default: %(default)s). Example: --output_bitrate=320k"
output_dir_help = "Directory to write output files (default: <current dir>). Example: --output_dir=/app/separated"
model_file_dir_help = "Model files directory (default: %(default)s or AUDIO_SEPARATOR_MODEL_DIR env var if set). Example: --model_file_dir=/app/models"
download_model_only_help = "Download a single model file only, without performing separation."
io_params = parser.add_argument_group("Separation I/O Params")
io_params.add_argument("-m", "--model_filename", default="model_bs_roformer_ep_317_sdr_12.9755.ckpt", help=model_filename_help)
io_params.add_argument("--output_format", default="FLAC", help=output_format_help)
io_params.add_argument("--output_bitrate", default=None, help=output_bitrate_help)
io_params.add_argument("--output_dir", default=None, help=output_dir_help)
io_params.add_argument("--model_file_dir", default="/tmp/audio-separator-models/", help=model_file_dir_help)
io_params.add_argument("--download_model_only", action="store_true", help=download_model_only_help)
invert_spect_help = "Invert secondary stem using spectrogram (default: %(default)s). Example: --invert_spect"
normalization_help = "Max peak amplitude to normalize input and output audio to (default: %(default)s). Example: --normalization=0.7"
amplification_help = "Min peak amplitude to amplify input and output audio to (default: %(default)s). Example: --amplification=0.4"
single_stem_help = "Output only single stem, e.g. Instrumental, Vocals, Drums, Bass, Guitar, Piano, Other. Example: --single_stem=Instrumental"
sample_rate_help = "Modify the sample rate of the output audio (default: %(default)s). Example: --sample_rate=44100"
use_soundfile_help = "Use soundfile to write audio output (default: %(default)s). Example: --use_soundfile"
use_autocast_help = "Use PyTorch autocast for faster inference (default: %(default)s). Do not use for CPU inference. Example: --use_autocast"
custom_output_names_help = 'Custom names for all output files in JSON format (default: %(default)s). Example: --custom_output_names=\'{"Vocals": "vocals_output", "Drums": "drums_output"}\''
common_params = parser.add_argument_group("Common Separation Parameters")
common_params.add_argument("--invert_spect", action="store_true", help=invert_spect_help)
common_params.add_argument("--normalization", type=float, default=0.9, help=normalization_help)
common_params.add_argument("--amplification", type=float, default=0.0, help=amplification_help)
common_params.add_argument("--single_stem", default=None, help=single_stem_help)
common_params.add_argument("--sample_rate", type=int, default=44100, help=sample_rate_help)
common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help)
common_params.add_argument("--use_autocast", action="store_true", help=use_autocast_help)
common_params.add_argument("--custom_output_names", type=json.loads, default=None, help=custom_output_names_help)
mdx_segment_size_help = "Larger consumes more resources, but may give better results (default: %(default)s). Example: --mdx_segment_size=256"
mdx_overlap_help = "Amount of overlap between prediction windows, 0.001-0.999. Higher is better but slower (default: %(default)s). Example: --mdx_overlap=0.25"
mdx_batch_size_help = "Larger consumes more RAM but may process slightly faster (default: %(default)s). Example: --mdx_batch_size=4"
mdx_hop_length_help = "Usually called stride in neural networks, only change if you know what you're doing (default: %(default)s). Example: --mdx_hop_length=1024"
mdx_enable_denoise_help = "Enable denoising during separation (default: %(default)s). Example: --mdx_enable_denoise"
mdx_params = parser.add_argument_group("MDX Architecture Parameters")
mdx_params.add_argument("--mdx_segment_size", type=int, default=256, help=mdx_segment_size_help)
mdx_params.add_argument("--mdx_overlap", type=float, default=0.25, help=mdx_overlap_help)
mdx_params.add_argument("--mdx_batch_size", type=int, default=1, help=mdx_batch_size_help)
mdx_params.add_argument("--mdx_hop_length", type=int, default=1024, help=mdx_hop_length_help)
mdx_params.add_argument("--mdx_enable_denoise", action="store_true", help=mdx_enable_denoise_help)
vr_batch_size_help = "Number of batches to process at a time. Higher = more RAM, slightly faster processing (default: %(default)s). Example: --vr_batch_size=16"
vr_window_size_help = "Balance quality and speed. 1024 = fast but lower, 320 = slower but better quality. (default: %(default)s). Example: --vr_window_size=320"
vr_aggression_help = "Intensity of primary stem extraction, -100 - 100. Typically, 5 for vocals & instrumentals (default: %(default)s). Example: --vr_aggression=2"
vr_enable_tta_help = "Enable Test-Time-Augmentation; slow but improves quality (default: %(default)s). Example: --vr_enable_tta"
vr_high_end_process_help = "Mirror the missing frequency range of the output (default: %(default)s). Example: --vr_high_end_process"
vr_enable_post_process_help = "Identify leftover artifacts within vocal output; may improve separation for some songs (default: %(default)s). Example: --vr_enable_post_process"
vr_post_process_threshold_help = "Threshold for post_process feature: 0.1-0.3 (default: %(default)s). Example: --vr_post_process_threshold=0.1"
vr_params = parser.add_argument_group("VR Architecture Parameters")
vr_params.add_argument("--vr_batch_size", type=int, default=1, help=vr_batch_size_help)
vr_params.add_argument("--vr_window_size", type=int, default=512, help=vr_window_size_help)
vr_params.add_argument("--vr_aggression", type=int, default=5, help=vr_aggression_help)
vr_params.add_argument("--vr_enable_tta", action="store_true", help=vr_enable_tta_help)
vr_params.add_argument("--vr_high_end_process", action="store_true", help=vr_high_end_process_help)
vr_params.add_argument("--vr_enable_post_process", action="store_true", help=vr_enable_post_process_help)
vr_params.add_argument("--vr_post_process_threshold", type=float, default=0.2, help=vr_post_process_threshold_help)
demucs_segment_size_help = "Size of segments into which the audio is split, 1-100. Higher = slower but better quality (default: %(default)s). Example: --demucs_segment_size=256"
demucs_shifts_help = "Number of predictions with random shifts, higher = slower but better quality (default: %(default)s). Example: --demucs_shifts=4"
demucs_overlap_help = "Overlap between prediction windows, 0.001-0.999. Higher = slower but better quality (default: %(default)s). Example: --demucs_overlap=0.25"
demucs_segments_enabled_help = "Enable segment-wise processing (default: %(default)s). Example: --demucs_segments_enabled=False"
demucs_params = parser.add_argument_group("Demucs Architecture Parameters")
demucs_params.add_argument("--demucs_segment_size", type=str, default="Default", help=demucs_segment_size_help)
demucs_params.add_argument("--demucs_shifts", type=int, default=2, help=demucs_shifts_help)
demucs_params.add_argument("--demucs_overlap", type=float, default=0.25, help=demucs_overlap_help)
demucs_params.add_argument("--demucs_segments_enabled", type=bool, default=True, help=demucs_segments_enabled_help)
mdxc_segment_size_help = "Larger consumes more resources, but may give better results (default: %(default)s). Example: --mdxc_segment_size=256"
mdxc_override_model_segment_size_help = "Override model default segment size instead of using the model default value. Example: --mdxc_override_model_segment_size"
mdxc_overlap_help = "Amount of overlap between prediction windows, 2-50. Higher is better but slower (default: %(default)s). Example: --mdxc_overlap=8"
mdxc_batch_size_help = "Larger consumes more RAM but may process slightly faster (default: %(default)s). Example: --mdxc_batch_size=4"
mdxc_pitch_shift_help = "Shift audio pitch by a number of semitones while processing. May improve output for deep/high vocals. (default: %(default)s). Example: --mdxc_pitch_shift=2"
mdxc_params = parser.add_argument_group("MDXC Architecture Parameters")
mdxc_params.add_argument("--mdxc_segment_size", type=int, default=256, help=mdxc_segment_size_help)
mdxc_params.add_argument("--mdxc_override_model_segment_size", action="store_true", help=mdxc_override_model_segment_size_help)
mdxc_params.add_argument("--mdxc_overlap", type=int, default=8, help=mdxc_overlap_help)
mdxc_params.add_argument("--mdxc_batch_size", type=int, default=1, help=mdxc_batch_size_help)
mdxc_params.add_argument("--mdxc_pitch_shift", type=int, default=0, help=mdxc_pitch_shift_help)
args = parser.parse_args()
if args.debug:
log_level = logging.DEBUG
else:
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
from audio_separator.separator import Separator
if args.env_info:
separator = Separator()
sys.exit(0)
if args.list_models:
separator = Separator(info_only=True)
if args.list_format == "json":
model_list = separator.list_supported_model_files()
print(json.dumps(model_list, indent=2))
else:
models = separator.get_simplified_model_list(filter_sort_by=args.list_filter)
# Apply limit if specified
if args.list_limit and args.list_limit > 0:
models = dict(list(models.items())[: args.list_limit])
# Calculate maximum widths for each column
filename_width = max(len("Model Filename"), max(len(filename) for filename in models.keys()))
arch_width = max(len("Arch"), max(len(info["Type"]) for info in models.values()))
stems_width = max(len("Output Stems (SDR)"), max(len(", ".join(info["Stems"])) for info in models.values()))
name_width = max(len("Friendly Name"), max(len(info["Name"]) for info in models.values()))
# Calculate total width for separator line
total_width = filename_width + arch_width + stems_width + name_width + 15 # 15 accounts for spacing between columns
# Format the output with dynamic widths and extra spacing
print("-" * total_width)
print(f"{'Model Filename':<{filename_width}} {'Arch':<{arch_width}} {'Output Stems (SDR)':<{stems_width}} {'Friendly Name'}")
print("-" * total_width)
for filename, info in models.items():
stems = ", ".join(info["Stems"])
print(f"{filename:<{filename_width}} {info['Type']:<{arch_width}} {stems:<{stems_width}} {info['Name']}")
sys.exit(0)
if args.download_model_only:
logger.info(f"Separator version {package_version} downloading model {args.model_filename} to directory {args.model_file_dir}")
separator = Separator(log_formatter=log_formatter, log_level=log_level, model_file_dir=args.model_file_dir)
separator.download_model_and_data(args.model_filename)
logger.info(f"Model {args.model_filename} downloaded successfully.")
sys.exit(0)
if not hasattr(args, "audio_files"):
parser.print_help()
sys.exit(1)
# Path processing: if a directory is specified, collect all audio files from it
audio_files = []
for path in args.audio_files:
if os.path.isdir(path):
# If the path is a directory, recursively search for all audio files
for root, dirs, files in os.walk(path):
for file in files:
# Check the file extension to ensure it's an audio file
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
audio_files.append(os.path.join(root, file))
else:
# If the path is a file, add it to the list
audio_files.append(path)
# If no audio files are found, log an error and exit the program
if not audio_files:
logger.error("No valid audio files found in the specified path(s).")
sys.exit(1)
logger.info(f"Separator version {package_version} beginning with input file(s): {', '.join(audio_files)}")
separator = Separator(
log_formatter=log_formatter,
log_level=log_level,
model_file_dir=args.model_file_dir,
output_dir=args.output_dir,
output_format=args.output_format,
output_bitrate=args.output_bitrate,
normalization_threshold=args.normalization,
amplification_threshold=args.amplification,
output_single_stem=args.single_stem,
invert_using_spec=args.invert_spect,
sample_rate=args.sample_rate,
use_soundfile=args.use_soundfile,
use_autocast=args.use_autocast,
mdx_params={
"hop_length": args.mdx_hop_length,
"segment_size": args.mdx_segment_size,
"overlap": args.mdx_overlap,
"batch_size": args.mdx_batch_size,
"enable_denoise": args.mdx_enable_denoise,
},
vr_params={
"batch_size": args.vr_batch_size,
"window_size": args.vr_window_size,
"aggression": args.vr_aggression,
"enable_tta": args.vr_enable_tta,
"enable_post_process": args.vr_enable_post_process,
"post_process_threshold": args.vr_post_process_threshold,
"high_end_process": args.vr_high_end_process,
},
demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled},
mdxc_params={
"segment_size": args.mdxc_segment_size,
"batch_size": args.mdxc_batch_size,
"overlap": args.mdxc_overlap,
"override_model_segment_size": args.mdxc_override_model_segment_size,
"pitch_shift": args.mdxc_pitch_shift,
},
)
separator.load_model(model_filename=args.model_filename)
for audio_file in audio_files:
output_files = separator.separate(audio_file, custom_output_names=args.custom_output_names)
logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}")