Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import os | |
import time | |
import museval | |
import numpy as np | |
import soundfile as sf | |
from audio_separator.separator import Separator | |
import json | |
from json import JSONEncoder | |
import logging | |
import musdb | |
from decimal import Decimal | |
import tempfile | |
import argparse | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Custom JSON Encoder to handle Decimal types | |
class DecimalEncoder(JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, Decimal): | |
return float(obj) | |
return super().default(obj) | |
MUSDB_PATH = "/Volumes/Nomad4TBOne/python-audio-separator/tests/model-metrics/datasets/musdb18hq" | |
RESULTS_PATH = "/Volumes/Nomad4TBOne/python-audio-separator/tests/model-metrics/results" | |
COMBINED_RESULTS_PATH = "/Users/andrew/Projects/python-audio-separator/audio_separator/models-scores.json" | |
COMBINED_MUSEVAL_RESULTS_PATH = "/Volumes/Nomad4TBOne/python-audio-separator/tests/model-metrics/results/combined-museval-results.json" | |
STOP_SIGNAL_PATH = "/Volumes/Nomad4TBOne/python-audio-separator/tests/model-metrics/stop-signal" | |
def load_combined_results(): | |
"""Load the combined museval results file""" | |
if os.path.exists(COMBINED_MUSEVAL_RESULTS_PATH): | |
logger.info("Loading combined museval results...") | |
try: | |
with open(COMBINED_MUSEVAL_RESULTS_PATH, "r") as f: | |
# Use a custom parser to handle Decimal values | |
def decimal_parser(dct): | |
for k, v in dct.items(): | |
if isinstance(v, str) and v.replace(".", "").isdigit(): | |
try: | |
dct[k] = float(v) | |
except (ValueError, TypeError): | |
pass | |
return dct | |
return json.load(f, object_hook=decimal_parser) | |
except Exception as e: | |
logger.error(f"Error loading combined results: {str(e)}") | |
# Try to load a backup file if it exists | |
backup_path = COMBINED_MUSEVAL_RESULTS_PATH + ".backup" | |
if os.path.exists(backup_path): | |
logger.info("Attempting to load backup file...") | |
try: | |
with open(backup_path, "r") as f: | |
return json.load(f, object_hook=decimal_parser) | |
except Exception as backup_e: | |
logger.error(f"Error loading backup file: {str(backup_e)}") | |
return {} | |
else: | |
logger.info("No combined results file found, creating new one") | |
return {} | |
def save_combined_results(combined_results): | |
"""Save the combined museval results file""" | |
logger.info("Saving combined museval results...") | |
try: | |
# Create a backup of the existing file if it exists | |
if os.path.exists(COMBINED_MUSEVAL_RESULTS_PATH): | |
backup_path = COMBINED_MUSEVAL_RESULTS_PATH + ".backup" | |
try: | |
with open(COMBINED_MUSEVAL_RESULTS_PATH, "r") as src, open(backup_path, "w") as dst: | |
dst.write(src.read()) | |
except Exception as e: | |
logger.error(f"Error creating backup file: {str(e)}") | |
# Save the new results using the custom encoder | |
with open(COMBINED_MUSEVAL_RESULTS_PATH, "w") as f: | |
json.dump(combined_results, f, cls=DecimalEncoder, indent=2) | |
logger.info("Combined results saved successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error saving combined results: {str(e)}") | |
return False | |
def update_combined_results(model_name, track_name, track_data): | |
"""Update the combined results file with new track data""" | |
try: | |
# Load existing combined results | |
combined_results = load_combined_results() | |
# Initialize model entry if it doesn't exist | |
if model_name not in combined_results: | |
combined_results[model_name] = {} | |
# Add or update track data | |
combined_results[model_name][track_name] = track_data | |
# Write updated results back to file | |
save_combined_results(combined_results) | |
return True | |
except Exception as e: | |
logger.error(f"Error updating combined results: {str(e)}") | |
return False | |
def check_track_evaluated(model_name, track_name): | |
"""Check if a track has already been evaluated for a specific model""" | |
combined_results = load_combined_results() | |
return model_name in combined_results and track_name in combined_results[model_name] | |
def get_track_results(model_name, track_name): | |
"""Get the evaluation results for a specific track and model""" | |
combined_results = load_combined_results() | |
if model_name in combined_results and track_name in combined_results[model_name]: | |
return combined_results[model_name][track_name] | |
return None | |
def get_track_duration(track_path): | |
"""Get the duration of a track in minutes""" | |
try: | |
mixture_path = os.path.join(track_path, "mixture.wav") | |
info = sf.info(mixture_path) | |
return info.duration / 60.0 # Convert seconds to minutes | |
except Exception as e: | |
logger.error(f"Error getting track duration: {str(e)}") | |
return 0.0 | |
def evaluate_track(track_name, track_path, test_model, mus_db): | |
"""Evaluate a single track using a specific model""" | |
logger.info(f"Evaluating track: {track_name} with model: {test_model}") | |
# Get track duration in minutes | |
track_duration_minutes = get_track_duration(track_path) | |
logger.info(f"Track duration: {track_duration_minutes:.2f} minutes") | |
# Initialize variables to track processing time | |
processing_time = 0 | |
seconds_per_minute = 0 | |
# Create a basic result structure that will be returned even if evaluation fails | |
basic_model_results = {"track_name": track_name, "scores": {}} | |
# Check if evaluation results already exist in combined file | |
museval_results = load_combined_results() | |
if test_model in museval_results and track_name in museval_results[test_model]: | |
logger.info("Found existing evaluation results in combined file...") | |
track_data = museval_results[test_model][track_name] | |
scores = museval.TrackStore(track_name) | |
scores.scores = track_data | |
# Try to extract existing speed metrics if available | |
try: | |
if isinstance(track_data, dict) and "targets" in track_data: | |
for target in track_data["targets"]: | |
if "metrics" in target and "seconds_per_minute_m3" in target["metrics"]: | |
basic_model_results["scores"]["seconds_per_minute_m3"] = target["metrics"]["seconds_per_minute_m3"] | |
break | |
except Exception: | |
pass # Ignore errors in extracting existing speed metrics | |
else: | |
# Expanded stem mapping to include "no-stem" outputs and custom stem formats | |
stem_mapping = { | |
# Standard stems | |
"Vocals": "vocals", | |
"Instrumental": "instrumental", | |
"Drums": "drums", | |
"Bass": "bass", | |
"Other": "other", | |
# No-stem variants | |
"No Drums": "nodrums", | |
"No Bass": "nobass", | |
"No Other": "noother", | |
# Custom stem formats (with hyphens) | |
"Drum-Bass": "drumbass", | |
"No Drum-Bass": "nodrumbass", | |
"Vocals-Other": "vocalsother", | |
"No Vocals-Other": "novocalsother", | |
} | |
# Create a temporary directory for separation files | |
with tempfile.TemporaryDirectory() as temp_dir: | |
logger.info(f"Using temporary directory: {temp_dir}") | |
# Measure separation time | |
start_time = time.time() | |
# Perform separation | |
logger.info("Performing separation...") | |
separator = Separator(output_dir=temp_dir) | |
separator.load_model(model_filename=test_model) | |
separator.separate(os.path.join(track_path, "mixture.wav"), custom_output_names=stem_mapping) | |
# Calculate processing time | |
processing_time = time.time() - start_time | |
seconds_per_minute = processing_time / track_duration_minutes if track_duration_minutes > 0 else 0 | |
logger.info(f"Separation completed in {processing_time:.2f} seconds") | |
logger.info(f"Processing speed: {seconds_per_minute:.2f} seconds per minute of audio") | |
# Always add the speed metric to our basic results | |
basic_model_results["scores"]["seconds_per_minute_m3"] = round(seconds_per_minute, 1) | |
# Check which stems were actually created | |
wav_files = [f for f in os.listdir(temp_dir) if f.endswith(".wav")] | |
logger.info(f"Found WAV files: {wav_files}") | |
# Determine if this is a standard vocal/instrumental model that can be evaluated with museval | |
standard_model = False | |
if len(wav_files) == 2: | |
# Check if one of the files is named vocals.wav or instrumental.wav | |
if "vocals.wav" in wav_files and "instrumental.wav" in wav_files: | |
standard_model = True | |
logger.info("Detected standard vocals/instrumental model, will run museval evaluation") | |
# If not a standard model, skip museval evaluation and just return speed metrics | |
if not standard_model: | |
logger.info(f"Non-standard stem configuration detected for model {test_model}, skipping museval evaluation") | |
# Store the speed metric in the combined results | |
if test_model not in museval_results: | |
museval_results[test_model] = {} | |
# Create a minimal structure for the speed metric | |
minimal_results = {"targets": [{"name": "speed_metrics_only", "metrics": {"seconds_per_minute_m3": round(seconds_per_minute, 1)}}]} | |
museval_results[test_model][track_name] = minimal_results | |
save_combined_results(museval_results) | |
return None, basic_model_results | |
# For standard models, proceed with museval evaluation | |
available_stems = {} | |
available_stems["vocals"] = os.path.join(temp_dir, "vocals.wav") | |
available_stems["accompaniment"] = os.path.join(temp_dir, "instrumental.wav") | |
# Get track from MUSDB | |
track = next((t for t in mus_db if t.name == track_name), None) | |
if track is None: | |
raise ValueError(f"Track {track_name} not found in MUSDB18") | |
# Load available stems | |
estimates = {} | |
for stem_name, stem_path in available_stems.items(): | |
audio, _ = sf.read(stem_path) | |
if len(audio.shape) == 1: | |
audio = np.expand_dims(audio, axis=1) | |
estimates[stem_name] = audio | |
# Evaluate using museval | |
logger.info(f"Evaluating stems: {list(estimates.keys())}") | |
try: | |
scores = museval.eval_mus_track(track, estimates, output_dir=temp_dir, mode="v4") | |
# Add the speed metric to the scores | |
if not hasattr(scores, "speed_metric_added"): | |
for target in scores.scores["targets"]: | |
if "metrics" not in target: | |
target["metrics"] = {} | |
target["metrics"]["seconds_per_minute_m3"] = round(seconds_per_minute, 1) | |
scores.speed_metric_added = True | |
# Update the combined results file with the new evaluation | |
if test_model not in museval_results: | |
museval_results[test_model] = {} | |
museval_results[test_model][track_name] = scores.scores | |
save_combined_results(museval_results) | |
except Exception as e: | |
logger.error(f"Error during museval evaluation: {str(e)}") | |
logger.exception("Evaluation exception details:") | |
# Return basic results with just the speed metric | |
return None, basic_model_results | |
try: | |
# Only process museval results if we have them | |
if "scores" in locals() and scores is not None: | |
# Calculate aggregate scores for available stems | |
results_store = museval.EvalStore() | |
results_store.add_track(scores.df) | |
methods = museval.MethodStore() | |
methods.add_evalstore(results_store, name=test_model) | |
agg_scores = methods.agg_frames_tracks_scores() | |
# Return the aggregate scores in a structured format with 6 significant figures | |
model_results = {"track_name": track_name, "scores": {}} | |
for stem in ["vocals", "drums", "bass", "other", "accompaniment"]: | |
try: | |
stem_scores = {metric: float(f"{agg_scores.loc[(test_model, stem, metric)]:.6g}") for metric in ["SDR", "SIR", "SAR", "ISR"]} | |
# Rename 'accompaniment' to 'instrumental' in the output | |
output_stem = "instrumental" if stem == "accompaniment" else stem | |
model_results["scores"][output_stem] = stem_scores | |
except KeyError: | |
continue | |
# Add the seconds_per_minute_m3 metric if it was calculated | |
if processing_time > 0 and track_duration_minutes > 0: | |
model_results["scores"]["seconds_per_minute_m3"] = round(seconds_per_minute, 1) | |
return scores, model_results if model_results["scores"] else basic_model_results | |
else: | |
# If we don't have scores, just return the basic results with speed metrics | |
return None, basic_model_results | |
except Exception as e: | |
logger.error(f"Error processing evaluation results: {str(e)}") | |
logger.exception("Results processing exception details:") | |
# Return basic results with just the speed metric | |
return None, basic_model_results | |
def convert_decimal_to_float(obj): | |
"""Recursively converts Decimal objects to floats in a nested structure.""" | |
if isinstance(obj, Decimal): | |
return float(obj) | |
elif isinstance(obj, dict): | |
return {k: convert_decimal_to_float(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [convert_decimal_to_float(x) for x in obj] | |
return obj | |
def calculate_median_scores(track_scores): | |
"""Calculate median scores across all tracks for each stem and metric""" | |
# Initialize containers for each stem's metrics | |
stem_metrics = { | |
"vocals": {"SDR": [], "SIR": [], "SAR": [], "ISR": []}, | |
"drums": {"SDR": [], "SIR": [], "SAR": [], "ISR": []}, | |
"bass": {"SDR": [], "SIR": [], "SAR": [], "ISR": []}, | |
"instrumental": {"SDR": [], "SIR": [], "SAR": [], "ISR": []}, | |
"seconds_per_minute_m3": [], | |
} | |
# Collect all scores for each stem and metric | |
for track_score in track_scores: | |
if track_score is not None and "scores" in track_score: | |
# Process audio quality metrics | |
for stem, metrics in track_score["scores"].items(): | |
if stem in stem_metrics and stem != "seconds_per_minute_m3": | |
for metric, value in metrics.items(): | |
stem_metrics[stem][metric].append(value) | |
# Process speed metric separately | |
if "seconds_per_minute_m3" in track_score["scores"]: | |
stem_metrics["seconds_per_minute_m3"].append(track_score["scores"]["seconds_per_minute_m3"]) | |
# Calculate medians for each stem and metric | |
median_scores = {} | |
for stem, metrics in stem_metrics.items(): | |
if stem != "seconds_per_minute_m3" and any(metrics.values()): # Only include stems that have scores | |
median_scores[stem] = {metric: float(f"{np.median(values):.6g}") for metric, values in metrics.items() if values} # Only include metrics that have values | |
# Add median speed metric if available | |
if stem_metrics["seconds_per_minute_m3"]: | |
median_scores["seconds_per_minute_m3"] = round(np.median(stem_metrics["seconds_per_minute_m3"]), 1) | |
return median_scores | |
def check_disk_usage(path): | |
"""Check inode usage and disk space on the filesystem containing path""" | |
import subprocess | |
import sys | |
# Check disk space first | |
result = subprocess.run(["df", "-h", path], capture_output=True, text=True) | |
output = result.stdout | |
logger.info(f"Current disk usage:\n{output}") | |
# Parse the output to get disk usage percentage | |
lines = output.strip().split("\n") | |
if len(lines) >= 2: | |
parts = lines[1].split() | |
if len(parts) >= 5: | |
try: | |
# Extract disk usage percentage | |
disk_usage_str = parts[4].rstrip("%") | |
disk_usage_pct = int(disk_usage_str) | |
logger.info(f"Disk usage: {disk_usage_pct}%") | |
if disk_usage_pct >= 99: | |
logger.critical("CRITICAL: Disk is almost full (>99%)! Cannot continue processing.") | |
logger.critical("Please free up disk space before continuing.") | |
sys.exit(1) | |
elif disk_usage_pct > 95: | |
logger.warning(f"WARNING: High disk usage ({disk_usage_pct}%)!") | |
except (ValueError, IndexError) as e: | |
logger.error(f"Error parsing disk usage: {str(e)}") | |
# Now check inode usage | |
result = subprocess.run(["df", "-i", path], capture_output=True, text=True) | |
output = result.stdout | |
logger.info(f"Current inode usage:\n{output}") | |
# Parse the output to get inode usage percentage | |
lines = output.strip().split("\n") | |
if len(lines) >= 2: | |
# The second line contains the actual data | |
parts = lines[1].split() | |
if len(parts) >= 8: # macOS df -i format has 8 columns | |
try: | |
# On macOS, inode usage is in the 8th column as a percentage | |
inode_usage_str = parts[7].rstrip("%") | |
inode_usage_pct = int(inode_usage_str) | |
# Also extract the actual inode numbers for better reporting | |
iused = int(parts[5]) | |
ifree = int(parts[6]) | |
total_inodes = iused + ifree | |
# Skip inode check for exFAT or similar filesystems | |
if total_inodes <= 1: | |
logger.info("Filesystem appears to be exFAT or similar (no real inode tracking). Skipping inode check.") | |
return None | |
logger.info(f"Inode usage: {iused:,}/{total_inodes:,} ({inode_usage_pct}%)") | |
if inode_usage_pct >= 100: | |
logger.critical("CRITICAL: Inode usage is at 100%! Cannot continue processing.") | |
logger.critical("Please free up inodes before continuing.") | |
sys.exit(1) | |
elif inode_usage_pct > 90: | |
logger.warning(f"WARNING: High inode usage ({inode_usage_pct}%)!") | |
return inode_usage_pct | |
except (ValueError, IndexError) as e: | |
logger.error(f"Error parsing inode usage: {str(e)}") | |
return None | |
def get_evaluated_track_count(model_name, museval_results): | |
"""Get the number of tracks evaluated for a specific model""" | |
if model_name in museval_results: | |
return len(museval_results[model_name]) | |
return 0 | |
def get_most_evaluated_tracks(museval_results, min_count=10): | |
"""Get tracks that have been evaluated for the most models""" | |
track_counts = {} | |
# Count how many models have evaluated each track | |
for model_name, tracks in museval_results.items(): | |
for track_name in tracks: | |
if track_name not in track_counts: | |
track_counts[track_name] = 0 | |
track_counts[track_name] += 1 | |
# Sort tracks by evaluation count (descending) | |
sorted_tracks = sorted(track_counts.items(), key=lambda x: x[1], reverse=True) | |
# Return tracks that have been evaluated at least min_count times | |
return [track for track, count in sorted_tracks if count >= min_count] | |
def generate_summary_statistics( | |
start_time, models_processed, tracks_processed, models_with_new_data, tracks_evaluated, total_processing_time, fastest_model=None, slowest_model=None, combined_results_path=None, is_dry_run=False | |
): | |
"""Generate a summary of the script's execution""" | |
end_time = time.time() | |
total_runtime = end_time - start_time | |
# Format the runtime | |
hours, remainder = divmod(total_runtime, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
runtime_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}" | |
# Build the summary | |
summary = [ | |
"=" * 80, | |
"DRY RUN SUMMARY - PREVIEW ONLY" if is_dry_run else "EXECUTION SUMMARY", | |
"=" * 80, | |
f"Total runtime: {runtime_str}", | |
f"Models {'that would be' if is_dry_run else ''} processed: {models_processed}", | |
f"Models {'that would receive' if is_dry_run else 'with'} new data: {len(models_with_new_data)}", | |
f"Total tracks {'that would be' if is_dry_run else ''} evaluated: {tracks_evaluated}", | |
f"Average tracks per model: {tracks_evaluated / len(models_with_new_data) if models_with_new_data else 0:.2f}", | |
] | |
if fastest_model: | |
summary.append(f"Fastest model: {fastest_model['name']} ({fastest_model['speed']:.2f} seconds per minute)") | |
if slowest_model: | |
summary.append(f"Slowest model: {slowest_model['name']} ({slowest_model['speed']:.2f} seconds per minute)") | |
if total_processing_time > 0: | |
summary.append(f"Total audio processing time: {total_processing_time:.2f} seconds") | |
if combined_results_path and os.path.exists(combined_results_path): | |
file_size = os.path.getsize(combined_results_path) / (1024 * 1024) # Size in MB | |
summary.append(f"Results file size: {file_size:.2f} MB") | |
# Add models with new data | |
if models_with_new_data: | |
summary.append(f"\nModels {'that would receive' if is_dry_run else 'with'} new evaluation data:") | |
for model_name in models_with_new_data: | |
summary.append(f"- {model_name}") | |
# Add dry run disclaimer if needed | |
if is_dry_run: | |
summary.append("\nNOTE: This is a dry run summary. No actual changes were made.") | |
summary.append("Run without --dry-run to perform actual evaluations.") | |
summary.append("=" * 80) | |
return "\n".join(summary) | |
def check_stop_signal(): | |
"""Check if the stop signal file exists""" | |
if os.path.exists(STOP_SIGNAL_PATH): | |
logger.info("Stop signal detected at: " + STOP_SIGNAL_PATH) | |
return True | |
return False | |
def main(): | |
# Add command line argument parsing for dry run mode | |
parser = argparse.ArgumentParser(description="Run model evaluation on MUSDB18 dataset") | |
parser.add_argument("--dry-run", action="store_true", help="Run in dry-run mode (no writes)") | |
parser.add_argument("--max-tracks", type=int, default=10, help="Maximum number of tracks to evaluate per model") | |
parser.add_argument("--max-models", type=int, default=None, help="Maximum number of models to evaluate") | |
args = parser.parse_args() | |
# Remove any existing stop signal file at start | |
if os.path.exists(STOP_SIGNAL_PATH): | |
os.remove(STOP_SIGNAL_PATH) | |
logger.info("Removed existing stop signal file") | |
# Track start time for progress reporting | |
start_time = time.time() | |
# Statistics tracking | |
models_processed = 0 | |
tracks_processed = 0 | |
models_with_new_data = set() | |
total_processing_time = 0 | |
fastest_model = {"name": "", "speed": float("inf")} # Initialize with infinity for comparison | |
slowest_model = {"name": "", "speed": 0} # Initialize with zero for comparison | |
# Create a results cache manager | |
class ResultsCache: | |
def __init__(self): | |
self.results = load_combined_results() | |
self.last_update_time = time.time() | |
def get_results(self, force=False): | |
current_time = time.time() | |
# Only reload from disk every 5 minutes unless forced | |
if force or (current_time - self.last_update_time) > 300: | |
self.results = load_combined_results() | |
self.last_update_time = current_time | |
return self.results | |
results_cache = ResultsCache() | |
# Helper function for logging with elapsed time | |
def log_with_time(message, level=logging.INFO): | |
elapsed = time.time() - start_time | |
hours, remainder = divmod(elapsed, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}" | |
logger.log(level, f"[{time_str}] {message}") | |
if args.dry_run: | |
log_with_time("*** RUNNING IN DRY-RUN MODE - NO DATA WILL BE MODIFIED ***") | |
log_with_time("Starting model evaluation script...") | |
os.makedirs(RESULTS_PATH, exist_ok=True) | |
# Check disk space and inode usage at start | |
check_disk_usage(RESULTS_PATH) | |
# Load existing results if available | |
combined_results = {} | |
if os.path.exists(COMBINED_RESULTS_PATH): | |
log_with_time("Loading existing combined results...") | |
with open(COMBINED_RESULTS_PATH) as f: | |
combined_results = json.load(f) | |
# Get initial museval results | |
museval_results = results_cache.get_results() | |
log_with_time(f"Loaded combined museval results with {len(museval_results)} models") | |
# Get the most commonly evaluated tracks | |
common_tracks = get_most_evaluated_tracks(museval_results) | |
log_with_time(f"Found {len(common_tracks)} commonly evaluated tracks") | |
# Initialize MUSDB | |
log_with_time("Initializing MUSDB database...") | |
mus = musdb.DB(root=MUSDB_PATH, is_wav=True) | |
# Create a prioritized list of tracks | |
all_tracks = [] | |
for track in mus.tracks: | |
# Check if this is a commonly evaluated track | |
is_common = track.name in common_tracks | |
all_tracks.append({"name": track.name, "path": os.path.dirname(track.path), "is_common": is_common}) | |
# Sort tracks by whether they're commonly evaluated | |
all_tracks.sort(key=lambda t: 0 if t["is_common"] else 1) | |
# Get list of all available models | |
log_with_time("Getting list of available models...") | |
separator = Separator() | |
models_by_type = separator.list_supported_model_files() | |
# Flatten the models list and prioritize them | |
all_models = [] | |
for model_type, models in models_by_type.items(): | |
for model_name, model_info in models.items(): | |
filename = model_info.get("filename") | |
if filename: | |
# Count how many tracks have been evaluated for this model | |
evaluated_count = get_evaluated_track_count(filename, museval_results) | |
# Determine if this is a roformer model | |
is_roformer = "roformer" in model_name.lower() | |
# Add to the list with priority information | |
all_models.append({"name": model_name, "filename": filename, "type": model_type, "info": model_info, "evaluated_count": evaluated_count, "is_roformer": is_roformer}) | |
# Sort models by priority: | |
# 1. Roformer models with fewer than max_tracks evaluations | |
# 2. Other models with fewer than max_tracks evaluations | |
# 3. Roformer models with more evaluations | |
# 4. Other models with more evaluations | |
all_models.sort( | |
key=lambda m: ( | |
0 if m["is_roformer"] and m["evaluated_count"] < args.max_tracks else 1 if not m["is_roformer"] and m["evaluated_count"] < args.max_tracks else 2 if m["is_roformer"] else 3, | |
m["evaluated_count"], # Secondary sort by number of evaluations (ascending) | |
) | |
) | |
# Log the prioritized models | |
log_with_time(f"Prioritized {len(all_models)} models for evaluation:") | |
for i, model in enumerate(all_models[:10]): # Show top 10 | |
log_with_time(f"{i+1}. {model['name']} ({model['filename']}) - {model['evaluated_count']} tracks evaluated, roformer: {model['is_roformer']}") | |
if len(all_models) > 10: | |
log_with_time(f"... and {len(all_models) - 10} more models") | |
# Limit the number of models if specified | |
if args.max_models: | |
all_models = all_models[: args.max_models] | |
log_with_time(f"Limited to {args.max_models} models for this run") | |
# Process models according to priority | |
model_idx = 0 | |
stop_requested = False | |
while model_idx < len(all_models): | |
# Check for stop signal before processing each model | |
if check_stop_signal(): | |
log_with_time("Stop signal detected. Will finish current model's tracks and then exit.") | |
stop_requested = True | |
model = all_models[model_idx] | |
model_name = model["name"] | |
model_filename = model["filename"] | |
model_type = model["type"] | |
progress_pct = (model_idx + 1) / len(all_models) * 100 | |
log_with_time(f"\n=== Processing model {model_idx+1}/{len(all_models)} ({progress_pct:.1f}%): {model_name} ({model_filename}) ===") | |
# Initialize model entry if it doesn't exist | |
if model_filename not in combined_results: | |
log_with_time(f"Initializing new entry for {model_filename}") | |
combined_results[model_filename] = {"model_name": model_name, "track_scores": [], "median_scores": {}, "stems": [], "target_stem": None} | |
# Try to load the model to get stem information | |
try: | |
separator.load_model(model_filename=model_filename) | |
model_data = separator.model_instance.model_data | |
# Extract stem information (similar to your existing code) | |
# ... (keep your existing stem extraction logic here) | |
except Exception as e: | |
log_with_time(f"Error loading model {model_filename}: {str(e)}", logging.ERROR) | |
logger.exception("Full exception details:") | |
model_idx += 1 | |
continue | |
# Count how many tracks have been evaluated for this model | |
# Use the cached results | |
evaluated_count = get_evaluated_track_count(model_filename, results_cache.get_results()) | |
# Determine how many more tracks to evaluate | |
tracks_to_evaluate = max(0, args.max_tracks - evaluated_count) | |
if tracks_to_evaluate == 0: | |
log_with_time(f"Model {model_name} already has {evaluated_count} tracks evaluated (>= {args.max_tracks}). Skipping.") | |
model_idx += 1 | |
continue | |
log_with_time(f"Will evaluate up to {tracks_to_evaluate} tracks for model {model_name}") | |
# Process tracks for this model | |
tracks_processed = 0 | |
for track in all_tracks: | |
# Check for stop signal before each track if we haven't already detected it | |
if not stop_requested and check_stop_signal(): | |
log_with_time("Stop signal detected. Will finish current track and then exit.") | |
stop_requested = True | |
# Skip if we've processed enough tracks for this model | |
if tracks_processed >= tracks_to_evaluate: | |
break | |
track_name = track["name"] | |
track_path = track["path"] | |
# Skip if track already evaluated for this model | |
# Use the cached results | |
if model_filename in results_cache.get_results() and track_name in results_cache.get_results()[model_filename]: | |
log_with_time(f"Skipping already evaluated track {track_name} for model: {model_filename}") | |
continue | |
log_with_time(f"Processing track: {track_name} for model: {model_filename}") | |
if args.dry_run: | |
log_with_time(f"[DRY RUN] Would evaluate track {track_name} with model {model_filename}") | |
tracks_processed += 1 | |
models_with_new_data.add(model_filename) | |
# Estimate processing time based on model type for dry run | |
# This is a rough estimate - roformer models are typically slower | |
estimated_speed = 30.0 # Default estimate: 30 seconds per minute | |
if "roformer" in model_name.lower(): | |
estimated_speed = 45.0 # Roformer models are typically slower | |
elif "umx" in model_name.lower(): | |
estimated_speed = 20.0 # UMX models are typically faster | |
# Update statistics with estimated values | |
total_processing_time += estimated_speed | |
# Track fastest and slowest models based on estimates | |
if estimated_speed < fastest_model["speed"]: | |
fastest_model = {"name": model_name, "speed": estimated_speed} | |
if estimated_speed > slowest_model["speed"]: | |
slowest_model = {"name": model_name, "speed": estimated_speed} | |
continue | |
try: | |
result = evaluate_track(track_name, track_path, model_filename, mus) | |
# Unpack the result safely | |
if result and isinstance(result, tuple) and len(result) == 2: | |
_, model_results = result | |
else: | |
model_results = None | |
# Process the results if they exist and are valid | |
if model_results is not None and isinstance(model_results, dict): | |
combined_results[model_filename]["track_scores"].append(model_results) | |
tracks_processed += 1 | |
models_with_new_data.add(model_filename) | |
# Track processing time statistics - safely access nested dictionaries | |
scores = model_results.get("scores", {}) | |
if isinstance(scores, dict): | |
speed = scores.get("seconds_per_minute_m3") | |
if speed is not None: | |
total_processing_time += speed # Accumulate total processing time | |
# Track fastest and slowest models | |
if speed < fastest_model["speed"]: | |
fastest_model = {"name": model_name, "speed": speed} | |
if speed > slowest_model["speed"]: | |
slowest_model = {"name": model_name, "speed": speed} | |
else: | |
log_with_time(f"Skipping model {model_filename} for track {track_name} due to no evaluatable stems or invalid results") | |
except Exception as e: | |
log_with_time(f"Error evaluating model {model_filename} with track {track_name}: {str(e)}", logging.ERROR) | |
logger.exception(f"Exception details: ", exc_info=e) | |
continue | |
# Update and save results | |
if combined_results[model_filename]["track_scores"]: | |
median_scores = calculate_median_scores(combined_results[model_filename]["track_scores"]) | |
combined_results[model_filename]["median_scores"] = median_scores | |
# Save results after each track | |
if not args.dry_run: | |
os.makedirs(os.path.dirname(COMBINED_RESULTS_PATH), exist_ok=True) | |
with open(COMBINED_RESULTS_PATH, "w", encoding="utf-8") as f: | |
json.dump(combined_results, f, indent=2) | |
log_with_time(f"Updated combined results file with {model_filename} - {track_name}") | |
# Force update the cache after saving | |
results_cache.get_results(force=True) | |
else: | |
log_with_time(f"[DRY RUN] Would have updated combined results for {model_filename} - {track_name}") | |
# Check disk space periodically | |
check_disk_usage(RESULTS_PATH) | |
log_with_time(f"Completed processing {tracks_processed} tracks for model {model_name}") | |
# If stop was requested, exit after completing the current model | |
if stop_requested: | |
log_with_time("Stop signal processed. Generating final summary before exit.") | |
break | |
# If we're processing a non-roformer model, check if there are roformer models that need evaluation | |
if not model["is_roformer"]: | |
# Find roformer models that still need more evaluations | |
# Use the cached results | |
roformer_models_needing_eval = [] | |
for i, m in enumerate(all_models[model_idx + 1 :], start=model_idx + 1): | |
if m["is_roformer"]: | |
eval_count = get_evaluated_track_count(m["filename"], results_cache.get_results()) | |
if eval_count < args.max_tracks: | |
roformer_models_needing_eval.append((i, m)) | |
if roformer_models_needing_eval: | |
log_with_time(f"Found {len(roformer_models_needing_eval)} roformer models that still need evaluation. Reprioritizing...") | |
# Move these models to the front of the remaining queue | |
for offset, (i, m) in enumerate(roformer_models_needing_eval): | |
# Adjust index for models we've already moved | |
adjusted_idx = i - offset | |
# Move this model right after the current one | |
all_models.insert(model_idx + 1, all_models.pop(adjusted_idx)) | |
log_with_time("Reprioritization complete. Continuing with highest priority model.") | |
# Move to the next model | |
model_idx += 1 | |
models_processed += 1 | |
log_with_time("Evaluation complete") | |
# Final disk space check | |
check_disk_usage(RESULTS_PATH) | |
# Generate and display summary statistics | |
# Reset fastest/slowest models if they weren't updated | |
if fastest_model["speed"] == float("inf"): | |
fastest_model = None | |
if slowest_model["speed"] == 0: | |
slowest_model = None | |
summary = generate_summary_statistics( | |
start_time=start_time, | |
models_processed=models_processed, | |
tracks_processed=tracks_processed, | |
models_with_new_data=models_with_new_data, | |
tracks_evaluated=tracks_processed, | |
total_processing_time=total_processing_time, | |
fastest_model=fastest_model, | |
slowest_model=slowest_model, | |
combined_results_path=COMBINED_RESULTS_PATH, | |
is_dry_run=args.dry_run, | |
) | |
log_with_time("\n" + summary) | |
# Also write summary to a log file | |
summary_filename = "dry_run_summary.log" if args.dry_run else "evaluation_summary.log" | |
if stop_requested: | |
summary_filename = "stopped_" + summary_filename | |
summary_log_path = os.path.join(os.path.dirname(COMBINED_RESULTS_PATH), summary_filename) | |
with open(summary_log_path, "w") as f: | |
f.write(f"{'Dry run' if args.dry_run else 'Evaluation'} {'(stopped early)' if stop_requested else ''} completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") | |
f.write(summary) | |
log_with_time(f"Summary written to {summary_log_path}") | |
# Clean up stop signal file if it exists | |
if os.path.exists(STOP_SIGNAL_PATH): | |
os.remove(STOP_SIGNAL_PATH) | |
log_with_time("Removed stop signal file") | |
return 0 if not stop_requested else 2 # Return different exit code if stopped early | |
if __name__ == "__main__": | |
exit(main()) | |