polymer-aging-ml / backend /inference_utils.py
devjas1
Initial migration from original polymer_project
e484a46
raw
history blame
2.55 kB
def load_model(name):
return "mock_model"
def run_inference(model, spectrum):
return {
"prediction": "Stubbed Output",
"class_index": 0,
"logits": [0.0, 1.0],
"class_labels": ["Stub", "Output"]
}
# ---------- ACTUAL MODEL LOADING/INFERENCE CODE ---------------------|
# import torch
# import numpy as np
# from pathlib import Path
# from scripts.preprocess_dataset import resample_spectrum
# from models.figure2_cnn import Figure2CNN
# from models.resnet_cnn import ResNet1D
# # -- Label Map --
# LABELS = ["Stable (Unweathered)", "Weathered (Degraded)"]
# # -- Model Paths --
# MODEL_CONFIG = {
# "figure2": {
# "class": Figure2CNN,
# "path": "outputs/figure2_model.pth"
# },
# "resnet": {
# "class": ResNet1D,
# "path": "outputs/resnet_model.pth"
# }
# }
# def load_model(model_name: str):
# if model_name not in MODEL_CONFIG:
# raise ValueError(f"Unknown model '{model_name}'. Valid options: {list(MODEL_CONFIG.keys())}")
# config = MODEL_CONFIG[model_name]
# model = config["class"]()
# state_dict = torch.load(config["path"], map_location=torch.device("cpu"), weights_only=True)
# model.load_state_dict(state_dict)
# model.eval()
# return model
# def run_inference(model, spectrum: list):
# # -- Validate Input --
# if not isinstance(spectrum, list) or len(spectrum) < 10:
# raise ValueError("Spectrum must be a list of floats with reasonable length")
# # -- Convert to Numpy --
# spectrum = np.array(spectrum, dtype=np.float32)
# # -- Resample --
# x_vals = np.arange(len(spectrum))
# spectrum = resample_spectrum(x_vals, spectrum, target_len=500)
# # -- Normalize --
# mean = np.mean(spectrum)
# std = np.std(spectrum)
# if std == 0:
# raise ValueError("Standard deviation of spectrum is zero; normalization will fail.")
# spectrum = (spectrum - mean) / std
# # -- To Tensor --
# x = torch.tensor(spectrum, dtype=torch.float32).unsqueeze(0).unsqueeze(0) # Shape (1, 1, 500)
# with torch.no_grad():
# logits = model(x)
# pred_index = torch.argmax(logits, dim=1).item()
# return {
# "prediction": LABELS[pred_index],
# "class_index": pred_index,
# "logits": logits.squeeze().tolist(),
# "class_labels": LABELS
# }
# ---------- ACTUAL MODEL LOADING/INFERENCE CODE ---------------------|