File size: 8,043 Bytes
01f8b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import subprocess
import pytest
from pathlib import Path
from tests.utils import generate_reference_images, compare_images


@pytest.fixture(name="input_file")
def fixture_input_file():
    """Fixture providing the test input audio file path."""
    return "tests/inputs/mardy20s.flac"


@pytest.fixture(name="reference_dir")
def fixture_reference_dir():
    """Fixture providing the reference images directory path."""
    return "tests/inputs/reference"


@pytest.fixture(name="cleanup_output_files")
def fixture_cleanup_output_files():
    """Fixture to clean up output files before and after test."""
    # This list will be populated by the test functions
    output_files = []

    # Yield to allow the test to run and add files to the list
    yield output_files

    # Clean up output files after test
    for file in output_files:
        if os.path.exists(file):
            print(f"Test output file exists: {file}")
            os.remove(file)


def run_separation_test(model, audio_path, expected_files):
    """Helper function to run a separation test with a specific model."""
    # Clean up any existing output files before the test
    for file in expected_files:
        if os.path.exists(file):
            print(f"Deleting existing test output file {file}")
            os.remove(file)

    # Run the CLI command
    result = subprocess.run(["audio-separator", "-m", model, audio_path], capture_output=True, text=True, check=False)  # Explicitly set check to False as we handle errors manually

    # Check that the command completed successfully
    assert result.returncode == 0, f"Command failed with output: {result.stderr}"

    # Check that the output files were created
    for file in expected_files:
        assert os.path.exists(file), f"Output file {file} was not created"
        assert os.path.getsize(file) > 0, f"Output file {file} is empty"

    return result


def validate_audio_output(output_file, reference_dir, waveform_threshold=0.999, spectrogram_threshold=None):
    """Validate an audio output file by comparing its waveform and spectrogram with reference images.

    Args:
        output_file: Path to the audio output file
        reference_dir: Directory containing reference images
        waveform_threshold: Minimum similarity required for waveform images (0.0-1.0)
        spectrogram_threshold: Minimum similarity for spectrogram images (0.0-1.0), defaults to waveform_threshold if None

    Returns:
        Tuple of booleans: (waveform_match, spectrogram_match)
    """
    # If spectrogram threshold not specified, use the same as waveform threshold
    if spectrogram_threshold is None:
        spectrogram_threshold = waveform_threshold

    # Create temporary directory for generated images
    temp_dir = os.path.join(os.path.dirname(output_file), "temp_images")
    os.makedirs(temp_dir, exist_ok=True)

    # Generate waveform and spectrogram images for the output file
    output_filename = os.path.basename(output_file)
    name_without_ext = os.path.splitext(output_filename)[0]

    # Generate actual images
    actual_waveform_path, actual_spectrogram_path = generate_reference_images(output_file, temp_dir, prefix="actual_")

    # Path to expected reference images
    expected_waveform_path = os.path.join(reference_dir, f"expected_{name_without_ext}_waveform.png")
    expected_spectrogram_path = os.path.join(reference_dir, f"expected_{name_without_ext}_spectrogram.png")

    # Check if reference images exist
    if not os.path.exists(expected_waveform_path) or not os.path.exists(expected_spectrogram_path):
        print(f"Warning: Reference images not found for {output_file}")
        print(f"Expected: {expected_waveform_path} and {expected_spectrogram_path}")
        return False, False

    # Compare waveform images
    waveform_similarity, waveform_match = compare_images(expected_waveform_path, actual_waveform_path, min_similarity_threshold=waveform_threshold)

    # Compare spectrogram images
    spectrogram_similarity, spectrogram_match = compare_images(expected_spectrogram_path, actual_spectrogram_path, min_similarity_threshold=spectrogram_threshold)

    print(f"Validation results for {output_file}:\n")
    print(f"  Waveform similarity: {waveform_similarity:.4f} (match: {waveform_match}, threshold: {waveform_threshold:.2f})\n")
    print(f"  Spectrogram similarity: {spectrogram_similarity:.4f} (match: {spectrogram_match}, threshold: {spectrogram_threshold:.2f})\n")

    # Cleanup temp images (optional, uncomment if needed)
    # os.remove(actual_waveform_path)
    # os.remove(actual_spectrogram_path)

    return waveform_match, spectrogram_match


# Default similarity threshold to use for most models
DEFAULT_SIMILARITY_THRESHOLDS = (0.90, 0.80)  # (waveform_threshold, spectrogram_threshold)

# Model-specific similarity thresholds
# Use lower thresholds for models that show more variation between runs
MODEL_SIMILARITY_THRESHOLDS = {
    # Format: (waveform_threshold, spectrogram_threshold)
    "htdemucs_6s.yaml": (0.90, 0.70)  # Demucs multi-stem output (e.g. "Other" and "Piano") is a lot more variable
}


# Parameterized test for multiple models
MODEL_PARAMS = [
    # (model_filename, expected_output_filenames)
    ("kuielab_b_vocals.onnx", ["mardy20s_(Instrumental)_kuielab_b_vocals.flac", "mardy20s_(Vocals)_kuielab_b_vocals.flac"]),
    ("MGM_MAIN_v4.pth", ["mardy20s_(Instrumental)_MGM_MAIN_v4.flac", "mardy20s_(Vocals)_MGM_MAIN_v4.flac"]),
    ("UVR-MDX-NET-Inst_HQ_4.onnx", ["mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4.flac", "mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4.flac"]),
    ("2_HP-UVR.pth", ["mardy20s_(Instrumental)_2_HP-UVR.flac", "mardy20s_(Vocals)_2_HP-UVR.flac"]),
    (
        "htdemucs_6s.yaml",
        [
            "mardy20s_(Vocals)_htdemucs_6s.flac",
            "mardy20s_(Drums)_htdemucs_6s.flac",
            "mardy20s_(Bass)_htdemucs_6s.flac",
            "mardy20s_(Other)_htdemucs_6s.flac",
            "mardy20s_(Guitar)_htdemucs_6s.flac",
            "mardy20s_(Piano)_htdemucs_6s.flac",
        ],
    ),
    ("model_bs_roformer_ep_937_sdr_10.5309.ckpt", ["mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac", "mardy20s_(No Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac"]),
    ("model_bs_roformer_ep_317_sdr_12.9755.ckpt", ["mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12.flac", "mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12.flac"]),
]


@pytest.mark.parametrize("model,expected_files", MODEL_PARAMS)
def test_model_separation(model, expected_files, input_file, reference_dir, cleanup_output_files):
    """Parameterized test for multiple model files."""
    # Add files to the cleanup list
    cleanup_output_files.extend(expected_files)

    # Run the test
    run_separation_test(model, input_file, expected_files)

    # Validate the output audio files
    print(f"\nValidating output files for model {model}...")

    # Get model-specific similarity threshold or use default
    threshold = MODEL_SIMILARITY_THRESHOLDS.get(model, DEFAULT_SIMILARITY_THRESHOLDS)

    # Unpack thresholds - DEFAULT_SIMILARITY_THRESHOLDS is now always a tuple
    waveform_threshold, spectrogram_threshold = threshold

    print(f"Using thresholds - waveform: {waveform_threshold}, spectrogram: {spectrogram_threshold} for model {model}")

    for output_file in expected_files:
        # Skip validation if reference images are not required (set environment variable to skip)
        if os.environ.get("SKIP_AUDIO_VALIDATION") == "1":
            print(f"Skipping audio validation for {output_file} (SKIP_AUDIO_VALIDATION=1)")
            continue

        waveform_match, spectrogram_match = validate_audio_output(output_file, reference_dir, waveform_threshold=waveform_threshold, spectrogram_threshold=spectrogram_threshold)

        # Assert that the output matches the reference
        assert waveform_match, f"Waveform for {output_file} does not match the reference"
        assert spectrogram_match, f"Spectrogram for {output_file} does not match the reference"