import json import pytest import logging from audio_separator.utils.cli import main import subprocess from unittest import mock from unittest.mock import patch, MagicMock, mock_open # Common fixture for expected arguments @pytest.fixture def common_expected_args(): return { "log_formatter": mock.ANY, "log_level": logging.INFO, "model_file_dir": "/tmp/audio-separator-models/", "output_dir": None, "output_format": "FLAC", "output_bitrate": None, "normalization_threshold": 0.9, "amplification_threshold": 0.0, "output_single_stem": None, "invert_using_spec": False, "sample_rate": 44100, "use_autocast": False, "use_soundfile": False, "mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, "vr_params": {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, "demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, "mdxc_params": {"segment_size": 256, "batch_size": 1, "overlap": 8, "override_model_segment_size": False, "pitch_shift": 0}, } # Test the CLI with version argument using subprocess def test_cli_version_subprocess(): # Run the CLI script with the '--version' argument result = subprocess.run(["poetry", "run", "audio-separator", "--version"], capture_output=True, text=True) assert result.returncode == 0 assert "audio-separator" in result.stdout # Test with the short version flag '-v' result = subprocess.run(["poetry", "run", "audio-separator", "-v"], capture_output=True, text=True) assert result.returncode == 0 assert "audio-separator" in result.stdout # Test the CLI with no arguments def test_cli_no_args(capsys): result = subprocess.run(["poetry", "run", "audio-separator"], capture_output=True, text=True) assert result.returncode == 1 assert "usage:" in result.stdout # Test with multiple filename arguments def test_cli_multiple_filenames(): test_args = ["cli.py", "test1.mp3", "test2.mp3"] # Mock the open function to prevent actual file operations mock_file = mock_open() # Create a mock logger mock_logger = MagicMock() # Patch multiple functions to prevent actual file operations and separations with patch("sys.argv", test_args), patch("builtins.open", mock_file), patch("audio_separator.separator.Separator.separate") as mock_separate, patch( "audio_separator.separator.Separator.load_model" ), patch("logging.getLogger", return_value=mock_logger): # Mock the separate method to return some dummy output mock_separate.return_value = ["output_file1.mp3", "output_file2.mp3"] # Call the main function main() # Check if separate was called twice (once for each input file) assert mock_separate.call_count == 2 # Check if the logger captured information about both files log_messages = [call[0][0] for call in mock_logger.info.call_args_list] assert any("test1.mp3" in msg and "test2.mp3" in msg for msg in log_messages) assert any("Separation complete" in msg for msg in log_messages) # Test the CLI with a specific audio file def test_cli_with_audio_file(capsys, common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--model_filename=UVR-MDX-NET-Inst_HQ_4.onnx"] with patch("audio_separator.separator.Separator.separate") as mock_separate: mock_separate.return_value = ["output_file.mp3"] with patch("sys.argv", test_args): # Call the main function in cli.py main() # Update expected args for this specific test common_expected_args["model_file_dir"] = "/tmp/audio-separator-models/" # Check if the separate method was called with the correct arguments mock_separate.assert_called_once() # Assertions assert mock_separate.called # Test the CLI with invalid log level def test_cli_invalid_log_level(): test_args = ["cli.py", "test_audio.mp3", "--log_level=invalid"] with patch("sys.argv", test_args): # Assert an attribute error is raised due to the invalid LogLevel with pytest.raises(AttributeError): # Call the main function in cli.py main() # Test using model name argument def test_cli_model_filename_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--model_filename=Custom_Model.onnx"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Assertions mock_separator.assert_called_once_with(**common_expected_args) mock_separator_instance.load_model.assert_called_once_with(model_filename="Custom_Model.onnx") # Test using output directory argument def test_cli_output_dir_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--output_dir=/custom/output/dir"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["output_dir"] = "/custom/output/dir" # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using output format argument def test_cli_output_format_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--output_format=MP3"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["output_format"] = "MP3" # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using normalization_threshold argument def test_cli_normalization_threshold_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--normalization=0.75"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["normalization_threshold"] = 0.75 # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using amplification_threshold argument def test_cli_amplification_threshold_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--amplification=0.75"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["amplification_threshold"] = 0.75 # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using single stem argument def test_cli_single_stem_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--single_stem=instrumental"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["output_single_stem"] = "instrumental" # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using invert spectrogram argument def test_cli_invert_spectrogram_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--invert_spect"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["invert_using_spec"] = True # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using use_autocast argument def test_cli_use_autocast_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--use_autocast"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Update expected args for this specific test expected_args = common_expected_args.copy() expected_args["use_autocast"] = True # Assertions mock_separator.assert_called_once_with(**expected_args) # Test using custom_output_names arguments def test_cli_custom_output_names_argument(common_expected_args): custom_names = { "Vocals": "vocals_output", "Instrumental": "instrumental_output", } test_args = ["cli.py", "test_audio.mp3", f"--custom_output_names={json.dumps(custom_names)}"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Assertions mock_separator.assert_called_once_with(**common_expected_args) mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", custom_output_names=custom_names) # Test using custom_output_names arguments def test_cli_demucs_output_names_argument(common_expected_args): demucs_output_names = { "Vocals": "vocals_output", "Drums": "drums_output", "Bass": "bass_output", "Other": "other_output", "Guitar": "guitar_output", "Piano": "piano_output" } test_args = ["cli.py", "test_audio.mp3", f"--custom_output_names={json.dumps(demucs_output_names)}", "--model_filename=htdemucs_6s.yaml"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator_instance = mock_separator.return_value mock_separator_instance.separate.return_value = ["output_file.mp3"] main() # Assertions mock_separator.assert_called_once_with(**common_expected_args) mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", custom_output_names=demucs_output_names)