dgfsfxc-tgsacxs-otyhrhs / tests /integration /test_cli_integration.py
ASesYusuf1's picture
Upload 131 files
01f8b5b verified
import os
import subprocess
import pytest
from pathlib import Path
from tests.utils import generate_reference_images, compare_images
@pytest.fixture(name="input_file")
def fixture_input_file():
"""Fixture providing the test input audio file path."""
return "tests/inputs/mardy20s.flac"
@pytest.fixture(name="reference_dir")
def fixture_reference_dir():
"""Fixture providing the reference images directory path."""
return "tests/inputs/reference"
@pytest.fixture(name="cleanup_output_files")
def fixture_cleanup_output_files():
"""Fixture to clean up output files before and after test."""
# This list will be populated by the test functions
output_files = []
# Yield to allow the test to run and add files to the list
yield output_files
# Clean up output files after test
for file in output_files:
if os.path.exists(file):
print(f"Test output file exists: {file}")
os.remove(file)
def run_separation_test(model, audio_path, expected_files):
"""Helper function to run a separation test with a specific model."""
# Clean up any existing output files before the test
for file in expected_files:
if os.path.exists(file):
print(f"Deleting existing test output file {file}")
os.remove(file)
# Run the CLI command
result = subprocess.run(["audio-separator", "-m", model, audio_path], capture_output=True, text=True, check=False) # Explicitly set check to False as we handle errors manually
# Check that the command completed successfully
assert result.returncode == 0, f"Command failed with output: {result.stderr}"
# Check that the output files were created
for file in expected_files:
assert os.path.exists(file), f"Output file {file} was not created"
assert os.path.getsize(file) > 0, f"Output file {file} is empty"
return result
def validate_audio_output(output_file, reference_dir, waveform_threshold=0.999, spectrogram_threshold=None):
"""Validate an audio output file by comparing its waveform and spectrogram with reference images.
Args:
output_file: Path to the audio output file
reference_dir: Directory containing reference images
waveform_threshold: Minimum similarity required for waveform images (0.0-1.0)
spectrogram_threshold: Minimum similarity for spectrogram images (0.0-1.0), defaults to waveform_threshold if None
Returns:
Tuple of booleans: (waveform_match, spectrogram_match)
"""
# If spectrogram threshold not specified, use the same as waveform threshold
if spectrogram_threshold is None:
spectrogram_threshold = waveform_threshold
# Create temporary directory for generated images
temp_dir = os.path.join(os.path.dirname(output_file), "temp_images")
os.makedirs(temp_dir, exist_ok=True)
# Generate waveform and spectrogram images for the output file
output_filename = os.path.basename(output_file)
name_without_ext = os.path.splitext(output_filename)[0]
# Generate actual images
actual_waveform_path, actual_spectrogram_path = generate_reference_images(output_file, temp_dir, prefix="actual_")
# Path to expected reference images
expected_waveform_path = os.path.join(reference_dir, f"expected_{name_without_ext}_waveform.png")
expected_spectrogram_path = os.path.join(reference_dir, f"expected_{name_without_ext}_spectrogram.png")
# Check if reference images exist
if not os.path.exists(expected_waveform_path) or not os.path.exists(expected_spectrogram_path):
print(f"Warning: Reference images not found for {output_file}")
print(f"Expected: {expected_waveform_path} and {expected_spectrogram_path}")
return False, False
# Compare waveform images
waveform_similarity, waveform_match = compare_images(expected_waveform_path, actual_waveform_path, min_similarity_threshold=waveform_threshold)
# Compare spectrogram images
spectrogram_similarity, spectrogram_match = compare_images(expected_spectrogram_path, actual_spectrogram_path, min_similarity_threshold=spectrogram_threshold)
print(f"Validation results for {output_file}:\n")
print(f" Waveform similarity: {waveform_similarity:.4f} (match: {waveform_match}, threshold: {waveform_threshold:.2f})\n")
print(f" Spectrogram similarity: {spectrogram_similarity:.4f} (match: {spectrogram_match}, threshold: {spectrogram_threshold:.2f})\n")
# Cleanup temp images (optional, uncomment if needed)
# os.remove(actual_waveform_path)
# os.remove(actual_spectrogram_path)
return waveform_match, spectrogram_match
# Default similarity threshold to use for most models
DEFAULT_SIMILARITY_THRESHOLDS = (0.90, 0.80) # (waveform_threshold, spectrogram_threshold)
# Model-specific similarity thresholds
# Use lower thresholds for models that show more variation between runs
MODEL_SIMILARITY_THRESHOLDS = {
# Format: (waveform_threshold, spectrogram_threshold)
"htdemucs_6s.yaml": (0.90, 0.70) # Demucs multi-stem output (e.g. "Other" and "Piano") is a lot more variable
}
# Parameterized test for multiple models
MODEL_PARAMS = [
# (model_filename, expected_output_filenames)
("kuielab_b_vocals.onnx", ["mardy20s_(Instrumental)_kuielab_b_vocals.flac", "mardy20s_(Vocals)_kuielab_b_vocals.flac"]),
("MGM_MAIN_v4.pth", ["mardy20s_(Instrumental)_MGM_MAIN_v4.flac", "mardy20s_(Vocals)_MGM_MAIN_v4.flac"]),
("UVR-MDX-NET-Inst_HQ_4.onnx", ["mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4.flac", "mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4.flac"]),
("2_HP-UVR.pth", ["mardy20s_(Instrumental)_2_HP-UVR.flac", "mardy20s_(Vocals)_2_HP-UVR.flac"]),
(
"htdemucs_6s.yaml",
[
"mardy20s_(Vocals)_htdemucs_6s.flac",
"mardy20s_(Drums)_htdemucs_6s.flac",
"mardy20s_(Bass)_htdemucs_6s.flac",
"mardy20s_(Other)_htdemucs_6s.flac",
"mardy20s_(Guitar)_htdemucs_6s.flac",
"mardy20s_(Piano)_htdemucs_6s.flac",
],
),
("model_bs_roformer_ep_937_sdr_10.5309.ckpt", ["mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac", "mardy20s_(No Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac"]),
("model_bs_roformer_ep_317_sdr_12.9755.ckpt", ["mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12.flac", "mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12.flac"]),
]
@pytest.mark.parametrize("model,expected_files", MODEL_PARAMS)
def test_model_separation(model, expected_files, input_file, reference_dir, cleanup_output_files):
"""Parameterized test for multiple model files."""
# Add files to the cleanup list
cleanup_output_files.extend(expected_files)
# Run the test
run_separation_test(model, input_file, expected_files)
# Validate the output audio files
print(f"\nValidating output files for model {model}...")
# Get model-specific similarity threshold or use default
threshold = MODEL_SIMILARITY_THRESHOLDS.get(model, DEFAULT_SIMILARITY_THRESHOLDS)
# Unpack thresholds - DEFAULT_SIMILARITY_THRESHOLDS is now always a tuple
waveform_threshold, spectrogram_threshold = threshold
print(f"Using thresholds - waveform: {waveform_threshold}, spectrogram: {spectrogram_threshold} for model {model}")
for output_file in expected_files:
# Skip validation if reference images are not required (set environment variable to skip)
if os.environ.get("SKIP_AUDIO_VALIDATION") == "1":
print(f"Skipping audio validation for {output_file} (SKIP_AUDIO_VALIDATION=1)")
continue
waveform_match, spectrogram_match = validate_audio_output(output_file, reference_dir, waveform_threshold=waveform_threshold, spectrogram_threshold=spectrogram_threshold)
# Assert that the output matches the reference
assert waveform_match, f"Waveform for {output_file} does not match the reference"
assert spectrogram_match, f"Spectrogram for {output_file} does not match the reference"