Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import pytest | |
from pathlib import Path | |
from tests.utils import generate_reference_images, compare_images | |
def fixture_input_file(): | |
"""Fixture providing the test input audio file path.""" | |
return "tests/inputs/mardy20s.flac" | |
def fixture_reference_dir(): | |
"""Fixture providing the reference images directory path.""" | |
return "tests/inputs/reference" | |
def fixture_cleanup_output_files(): | |
"""Fixture to clean up output files before and after test.""" | |
# This list will be populated by the test functions | |
output_files = [] | |
# Yield to allow the test to run and add files to the list | |
yield output_files | |
# Clean up output files after test | |
for file in output_files: | |
if os.path.exists(file): | |
print(f"Test output file exists: {file}") | |
os.remove(file) | |
def run_separation_test(model, audio_path, expected_files): | |
"""Helper function to run a separation test with a specific model.""" | |
# Clean up any existing output files before the test | |
for file in expected_files: | |
if os.path.exists(file): | |
print(f"Deleting existing test output file {file}") | |
os.remove(file) | |
# Run the CLI command | |
result = subprocess.run(["audio-separator", "-m", model, audio_path], capture_output=True, text=True, check=False) # Explicitly set check to False as we handle errors manually | |
# Check that the command completed successfully | |
assert result.returncode == 0, f"Command failed with output: {result.stderr}" | |
# Check that the output files were created | |
for file in expected_files: | |
assert os.path.exists(file), f"Output file {file} was not created" | |
assert os.path.getsize(file) > 0, f"Output file {file} is empty" | |
return result | |
def validate_audio_output(output_file, reference_dir, waveform_threshold=0.999, spectrogram_threshold=None): | |
"""Validate an audio output file by comparing its waveform and spectrogram with reference images. | |
Args: | |
output_file: Path to the audio output file | |
reference_dir: Directory containing reference images | |
waveform_threshold: Minimum similarity required for waveform images (0.0-1.0) | |
spectrogram_threshold: Minimum similarity for spectrogram images (0.0-1.0), defaults to waveform_threshold if None | |
Returns: | |
Tuple of booleans: (waveform_match, spectrogram_match) | |
""" | |
# If spectrogram threshold not specified, use the same as waveform threshold | |
if spectrogram_threshold is None: | |
spectrogram_threshold = waveform_threshold | |
# Create temporary directory for generated images | |
temp_dir = os.path.join(os.path.dirname(output_file), "temp_images") | |
os.makedirs(temp_dir, exist_ok=True) | |
# Generate waveform and spectrogram images for the output file | |
output_filename = os.path.basename(output_file) | |
name_without_ext = os.path.splitext(output_filename)[0] | |
# Generate actual images | |
actual_waveform_path, actual_spectrogram_path = generate_reference_images(output_file, temp_dir, prefix="actual_") | |
# Path to expected reference images | |
expected_waveform_path = os.path.join(reference_dir, f"expected_{name_without_ext}_waveform.png") | |
expected_spectrogram_path = os.path.join(reference_dir, f"expected_{name_without_ext}_spectrogram.png") | |
# Check if reference images exist | |
if not os.path.exists(expected_waveform_path) or not os.path.exists(expected_spectrogram_path): | |
print(f"Warning: Reference images not found for {output_file}") | |
print(f"Expected: {expected_waveform_path} and {expected_spectrogram_path}") | |
return False, False | |
# Compare waveform images | |
waveform_similarity, waveform_match = compare_images(expected_waveform_path, actual_waveform_path, min_similarity_threshold=waveform_threshold) | |
# Compare spectrogram images | |
spectrogram_similarity, spectrogram_match = compare_images(expected_spectrogram_path, actual_spectrogram_path, min_similarity_threshold=spectrogram_threshold) | |
print(f"Validation results for {output_file}:\n") | |
print(f" Waveform similarity: {waveform_similarity:.4f} (match: {waveform_match}, threshold: {waveform_threshold:.2f})\n") | |
print(f" Spectrogram similarity: {spectrogram_similarity:.4f} (match: {spectrogram_match}, threshold: {spectrogram_threshold:.2f})\n") | |
# Cleanup temp images (optional, uncomment if needed) | |
# os.remove(actual_waveform_path) | |
# os.remove(actual_spectrogram_path) | |
return waveform_match, spectrogram_match | |
# Default similarity threshold to use for most models | |
DEFAULT_SIMILARITY_THRESHOLDS = (0.90, 0.80) # (waveform_threshold, spectrogram_threshold) | |
# Model-specific similarity thresholds | |
# Use lower thresholds for models that show more variation between runs | |
MODEL_SIMILARITY_THRESHOLDS = { | |
# Format: (waveform_threshold, spectrogram_threshold) | |
"htdemucs_6s.yaml": (0.90, 0.70) # Demucs multi-stem output (e.g. "Other" and "Piano") is a lot more variable | |
} | |
# Parameterized test for multiple models | |
MODEL_PARAMS = [ | |
# (model_filename, expected_output_filenames) | |
("kuielab_b_vocals.onnx", ["mardy20s_(Instrumental)_kuielab_b_vocals.flac", "mardy20s_(Vocals)_kuielab_b_vocals.flac"]), | |
("MGM_MAIN_v4.pth", ["mardy20s_(Instrumental)_MGM_MAIN_v4.flac", "mardy20s_(Vocals)_MGM_MAIN_v4.flac"]), | |
("UVR-MDX-NET-Inst_HQ_4.onnx", ["mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4.flac", "mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4.flac"]), | |
("2_HP-UVR.pth", ["mardy20s_(Instrumental)_2_HP-UVR.flac", "mardy20s_(Vocals)_2_HP-UVR.flac"]), | |
( | |
"htdemucs_6s.yaml", | |
[ | |
"mardy20s_(Vocals)_htdemucs_6s.flac", | |
"mardy20s_(Drums)_htdemucs_6s.flac", | |
"mardy20s_(Bass)_htdemucs_6s.flac", | |
"mardy20s_(Other)_htdemucs_6s.flac", | |
"mardy20s_(Guitar)_htdemucs_6s.flac", | |
"mardy20s_(Piano)_htdemucs_6s.flac", | |
], | |
), | |
("model_bs_roformer_ep_937_sdr_10.5309.ckpt", ["mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac", "mardy20s_(No Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac"]), | |
("model_bs_roformer_ep_317_sdr_12.9755.ckpt", ["mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12.flac", "mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12.flac"]), | |
] | |
def test_model_separation(model, expected_files, input_file, reference_dir, cleanup_output_files): | |
"""Parameterized test for multiple model files.""" | |
# Add files to the cleanup list | |
cleanup_output_files.extend(expected_files) | |
# Run the test | |
run_separation_test(model, input_file, expected_files) | |
# Validate the output audio files | |
print(f"\nValidating output files for model {model}...") | |
# Get model-specific similarity threshold or use default | |
threshold = MODEL_SIMILARITY_THRESHOLDS.get(model, DEFAULT_SIMILARITY_THRESHOLDS) | |
# Unpack thresholds - DEFAULT_SIMILARITY_THRESHOLDS is now always a tuple | |
waveform_threshold, spectrogram_threshold = threshold | |
print(f"Using thresholds - waveform: {waveform_threshold}, spectrogram: {spectrogram_threshold} for model {model}") | |
for output_file in expected_files: | |
# Skip validation if reference images are not required (set environment variable to skip) | |
if os.environ.get("SKIP_AUDIO_VALIDATION") == "1": | |
print(f"Skipping audio validation for {output_file} (SKIP_AUDIO_VALIDATION=1)") | |
continue | |
waveform_match, spectrogram_match = validate_audio_output(output_file, reference_dir, waveform_threshold=waveform_threshold, spectrogram_threshold=spectrogram_threshold) | |
# Assert that the output matches the reference | |
assert waveform_match, f"Waveform for {output_file} does not match the reference" | |
assert spectrogram_match, f"Spectrogram for {output_file} does not match the reference" | |