Spaces:
Running
Running
import os | |
import sys | |
# Project base path | |
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.append(BASE_DIR) | |
from models.figure2_cnn import Figure2CNN | |
from models.resnet_cnn import ResNet1D | |
from scripts.preprocess_dataset import resample_spectrum | |
from io import StringIO | |
from glob import glob | |
from pathlib import Path | |
import numpy as np | |
import streamlit as st | |
import torch | |
import matplotlib.pyplot as plt | |
# Label map and label extractor | |
label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"} | |
def label_file(filename: str) -> int: | |
name = Path(filename).name.lower() | |
if name.startswith("sta"): | |
return 0 | |
elif name.startswith("wea"): | |
return 1 | |
else: | |
raise ValueError("Unknown label pattern") | |
# Page configuration | |
st.set_page_config( | |
page_title="Polymer Aging Inference", | |
initial_sidebar_state="collapsed", | |
page_icon="π¬", | |
layout="wide") | |
# Reset status if nothing is uploaded | |
if 'uploaded_file' not in st.session_state: | |
st.session_state.status_message = "Awaiting input..." | |
st.session_state.status_type = "info" | |
# Title and caption | |
st.markdown("**π§ͺ Raman Spectrum Classifier**") | |
st.caption("AI-driven classification of polymer degradation using Raman spectroscopy.") | |
# Sidebar | |
with st.sidebar: | |
st.header("βΉοΈ About This App") | |
st.markdown(""" | |
Part of the **AIRE 2025 Internship Project**: | |
`AI-Driven Polymer Aging Prediction and Classification` | |
Uses Raman spectra and deep learning to predict material degradation. | |
**Author**: Jaser Hasan | |
**Mentor**: Dr. Sanmukh Kuppannagari | |
[π GitHub](https://github.com/dev-jaser/ai-ml-polymer-aging-prediction) | |
""") | |
# Metadata for visual badges and metrics | |
model_metadata = { | |
"Figure2CNN (Baseline)": { | |
"emoji": "π¬", | |
"description": "Baseline CNN with standard filters", | |
"accuracy": "94.80%", | |
"f1": "94.30%" | |
}, | |
"ResNet1D (Advanced)": { | |
"emoji": "π§ ", | |
"description": "Residual CNN with deeper feature learning", | |
"accuracy": "96.20%", | |
"f1": "95.90%" | |
} | |
} | |
model_config = { | |
"Figure2CNN (Baseline)": { | |
"model_class": Figure2CNN, | |
"model_path": "outputs/figure2_model.pth" | |
}, | |
"ResNet1D (Advanced)": { | |
"model_class": ResNet1D, | |
"model_path": "outputs/resnet_model.pth" | |
} | |
} | |
col1, col2 = st.columns([1.1, 2], gap="large") # optional for cleaner spacing | |
try: | |
with col1: | |
# π Upload + Model Selection | |
st.markdown("**π Upload Spectrum**") | |
# [NEW POSITION] π§ Model Selection grounded near data input | |
with st.container(): | |
st.markdown("**π§ Model Selection**") | |
# Enhanced model selector | |
model_labels = [ | |
f"{model_metadata[name]['emoji']} {name}" for name in model_config.keys() | |
] | |
selected_label = st.selectbox( | |
"Choose model architecture:", | |
model_labels, | |
key="model_selector" | |
) | |
model_choice = selected_label.split(" ", 1)[1] | |
with st.container(): | |
meta = model_metadata[model_choice] | |
st.markdown(f""" | |
**π Model Overview** | |
*{meta['description']}* | |
- **Accuracy**: `{meta['accuracy']}` | |
- **F1 Score**: `{meta['f1']}` | |
""") | |
# Model path & check | |
# [PATCH] Use selected model config | |
MODEL_PATH = model_config[model_choice]["model_path"] | |
MODEL_EXISTS = Path(MODEL_PATH).exists() | |
TARGET_LEN = 500 | |
if not MODEL_EXISTS: | |
st.error("π« Model file not found. Please train the model first.") | |
tab1, tab2 = st.tabs(["Upload File", "Use Sample"]) | |
with tab1: | |
uploaded_file = st.file_uploader("Upload Raman `.txt` spectrum", type="txt") | |
with tab2: | |
sample_files = sorted(glob("app/sample_spectra/*.txt")) | |
sample_options = ["-- Select --"] + sample_files | |
selected_sample = st.selectbox("Choose a sample:", sample_options) | |
if selected_sample != "-- Select --": | |
with open(selected_sample, "r", encoding="utf-8") as f: | |
file_contents = f.read() | |
uploaded_file = StringIO(file_contents) | |
uploaded_file.name = os.path.basename(selected_sample) | |
# Capture file in session | |
if uploaded_file is not None: | |
st.session_state['uploaded_file'] = uploaded_file | |
st.session_state['filename'] = uploaded_file.name | |
st.session_state.status_message = f"π File `{uploaded_file.name}` loaded. Ready to infer." | |
st.session_state.status_type = "success" | |
st.session_state.inference_run_once = False | |
# Status banner | |
st.markdown("**π¦ Pipeline Status**") | |
status_msg = st.session_state.get("status_message", "Awaiting input...") | |
status_typ = st.session_state.get("status_type", "info") | |
if status_typ == "success": | |
st.success(status_msg) | |
elif status_typ == "error": | |
st.error(status_msg) | |
else: | |
st.info(status_msg) | |
# Inference trigger | |
if st.button("βΆοΈ Run Inference") and 'uploaded_file' in st.session_state and MODEL_EXISTS: | |
spectrum_name = st.session_state['filename'] | |
uploaded_file = st.session_state['uploaded_file'] | |
uploaded_file.seek(0) | |
raw_data = uploaded_file.read() | |
raw_text = raw_data.decode("utf-8") if isinstance(raw_data, bytes) else raw_data | |
# Parse spectrum | |
x_vals, y_vals = [], [] | |
for line in raw_text.splitlines(): | |
parts = line.strip().replace(",", " ").split() | |
numbers = [p for p in parts if p.replace('.', '', 1).replace('-', '', 1).isdigit()] | |
if len(numbers) >= 2: | |
try: | |
x, y = float(numbers[0]), float(numbers[1]) | |
x_vals.append(x) | |
y_vals.append(y) | |
except ValueError: | |
continue | |
x_raw = np.array(x_vals) | |
y_raw = np.array(y_vals) | |
y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN) | |
st.session_state['x_raw'] = x_raw | |
st.session_state['y_raw'] = y_raw | |
st.session_state['y_resampled'] = y_resampled | |
# --- | |
# Update banner for inference | |
st.session_state.status_message = f"π Inference running on: `{spectrum_name}`" | |
st.session_state.status_type = "info" | |
st.session_state.inference_run_once = True | |
# Inference | |
with col2: | |
if st.session_state.get("inference_run_once", False): | |
# Plot: Raw + Resampled | |
x_raw = st.session_state.get("x_raw", None) | |
y_raw = st.session_state.get("y_raw", None) | |
y_resampled = st.session_state.get("y_resampled", None) | |
if x_raw is not None and y_raw is not None and y_resampled is not None: | |
st.subheader("π Spectrum Overview") | |
st.write("") # Spacer line for visual breathing room | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from PIL import Image | |
import io | |
# Create smaller figure | |
fig, ax = plt.subplots(1, 2, figsize=(8, 2.5), dpi=150) | |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray") | |
ax[0].set_title("Raw Input") | |
ax[0].set_xlabel("Wavenumber") | |
ax[0].set_ylabel("Intensity") | |
ax[0].legend() | |
ax[1].plot(np.linspace(min(x_raw), max(x_raw), TARGET_LEN), y_resampled, label="Resampled", color="steelblue") | |
ax[1].set_title("Resampled") | |
ax[1].set_xlabel("Wavenumber") | |
ax[1].set_ylabel("Intensity") | |
ax[1].legend() | |
plt.tight_layout() | |
# Render to image buffer | |
canvas = FigureCanvas(fig) | |
buf = io.BytesIO() | |
canvas.print_png(buf) | |
buf.seek(0) | |
# Display fixed-size image | |
st.image(Image.open(buf), caption="Raw vs. Resampled Spectrum", width=880) | |
st.session_state['x_raw'] = x_raw | |
st.session_state['y_raw'] = y_raw | |
y_resampled = st.session_state.get('y_resampled', None) | |
if y_resampled is None: | |
st.error("β Error: Missing resampled spectrum. Please upload and run inference.") | |
st.stop() | |
input_tensor = torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0) | |
# [PATCH] Load selected model | |
ModelClass = model_config[model_choice]["model_class"] | |
model = ModelClass(input_length=TARGET_LEN) | |
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"), strict=False) | |
model.eval() | |
with torch.no_grad(): | |
logits = model(input_tensor) | |
prediction = torch.argmax(logits, dim=1).item() | |
logits_list = logits.numpy().tolist()[0] | |
try: | |
true_label_idx = label_file(spectrum_name) | |
true_label_str = label_map[true_label_idx] | |
except Exception: | |
true_label_idx = None | |
true_label_str = "Unknown" | |
predicted_class = label_map.get(prediction, f"Class {prediction}") | |
import torch.nn.functional as F | |
probs = F.softmax(torch.tensor(logits_list), dim=0).numpy() | |
# π¬ Redesigned Prediction Block β Distinguishing Model vs Classification | |
tab_summary, tab_logits, tab_system, tab_explainer = st.tabs([ | |
"π§ Model Summary", "π¬ Logits", "βοΈ System Info", "π Explanation"]) | |
with tab_summary: | |
st.markdown("### π§ AI Model Decision Summary") | |
st.markdown(f""" | |
**π File Analyzed:** `{spectrum_name}` | |
**π οΈ Model Chosen:** `{model_choice}` | |
""") | |
st.markdown("**π Internal Model Prediction**") | |
st.write(f"The model believes this sample best matches: **`{predicted_class}`**") | |
if true_label_idx is not None: | |
st.caption(f"Ground Truth Label: `{true_label_str}`") | |
logit_margin = abs(logits_list[0] - logits_list[1]) | |
if logit_margin > 1000: | |
strength_desc = "VERY STRONG" | |
elif logit_margin > 250: | |
strength_desc = "STRONG" | |
elif logit_margin > 100: | |
strength_desc = "MODERATE" | |
else: | |
strength_desc = "UNCERTAIN" | |
st.markdown("π§ͺ Final Classification") | |
st.markdown("**π Model Confidence Estimate**") | |
st.write(f"**Decision Confidence:** `{strength_desc}` (margin = `{logit_margin:.1f}`)") | |
st.success(f"This spectrum is classified as: **`{predicted_class}`**") | |
with tab_logits: | |
st.markdown("π¬ View Internal Model Output (Logits)") | |
st.markdown(""" | |
These are the **raw output scores** from the model before making a final prediction. | |
Higher scores indicate stronger alignment between the input spectrum and that class. | |
""") | |
st.json({ | |
label_map.get(i, f"Class {i}"): float(score) | |
for i, score in enumerate(logits_list) | |
}) | |
with tab_system: | |
st.markdown("βοΈ View System Info") | |
st.json({ | |
"Model Chosen": model_choice, | |
"Spectrum Length": TARGET_LEN, | |
"Processing Steps": "Raw Signal β Resampled β Inference" | |
}) | |
with tab_explainer: | |
st.markdown("π What Just Happened?") | |
st.markdown(""" | |
**π Process Overview** | |
1. π A Raman spectrum was uploaded | |
2. π Data was standardized | |
3. π€ AI model analyzed the spectrum | |
4. π A classification was made | |
--- | |
**π§ How the Model Operates** | |
Trained on known polymer conditions, the system detects spectral patterns | |
indicative of stable or weathered polymers. | |
--- | |
**β Why It Matters** | |
Enables: | |
- π¬ Material longevity research | |
- π Recycling assessments | |
- π± Sustainability decisions | |
""") | |
except (ValueError, TypeError, RuntimeError) as e: | |
st.error(f"β Inference error: {e}") |