devjas1 commited on
Commit
6ea9614
·
1 Parent(s): 222f7ff

(FEAT)[CLI Tool]: Add multi-model inference, format detection, and flexible output

Browse files

CLI:
- Accepts either a single model ('--arch') or multiple models ('--models') for inference.
- Supports input files in .txt, .csv, or .json, with auto-format detection or forced format.
- Introduces modality selection (Raman/FTIR) for preprocessing.
- Output can be JSON or CSV, with improved naming and path handling.

Internal logic:
- Added 'run_single_model_inference' and 'run_multi_model_inference' to modularize inference workflows.
- Handles weight path patterns for multi-model runs.
- Results include prediction, confidence, processing time, and class probabilities for each model.
- Output saving supports both formats, including tabular CSV for multi-model runs.
- Summary logs and error handling improved for clarity.

Files changed (1) hide show
  1. scripts/run_inference.py +364 -61
scripts/run_inference.py CHANGED
@@ -17,144 +17,447 @@ python scripts/run_inference.py --input ... --arch resnet --weights ... --disabl
17
 
18
  import os
19
  import sys
 
20
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
21
 
22
  import argparse
23
  import json
 
24
  import logging
25
  from pathlib import Path
26
- from typing import cast
27
  from torch import nn
 
28
 
29
  import numpy as np
30
  import torch
31
  import torch.nn.functional as F
32
 
33
- from models.registry import build, choices
34
  from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
 
35
  from scripts.plot_spectrum import load_spectrum
36
  from scripts.discover_raman_files import label_file
37
 
38
 
39
  def parse_args():
40
- p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).")
41
- p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).")
42
- p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.")
43
- p.add_argument("--weights", required=True, help="Path to model weights (.pth).")
44
- p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Default = ON; use disable- flags to turn steps off explicitly.
47
- p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.")
48
- p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.")
49
- p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).")
52
- p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).")
53
  return p.parse_args()
54
 
55
 
 
 
 
56
  def _load_state_dict_safe(path: str):
57
  """Load a state dict safely across torch versions & checkpoint formats."""
58
  try:
59
  obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
60
  except TypeError:
61
  obj = torch.load(path, map_location="cpu") # fallback for older torch
62
-
63
  # Accept either a plain state_dict or a checkpoint dict that contains one
64
  if isinstance(obj, dict):
65
  for k in ("state_dict", "model_state_dict", "model"):
66
  if k in obj and isinstance(obj[k], dict):
67
  obj = obj[k]
68
  break
69
-
70
  if not isinstance(obj, dict):
71
  raise ValueError(
72
  "Loaded object is not a state_dict or checkpoint with a state_dict. "
73
  f"Type={type(obj)} from file={path}"
74
  )
75
-
76
  # Strip DataParallel 'module.' prefixes if present
77
  if any(key.startswith("module.") for key in obj.keys()):
78
  obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
79
-
80
  return obj
81
 
82
 
83
- def main():
84
- logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
85
- args = parse_args()
86
 
87
- in_path = Path(args.input)
88
- if not in_path.exists():
89
- raise FileNotFoundError(f"Input file not found: {in_path}")
90
 
91
- # --- Load raw spectrum
92
- x_raw, y_raw = load_spectrum(str(in_path))
93
- if len(x_raw) < 10:
94
- raise ValueError("Input spectrum has too few points (<10).")
 
 
 
 
 
 
95
 
96
- # --- Preprocess (single source of truth)
97
  _, y_proc = preprocess_spectrum(
98
- np.array(x_raw),
99
- np.array(y_raw),
100
  target_len=args.target_len,
 
101
  do_baseline=not args.disable_baseline,
102
  do_smooth=not args.disable_smooth,
103
  do_normalize=not args.disable_normalize,
104
  out_dtype="float32",
105
  )
106
 
107
- # --- Build model & load weights (safe)
108
- device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
109
- model = cast(nn.Module, build(args.arch, args.target_len)).to(device)
110
- state = _load_state_dict_safe(args.weights)
111
  missing, unexpected = model.load_state_dict(state, strict=False)
112
  if missing or unexpected:
113
- logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected))
 
 
114
 
115
  model.eval()
116
 
117
- # Shape: (B, C, L) = (1, 1, target_len)
118
  x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
119
 
120
  with torch.no_grad():
121
- logits = model(x_tensor).float().cpu() # shape (1, num_classes)
122
  probs = F.softmax(logits, dim=1)
123
 
 
124
  probs_np = probs.numpy().ravel().tolist()
125
  logits_np = logits.numpy().ravel().tolist()
126
  pred_label = int(np.argmax(probs_np))
127
 
128
- # Optional ground-truth from filename (if encoded)
129
- true_label = label_file(str(in_path))
130
-
131
- # --- Prepare output
132
- out_dir = Path("outputs") / "inference"
133
- out_dir.mkdir(parents=True, exist_ok=True)
134
- out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json")
135
-
136
- result = {
137
- "input_file": str(in_path),
138
- "arch": args.arch,
139
- "weights": str(args.weights),
140
- "target_len": args.target_len,
141
- "preprocessing": {
142
- "baseline": not args.disable_baseline,
143
- "smooth": not args.disable_smooth,
144
- "normalize": not args.disable_normalize,
145
- },
146
- "predicted_label": pred_label,
147
- "true_label": true_label,
148
  "probs": probs_np,
149
  "logits": logits_np,
 
150
  }
151
 
152
- with open(out_path, "w", encoding="utf-8") as f:
153
- json.dump(result, f, indent=2)
154
 
155
- logging.info("Predicted Label: %d True Label: %s", pred_label, true_label)
156
- logging.info("Raw Logits: %s", logits_np)
157
- logging.info("Result saved to %s", out_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  if __name__ == "__main__":
 
17
 
18
  import os
19
  import sys
20
+
21
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
22
 
23
  import argparse
24
  import json
25
+ import csv
26
  import logging
27
  from pathlib import Path
28
+ from typing import cast, Dict, List, Any
29
  from torch import nn
30
+ import time
31
 
32
  import numpy as np
33
  import torch
34
  import torch.nn.functional as F
35
 
36
+ from models.registry import build, choices, build_multiple, validate_model_list
37
  from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
38
+ from utils.multifile import parse_spectrum_data, detect_file_format
39
  from scripts.plot_spectrum import load_spectrum
40
  from scripts.discover_raman_files import label_file
41
 
42
 
43
  def parse_args():
44
+ p = argparse.ArgumentParser(
45
+ description="Raman/FTIR spectrum inference with multi-model support."
46
+ )
47
+ p.add_argument(
48
+ "--input",
49
+ required=True,
50
+ help="Path to spectrum file (.txt, .csv, .json) or directory for batch processing.",
51
+ )
52
+
53
+ # Model selection - either single or multiple
54
+ group = p.add_mutually_exclusive_group(required=True)
55
+ group.add_argument(
56
+ "--arch", choices=choices(), help="Single model architecture key."
57
+ )
58
+ group.add_argument(
59
+ "--models",
60
+ help="Comma-separated list of models for comparison (e.g., 'figure2,resnet,resnet18vision').",
61
+ )
62
+
63
+ p.add_argument(
64
+ "--weights",
65
+ help="Path to model weights (.pth). For multi-model, use pattern with {model} placeholder.",
66
+ )
67
+ p.add_argument(
68
+ "--target-len",
69
+ type=int,
70
+ default=TARGET_LENGTH,
71
+ help="Resample length (default: 500).",
72
+ )
73
+
74
+ # Modality support
75
+ p.add_argument(
76
+ "--modality",
77
+ choices=["raman", "ftir"],
78
+ default="raman",
79
+ help="Spectroscopy modality for preprocessing (default: raman).",
80
+ )
81
 
82
  # Default = ON; use disable- flags to turn steps off explicitly.
83
+ p.add_argument(
84
+ "--disable-baseline", action="store_true", help="Disable baseline correction."
85
+ )
86
+ p.add_argument(
87
+ "--disable-smooth",
88
+ action="store_true",
89
+ help="Disable Savitzky–Golay smoothing.",
90
+ )
91
+ p.add_argument(
92
+ "--disable-normalize",
93
+ action="store_true",
94
+ help="Disable min-max normalization.",
95
+ )
96
+
97
+ p.add_argument(
98
+ "--output",
99
+ default=None,
100
+ help="Output path - JSON for single file, CSV for multi-model comparison.",
101
+ )
102
+ p.add_argument(
103
+ "--output-format",
104
+ choices=["json", "csv"],
105
+ default="json",
106
+ help="Output format for results.",
107
+ )
108
+ p.add_argument(
109
+ "--device",
110
+ default="cpu",
111
+ choices=["cpu", "cuda"],
112
+ help="Compute device (default: cpu).",
113
+ )
114
+
115
+ # File format options
116
+ p.add_argument(
117
+ "--file-format",
118
+ choices=["auto", "txt", "csv", "json"],
119
+ default="auto",
120
+ help="Input file format (auto-detect by default).",
121
+ )
122
 
 
 
123
  return p.parse_args()
124
 
125
 
126
+ # /////////////////////////////////////////////////////////
127
+
128
+
129
  def _load_state_dict_safe(path: str):
130
  """Load a state dict safely across torch versions & checkpoint formats."""
131
  try:
132
  obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
133
  except TypeError:
134
  obj = torch.load(path, map_location="cpu") # fallback for older torch
 
135
  # Accept either a plain state_dict or a checkpoint dict that contains one
136
  if isinstance(obj, dict):
137
  for k in ("state_dict", "model_state_dict", "model"):
138
  if k in obj and isinstance(obj[k], dict):
139
  obj = obj[k]
140
  break
 
141
  if not isinstance(obj, dict):
142
  raise ValueError(
143
  "Loaded object is not a state_dict or checkpoint with a state_dict. "
144
  f"Type={type(obj)} from file={path}"
145
  )
 
146
  # Strip DataParallel 'module.' prefixes if present
147
  if any(key.startswith("module.") for key in obj.keys()):
148
  obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
 
149
  return obj
150
 
151
 
152
+ # /////////////////////////////////////////////////////////
 
 
153
 
 
 
 
154
 
155
+ def run_single_model_inference(
156
+ x_raw: np.ndarray,
157
+ y_raw: np.ndarray,
158
+ model_name: str,
159
+ weights_path: str,
160
+ args: argparse.Namespace,
161
+ device: torch.device,
162
+ ) -> Dict[str, Any]:
163
+ """Run inference with a single model."""
164
+ start_time = time.time()
165
 
166
+ # Preprocess spectrum
167
  _, y_proc = preprocess_spectrum(
168
+ x_raw,
169
+ y_raw,
170
  target_len=args.target_len,
171
+ modality=args.modality,
172
  do_baseline=not args.disable_baseline,
173
  do_smooth=not args.disable_smooth,
174
  do_normalize=not args.disable_normalize,
175
  out_dtype="float32",
176
  )
177
 
178
+ # Build model & load weights
179
+ model = cast(nn.Module, build(model_name, args.target_len)).to(device)
180
+ state = _load_state_dict_safe(weights_path)
 
181
  missing, unexpected = model.load_state_dict(state, strict=False)
182
  if missing or unexpected:
183
+ logging.info(
184
+ f"Model {model_name}: Loaded with non-strict keys. missing={len(missing)} unexpected={len(unexpected)}"
185
+ )
186
 
187
  model.eval()
188
 
189
+ # Run inference
190
  x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
191
 
192
  with torch.no_grad():
193
+ logits = model(x_tensor).float().cpu()
194
  probs = F.softmax(logits, dim=1)
195
 
196
+ processing_time = time.time() - start_time
197
  probs_np = probs.numpy().ravel().tolist()
198
  logits_np = logits.numpy().ravel().tolist()
199
  pred_label = int(np.argmax(probs_np))
200
 
201
+ # Map prediction to class name
202
+ class_names = ["Stable", "Weathered"]
203
+ predicted_class = (
204
+ class_names[pred_label]
205
+ if pred_label < len(class_names)
206
+ else f"Class_{pred_label}"
207
+ )
208
+
209
+ return {
210
+ "model": model_name,
211
+ "prediction": pred_label,
212
+ "predicted_class": predicted_class,
213
+ "confidence": max(probs_np),
 
 
 
 
 
 
 
214
  "probs": probs_np,
215
  "logits": logits_np,
216
+ "processing_time": processing_time,
217
  }
218
 
 
 
219
 
220
+ # /////////////////////////////////////////////////////////
221
+
222
+
223
+ def run_multi_model_inference(
224
+ x_raw: np.ndarray,
225
+ y_raw: np.ndarray,
226
+ model_names: List[str],
227
+ args: argparse.Namespace,
228
+ device: torch.device,
229
+ ) -> Dict[str, Dict[str, Any]]:
230
+ """Run inference with multiple models for comparison."""
231
+ results = {}
232
+
233
+ for model_name in model_names:
234
+ try:
235
+ # Generate weights path - either use pattern or assume same weights for all
236
+ if args.weights and "{model}" in args.weights:
237
+ weights_path = args.weights.format(model=model_name)
238
+ elif args.weights:
239
+ weights_path = args.weights
240
+ else:
241
+ # Default weights path pattern
242
+ weights_path = f"outputs/{model_name}_model.pth"
243
+
244
+ if not Path(weights_path).exists():
245
+ logging.warning(f"Weights not found for {model_name}: {weights_path}")
246
+ continue
247
+
248
+ result = run_single_model_inference(
249
+ x_raw, y_raw, model_name, weights_path, args, device
250
+ )
251
+ results[model_name] = result
252
+
253
+ except Exception as e:
254
+ logging.error(f"Failed to run inference with {model_name}: {str(e)}")
255
+ continue
256
+
257
+ return results
258
+
259
+
260
+ # /////////////////////////////////////////////////////////
261
+
262
+
263
+ def save_results(
264
+ results: Dict[str, Any], output_path: Path, format: str = "json"
265
+ ) -> None:
266
+ """Save results to file in specified format"""
267
+ output_path.parent.mkdir(parents=True, exist_ok=True)
268
+
269
+ if format == "json":
270
+ with open(output_path, "w", encoding="utf-8") as f:
271
+ json.dump(results, f, indent=2)
272
+ elif format == "csv":
273
+ # Convert to tabular format for CSV
274
+ if "models" in results: # Multi-model results
275
+ rows = []
276
+ for model_name, model_result in results["models"].items():
277
+ row = {
278
+ "model": model_name,
279
+ "prediction": model_result["prediction"],
280
+ "predicted_class": model_result["predicted_class"],
281
+ "confidence": model_result["confidence"],
282
+ "processing_time": model_result["processing_time"],
283
+ }
284
+ # Add individual class probabilities
285
+ if "probs" in model_result:
286
+ for i, prob in enumerate(model_result["probs"]):
287
+ row[f"prob_class_{i}"] = prob
288
+ rows.append(row)
289
+
290
+ # Write CSV
291
+ with open(output_path, "w", newline="", encoding="utf-8") as f:
292
+ if rows:
293
+ writer = csv.DictWriter(f, fieldnames=rows[0].keys())
294
+ writer.writeheader()
295
+ writer.writerows(rows)
296
+ else: # Single model result
297
+ with open(output_path, "w", newline="", encoding="utf-8") as f:
298
+ writer = csv.DictWriter(f, fieldnames=results.keys())
299
+ writer.writeheader()
300
+ writer.writerow(results)
301
+
302
+
303
+ def main():
304
+ logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
305
+ args = parse_args()
306
+
307
+ # Input validation
308
+ in_path = Path(args.input)
309
+ if not in_path.exists():
310
+ raise FileNotFoundError(f"Input file not found: {in_path}")
311
+
312
+ # Determine if this is single or multi-model inference
313
+ if args.models:
314
+ model_names = [m.strip() for m in args.models.split(",")]
315
+ model_names = validate_model_list(model_names)
316
+ if not model_names:
317
+ raise ValueError(f"No valid models found in: {args.models}")
318
+ multi_model = True
319
+ else:
320
+ model_names = [args.arch]
321
+ multi_model = False
322
+
323
+ # Load and parse spectrum data
324
+ if args.file_format == "auto":
325
+ file_format = None # Auto-detect
326
+ else:
327
+ file_format = args.file_format
328
+
329
+ try:
330
+ # Read file content
331
+ with open(in_path, "r", encoding="utf-8") as f:
332
+ content = f.read()
333
+
334
+ # Parse spectrum data with format detection
335
+ x_raw, y_raw = parse_spectrum_data(content, str(in_path))
336
+ x_raw = np.array(x_raw, dtype=np.float32)
337
+ y_raw = np.array(y_raw, dtype=np.float32)
338
+
339
+ except Exception as e:
340
+ x_raw, y_raw = load_spectrum(str(in_path))
341
+ x_raw = np.array(x_raw, dtype=np.float32)
342
+ y_raw = np.array(y_raw, dtype=np.float32)
343
+ logging.warning(
344
+ f"Failed to parse with new parser, falling back to original: {e}"
345
+ )
346
+ x_raw, y_raw = load_spectrum(str(in_path))
347
+
348
+ if len(x_raw) < 10:
349
+ raise ValueError("Input spectrum has too few points (<10).")
350
+
351
+ # Setup device
352
+ device = torch.device(
353
+ args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
354
+ )
355
+
356
+ # Run inference
357
+ model_results = {} # Initialize to avoid unbound variable error
358
+ if multi_model:
359
+ model_results = run_multi_model_inference(
360
+ np.array(x_raw, dtype=np.float32),
361
+ np.array(y_raw, dtype=np.float32),
362
+ model_names,
363
+ args,
364
+ device,
365
+ )
366
+
367
+ # Get ground truth if available
368
+ true_label = label_file(str(in_path))
369
+
370
+ # Prepare combined results
371
+ results = {
372
+ "input_file": str(in_path),
373
+ "modality": args.modality,
374
+ "models": model_results,
375
+ "true_label": true_label,
376
+ "preprocessing": {
377
+ "baseline": not args.disable_baseline,
378
+ "smooth": not args.disable_smooth,
379
+ "normalize": not args.disable_normalize,
380
+ "target_len": args.target_len,
381
+ },
382
+ "comparison": {
383
+ "total_models": len(model_results),
384
+ "agreements": (
385
+ sum(
386
+ 1
387
+ for i, (_, r1) in enumerate(model_results.items())
388
+ for j, (_, r2) in enumerate(
389
+ list(model_results.items())[i + 1 :]
390
+ )
391
+ if r1["prediction"] == r2["prediction"]
392
+ )
393
+ if len(model_results) > 1
394
+ else 0
395
+ ),
396
+ },
397
+ }
398
+
399
+ # Default output path for multi-model
400
+ default_output = (
401
+ Path("outputs")
402
+ / "inference"
403
+ / f"{in_path.stem}_comparison.{args.output_format}"
404
+ )
405
+
406
+ else:
407
+ # Single model inference
408
+ model_result = run_single_model_inference(
409
+ x_raw, y_raw, model_names[0], args.weights, args, device
410
+ )
411
+ true_label = label_file(str(in_path))
412
+
413
+ results = {
414
+ "input_file": str(in_path),
415
+ "modality": args.modality,
416
+ "arch": model_names[0],
417
+ "weights": str(args.weights),
418
+ "target_len": args.target_len,
419
+ "preprocessing": {
420
+ "baseline": not args.disable_baseline,
421
+ "smooth": not args.disable_smooth,
422
+ "normalize": not args.disable_normalize,
423
+ },
424
+ "predicted_label": model_result["prediction"],
425
+ "predicted_class": model_result["predicted_class"],
426
+ "true_label": true_label,
427
+ "confidence": model_result["confidence"],
428
+ "probs": model_result["probs"],
429
+ "logits": model_result["logits"],
430
+ "processing_time": model_result["processing_time"],
431
+ }
432
+
433
+ # Default output path for single model
434
+ default_output = (
435
+ Path("outputs")
436
+ / "inference"
437
+ / f"{in_path.stem}_{model_names[0]}.{args.output_format}"
438
+ )
439
+
440
+ # Save results
441
+ output_path = Path(args.output) if args.output else default_output
442
+ save_results(results, output_path, args.output_format)
443
+
444
+ # Log summary
445
+ if multi_model:
446
+ logging.info(
447
+ f"Multi-model inference completed with {len(model_results)} models"
448
+ )
449
+ for model_name, result in model_results.items():
450
+ logging.info(
451
+ f"{model_name}: {result['predicted_class']} (confidence: {result['confidence']:.3f})"
452
+ )
453
+ logging.info(f"Results saved to {output_path}")
454
+ else:
455
+ logging.info(
456
+ f"Predicted Label: {results['predicted_label']} ({results['predicted_class']})"
457
+ )
458
+ logging.info(f"Confidence: {results['confidence']:.3f}")
459
+ logging.info(f"True Label: {results['true_label']}")
460
+ logging.info(f"Result saved to {output_path}")
461
 
462
 
463
  if __name__ == "__main__":