Spaces:
Running
Running
File size: 4,337 Bytes
e484a46 86d081c e484a46 86d081c e484a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from pathlib import Path
import argparse
import warnings
import logging
import numpy as np
import torch
from models.figure2_cnn import Figure2CNN
from scripts.preprocess_dataset import resample_spectrum, label_file
# =============================================
# ✅ Raman-Only Inference Script
# This script supports prediction on a single Raman spectrum (.txt file).
# FTIR inference has been deprecated and removed for scientific integrity.
# See: @raman-pipeline-focus-milestone
# =============================================
warnings.filterwarnings(
"ignore",
message=".*weights_only=False.*",
category=FutureWarning
)
def load_raman_spectrum(filepath):
"""Load a 2-column Raman spectrum from a .txt file"""
x_vals, y_vals = [], []
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) == 2:
try:
x, y = float(parts[0]), float(parts[1])
x_vals.append(x)
y_vals.append(y)
except ValueError:
continue
return np.array(x_vals), np.array(y_vals)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run inference on a single Raman spectrum (.txt file)."
)
parser.add_argument(
"--target-len", type=int, required=True,
help="Target length to match model input"
)
parser.add_argument(
"--input", required=True,
help="Path to Raman .txt file."
)
parser.add_argument(
"--model", default="random",
help="Path to .pth model file, or specify 'random' to use untrained weights."
)
parser.add_argument(
"--output", default=None,
help="Where to write prediction result. If omitted, prints to stdout."
)
verbosity = parser.add_mutually_exclusive_group()
verbosity.add_argument(
"--quiet", action="store_true",
help="Show only warnings and errors"
)
verbosity.add_argument(
"--verbose", action="store_true",
help="Show debug-level logging"
)
args = parser.parse_args()
# configure logging
level = logging.INFO
if args.verbose:
level = logging.DEBUG
elif args.quiet:
level = logging.WARNING
logging.basicConfig(level=level, format="%(levelname)s: %(message)s")
try:
# Load & preprocess Raman spectrum
if os.path.isdir(args.input):
parser.error(f"Input must be a single Raman .txt file, got a directory: {args.input}")
x_raw, y_raw = load_raman_spectrum(args.input)
if len(x_raw) < 10:
parser.error("Spectrum too short for inference.")
data = resample_spectrum(x_raw, y_raw, target_len=args.target_len)
# Shape = (1, 1, target_len) — valid input for Raman inference
input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# 2. Load Model
model = Figure2CNN(
input_length=args.target_len,
input_channels=1
)
if args.model != "random":
model.load_state_dict(
torch.load(args.model, map_location="cpu", weights_only=True)
)
model.eval()
# 3. Inference
with torch.no_grad():
logits = model(input_tensor)
pred = torch.argmax(logits, dim=1).item()
# 4. True Label
try:
true_label = label_file(args.input)
label_str = f"True Label: {true_label}"
except FileNotFoundError:
label_str = "True Label: Unknown"
result = f"Predicted Label: {pred} {label_str}\nRaw Logits: {logits.tolist()}"
logging.info(result)
# 5. Save or stdout
if args.output:
# ensure parent dir exists (e.g., outputs/inference/)
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as fout:
fout.write(result)
logging.info("Result saved to %s", args.output)
sys.exit(0)
except Exception as e:
logging.error(e)
sys.exit(1)
|