Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	devjas1
				
			
		(SYNC): bring parity backend (utils/ scripts/ models/ tests/) from feat/ui-parity-rebuild; no UI changes
		6373c5a
		
		| # scripts/run_inference.py | |
| """ | |
| CLI inference with preprocessing parity. | |
| Applies: resample → baseline (deg=2) → smooth (w=11,o=2) → normalize | |
| unless explicitly disabled via flags. | |
| Usage (examples): | |
| python scripts/run_inference.py \ | |
| --input datasets/rdwp/sta-1.txt \ | |
| --arch figure2 \ | |
| --weights outputs/figure2_model.pth \ | |
| --target-len 500 | |
| # Disable smoothing only: | |
| python scripts/run_inference.py --input ... --arch resnet --weights ... --disable-smooth | |
| """ | |
| import os | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| import argparse | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import cast | |
| from torch import nn | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from models.registry import build, choices | |
| from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH | |
| from scripts.plot_spectrum import load_spectrum | |
| from scripts.discover_raman_files import label_file | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).") | |
| p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).") | |
| p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.") | |
| p.add_argument("--weights", required=True, help="Path to model weights (.pth).") | |
| p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).") | |
| # Default = ON; use disable- flags to turn steps off explicitly. | |
| p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.") | |
| p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.") | |
| p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.") | |
| p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).") | |
| p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).") | |
| return p.parse_args() | |
| def _load_state_dict_safe(path: str): | |
| """Load a state dict safely across torch versions & checkpoint formats.""" | |
| try: | |
| obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch | |
| except TypeError: | |
| obj = torch.load(path, map_location="cpu") # fallback for older torch | |
| # Accept either a plain state_dict or a checkpoint dict that contains one | |
| if isinstance(obj, dict): | |
| for k in ("state_dict", "model_state_dict", "model"): | |
| if k in obj and isinstance(obj[k], dict): | |
| obj = obj[k] | |
| break | |
| if not isinstance(obj, dict): | |
| raise ValueError( | |
| "Loaded object is not a state_dict or checkpoint with a state_dict. " | |
| f"Type={type(obj)} from file={path}" | |
| ) | |
| # Strip DataParallel 'module.' prefixes if present | |
| if any(key.startswith("module.") for key in obj.keys()): | |
| obj = {key.replace("module.", "", 1): val for key, val in obj.items()} | |
| return obj | |
| def main(): | |
| logging.basicConfig(level=logging.INFO, format="INFO: %(message)s") | |
| args = parse_args() | |
| in_path = Path(args.input) | |
| if not in_path.exists(): | |
| raise FileNotFoundError(f"Input file not found: {in_path}") | |
| # --- Load raw spectrum | |
| x_raw, y_raw = load_spectrum(str(in_path)) | |
| if len(x_raw) < 10: | |
| raise ValueError("Input spectrum has too few points (<10).") | |
| # --- Preprocess (single source of truth) | |
| _, y_proc = preprocess_spectrum( | |
| np.array(x_raw), | |
| np.array(y_raw), | |
| target_len=args.target_len, | |
| do_baseline=not args.disable_baseline, | |
| do_smooth=not args.disable_smooth, | |
| do_normalize=not args.disable_normalize, | |
| out_dtype="float32", | |
| ) | |
| # --- Build model & load weights (safe) | |
| device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu") | |
| model = cast(nn.Module, build(args.arch, args.target_len)).to(device) | |
| state = _load_state_dict_safe(args.weights) | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| if missing or unexpected: | |
| logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected)) | |
| model.eval() | |
| # Shape: (B, C, L) = (1, 1, target_len) | |
| x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device) | |
| with torch.no_grad(): | |
| logits = model(x_tensor).float().cpu() # shape (1, num_classes) | |
| probs = F.softmax(logits, dim=1) | |
| probs_np = probs.numpy().ravel().tolist() | |
| logits_np = logits.numpy().ravel().tolist() | |
| pred_label = int(np.argmax(probs_np)) | |
| # Optional ground-truth from filename (if encoded) | |
| true_label = label_file(str(in_path)) | |
| # --- Prepare output | |
| out_dir = Path("outputs") / "inference" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json") | |
| result = { | |
| "input_file": str(in_path), | |
| "arch": args.arch, | |
| "weights": str(args.weights), | |
| "target_len": args.target_len, | |
| "preprocessing": { | |
| "baseline": not args.disable_baseline, | |
| "smooth": not args.disable_smooth, | |
| "normalize": not args.disable_normalize, | |
| }, | |
| "predicted_label": pred_label, | |
| "true_label": true_label, | |
| "probs": probs_np, | |
| "logits": logits_np, | |
| } | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| json.dump(result, f, indent=2) | |
| logging.info("Predicted Label: %d True Label: %s", pred_label, true_label) | |
| logging.info("Raw Logits: %s", logits_np) | |
| logging.info("Result saved to %s", out_path) | |
| if __name__ == "__main__": | |
| main() | |
