Spaces:
Running
Running
# scripts/run_inference.py | |
""" | |
CLI inference with preprocessing parity. | |
Applies: resample → baseline (deg=2) → smooth (w=11,o=2) → normalize | |
unless explicitly disabled via flags. | |
Usage (examples): | |
python scripts/run_inference.py \ | |
--input datasets/rdwp/sta-1.txt \ | |
--arch figure2 \ | |
--weights outputs/figure2_model.pth \ | |
--target-len 500 | |
# Disable smoothing only: | |
python scripts/run_inference.py --input ... --arch resnet --weights ... --disable-smooth | |
""" | |
import os | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
import argparse | |
import json | |
import csv | |
import logging | |
from pathlib import Path | |
from typing import cast, Dict, List, Any | |
from torch import nn | |
import time | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from models.registry import build, choices, build_multiple, validate_model_list | |
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH | |
from utils.multifile import parse_spectrum_data, detect_file_format | |
from scripts.plot_spectrum import load_spectrum | |
from scripts.discover_raman_files import label_file | |
def parse_args(): | |
p = argparse.ArgumentParser( | |
description="Raman/FTIR spectrum inference with multi-model support." | |
) | |
p.add_argument( | |
"--input", | |
required=True, | |
help="Path to spectrum file (.txt, .csv, .json) or directory for batch processing.", | |
) | |
# Model selection - either single or multiple | |
group = p.add_mutually_exclusive_group(required=True) | |
group.add_argument( | |
"--arch", choices=choices(), help="Single model architecture key." | |
) | |
group.add_argument( | |
"--models", | |
help="Comma-separated list of models for comparison (e.g., 'figure2,resnet,resnet18vision').", | |
) | |
p.add_argument( | |
"--weights", | |
help="Path to model weights (.pth). For multi-model, use pattern with {model} placeholder.", | |
) | |
p.add_argument( | |
"--target-len", | |
type=int, | |
default=TARGET_LENGTH, | |
help="Resample length (default: 500).", | |
) | |
# Modality support | |
p.add_argument( | |
"--modality", | |
choices=["raman", "ftir"], | |
default="raman", | |
help="Spectroscopy modality for preprocessing (default: raman).", | |
) | |
# Default = ON; use disable- flags to turn steps off explicitly. | |
p.add_argument( | |
"--disable-baseline", action="store_true", help="Disable baseline correction." | |
) | |
p.add_argument( | |
"--disable-smooth", | |
action="store_true", | |
help="Disable Savitzky–Golay smoothing.", | |
) | |
p.add_argument( | |
"--disable-normalize", | |
action="store_true", | |
help="Disable min-max normalization.", | |
) | |
p.add_argument( | |
"--output", | |
default=None, | |
help="Output path - JSON for single file, CSV for multi-model comparison.", | |
) | |
p.add_argument( | |
"--output-format", | |
choices=["json", "csv"], | |
default="json", | |
help="Output format for results.", | |
) | |
p.add_argument( | |
"--device", | |
default="cpu", | |
choices=["cpu", "cuda"], | |
help="Compute device (default: cpu).", | |
) | |
# File format options | |
p.add_argument( | |
"--file-format", | |
choices=["auto", "txt", "csv", "json"], | |
default="auto", | |
help="Input file format (auto-detect by default).", | |
) | |
return p.parse_args() | |
# ///////////////////////////////////////////////////////// | |
def _load_state_dict_safe(path: str): | |
"""Load a state dict safely across torch versions & checkpoint formats.""" | |
try: | |
obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch | |
except TypeError: | |
obj = torch.load(path, map_location="cpu") # fallback for older torch | |
# Accept either a plain state_dict or a checkpoint dict that contains one | |
if isinstance(obj, dict): | |
for k in ("state_dict", "model_state_dict", "model"): | |
if k in obj and isinstance(obj[k], dict): | |
obj = obj[k] | |
break | |
if not isinstance(obj, dict): | |
raise ValueError( | |
"Loaded object is not a state_dict or checkpoint with a state_dict. " | |
f"Type={type(obj)} from file={path}" | |
) | |
# Strip DataParallel 'module.' prefixes if present | |
if any(key.startswith("module.") for key in obj.keys()): | |
obj = {key.replace("module.", "", 1): val for key, val in obj.items()} | |
return obj | |
# ///////////////////////////////////////////////////////// | |
def run_single_model_inference( | |
x_raw: np.ndarray, | |
y_raw: np.ndarray, | |
model_name: str, | |
weights_path: str, | |
args: argparse.Namespace, | |
device: torch.device, | |
) -> Dict[str, Any]: | |
"""Run inference with a single model.""" | |
start_time = time.time() | |
# Preprocess spectrum | |
_, y_proc = preprocess_spectrum( | |
x_raw, | |
y_raw, | |
target_len=args.target_len, | |
modality=args.modality, | |
do_baseline=not args.disable_baseline, | |
do_smooth=not args.disable_smooth, | |
do_normalize=not args.disable_normalize, | |
out_dtype="float32", | |
) | |
# Build model & load weights | |
model = cast(nn.Module, build(model_name, args.target_len)).to(device) | |
state = _load_state_dict_safe(weights_path) | |
missing, unexpected = model.load_state_dict(state, strict=False) | |
if missing or unexpected: | |
logging.info( | |
f"Model {model_name}: Loaded with non-strict keys. missing={len(missing)} unexpected={len(unexpected)}" | |
) | |
model.eval() | |
# Run inference | |
x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device) | |
with torch.no_grad(): | |
logits = model(x_tensor).float().cpu() | |
probs = F.softmax(logits, dim=1) | |
processing_time = time.time() - start_time | |
probs_np = probs.numpy().ravel().tolist() | |
logits_np = logits.numpy().ravel().tolist() | |
pred_label = int(np.argmax(probs_np)) | |
# Map prediction to class name | |
class_names = ["Stable", "Weathered"] | |
predicted_class = ( | |
class_names[pred_label] | |
if pred_label < len(class_names) | |
else f"Class_{pred_label}" | |
) | |
return { | |
"model": model_name, | |
"prediction": pred_label, | |
"predicted_class": predicted_class, | |
"confidence": max(probs_np), | |
"probs": probs_np, | |
"logits": logits_np, | |
"processing_time": processing_time, | |
} | |
# ///////////////////////////////////////////////////////// | |
def run_multi_model_inference( | |
x_raw: np.ndarray, | |
y_raw: np.ndarray, | |
model_names: List[str], | |
args: argparse.Namespace, | |
device: torch.device, | |
) -> Dict[str, Dict[str, Any]]: | |
"""Run inference with multiple models for comparison.""" | |
results = {} | |
for model_name in model_names: | |
try: | |
# Generate weights path - either use pattern or assume same weights for all | |
if args.weights and "{model}" in args.weights: | |
weights_path = args.weights.format(model=model_name) | |
elif args.weights: | |
weights_path = args.weights | |
else: | |
# Default weights path pattern | |
weights_path = f"outputs/{model_name}_model.pth" | |
if not Path(weights_path).exists(): | |
logging.warning(f"Weights not found for {model_name}: {weights_path}") | |
continue | |
result = run_single_model_inference( | |
x_raw, y_raw, model_name, weights_path, args, device | |
) | |
results[model_name] = result | |
except Exception as e: | |
logging.error(f"Failed to run inference with {model_name}: {str(e)}") | |
continue | |
return results | |
# ///////////////////////////////////////////////////////// | |
def save_results( | |
results: Dict[str, Any], output_path: Path, format: str = "json" | |
) -> None: | |
"""Save results to file in specified format""" | |
output_path.parent.mkdir(parents=True, exist_ok=True) | |
if format == "json": | |
with open(output_path, "w", encoding="utf-8") as f: | |
json.dump(results, f, indent=2) | |
elif format == "csv": | |
# Convert to tabular format for CSV | |
if "models" in results: # Multi-model results | |
rows = [] | |
for model_name, model_result in results["models"].items(): | |
row = { | |
"model": model_name, | |
"prediction": model_result["prediction"], | |
"predicted_class": model_result["predicted_class"], | |
"confidence": model_result["confidence"], | |
"processing_time": model_result["processing_time"], | |
} | |
# Add individual class probabilities | |
if "probs" in model_result: | |
for i, prob in enumerate(model_result["probs"]): | |
row[f"prob_class_{i}"] = prob | |
rows.append(row) | |
# Write CSV | |
with open(output_path, "w", newline="", encoding="utf-8") as f: | |
if rows: | |
writer = csv.DictWriter(f, fieldnames=rows[0].keys()) | |
writer.writeheader() | |
writer.writerows(rows) | |
else: # Single model result | |
with open(output_path, "w", newline="", encoding="utf-8") as f: | |
writer = csv.DictWriter(f, fieldnames=results.keys()) | |
writer.writeheader() | |
writer.writerow(results) | |
def main(): | |
logging.basicConfig(level=logging.INFO, format="INFO: %(message)s") | |
args = parse_args() | |
# Input validation | |
in_path = Path(args.input) | |
if not in_path.exists(): | |
raise FileNotFoundError(f"Input file not found: {in_path}") | |
# Determine if this is single or multi-model inference | |
if args.models: | |
model_names = [m.strip() for m in args.models.split(",")] | |
model_names = validate_model_list(model_names) | |
if not model_names: | |
raise ValueError(f"No valid models found in: {args.models}") | |
multi_model = True | |
else: | |
model_names = [args.arch] | |
multi_model = False | |
# Load and parse spectrum data | |
if args.file_format == "auto": | |
file_format = None # Auto-detect | |
else: | |
file_format = args.file_format | |
try: | |
# Read file content | |
with open(in_path, "r", encoding="utf-8") as f: | |
content = f.read() | |
# Parse spectrum data with format detection | |
x_raw, y_raw = parse_spectrum_data(content, str(in_path)) | |
x_raw = np.array(x_raw, dtype=np.float32) | |
y_raw = np.array(y_raw, dtype=np.float32) | |
except Exception as e: | |
x_raw, y_raw = load_spectrum(str(in_path)) | |
x_raw = np.array(x_raw, dtype=np.float32) | |
y_raw = np.array(y_raw, dtype=np.float32) | |
logging.warning( | |
f"Failed to parse with new parser, falling back to original: {e}" | |
) | |
x_raw, y_raw = load_spectrum(str(in_path)) | |
if len(x_raw) < 10: | |
raise ValueError("Input spectrum has too few points (<10).") | |
# Setup device | |
device = torch.device( | |
args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu" | |
) | |
# Run inference | |
model_results = {} # Initialize to avoid unbound variable error | |
if multi_model: | |
model_results = run_multi_model_inference( | |
np.array(x_raw, dtype=np.float32), | |
np.array(y_raw, dtype=np.float32), | |
model_names, | |
args, | |
device, | |
) | |
# Get ground truth if available | |
true_label = label_file(str(in_path)) | |
# Prepare combined results | |
results = { | |
"input_file": str(in_path), | |
"modality": args.modality, | |
"models": model_results, | |
"true_label": true_label, | |
"preprocessing": { | |
"baseline": not args.disable_baseline, | |
"smooth": not args.disable_smooth, | |
"normalize": not args.disable_normalize, | |
"target_len": args.target_len, | |
}, | |
"comparison": { | |
"total_models": len(model_results), | |
"agreements": ( | |
sum( | |
1 | |
for i, (_, r1) in enumerate(model_results.items()) | |
for j, (_, r2) in enumerate( | |
list(model_results.items())[i + 1 :] | |
) | |
if r1["prediction"] == r2["prediction"] | |
) | |
if len(model_results) > 1 | |
else 0 | |
), | |
}, | |
} | |
# Default output path for multi-model | |
default_output = ( | |
Path("outputs") | |
/ "inference" | |
/ f"{in_path.stem}_comparison.{args.output_format}" | |
) | |
else: | |
# Single model inference | |
model_result = run_single_model_inference( | |
x_raw, y_raw, model_names[0], args.weights, args, device | |
) | |
true_label = label_file(str(in_path)) | |
results = { | |
"input_file": str(in_path), | |
"modality": args.modality, | |
"arch": model_names[0], | |
"weights": str(args.weights), | |
"target_len": args.target_len, | |
"preprocessing": { | |
"baseline": not args.disable_baseline, | |
"smooth": not args.disable_smooth, | |
"normalize": not args.disable_normalize, | |
}, | |
"predicted_label": model_result["prediction"], | |
"predicted_class": model_result["predicted_class"], | |
"true_label": true_label, | |
"confidence": model_result["confidence"], | |
"probs": model_result["probs"], | |
"logits": model_result["logits"], | |
"processing_time": model_result["processing_time"], | |
} | |
# Default output path for single model | |
default_output = ( | |
Path("outputs") | |
/ "inference" | |
/ f"{in_path.stem}_{model_names[0]}.{args.output_format}" | |
) | |
# Save results | |
output_path = Path(args.output) if args.output else default_output | |
save_results(results, output_path, args.output_format) | |
# Log summary | |
if multi_model: | |
logging.info( | |
f"Multi-model inference completed with {len(model_results)} models" | |
) | |
for model_name, result in model_results.items(): | |
logging.info( | |
f"{model_name}: {result['predicted_class']} (confidence: {result['confidence']:.3f})" | |
) | |
logging.info(f"Results saved to {output_path}") | |
else: | |
logging.info( | |
f"Predicted Label: {results['predicted_label']} ({results['predicted_class']})" | |
) | |
logging.info(f"Confidence: {results['confidence']:.3f}") | |
logging.info(f"True Label: {results['true_label']}") | |
logging.info(f"Result saved to {output_path}") | |
if __name__ == "__main__": | |
main() | |