File size: 13,757 Bytes
e484a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import os
import sys

# Project base path
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(BASE_DIR)

from models.figure2_cnn import Figure2CNN
from models.resnet_cnn import ResNet1D
from scripts.preprocess_dataset import resample_spectrum

from io import StringIO
from glob import glob
from pathlib import Path
import numpy as np
import streamlit as st
import torch
import matplotlib.pyplot as plt



# Label map and label extractor
label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}

def label_file(filename: str) -> int:
    name = Path(filename).name.lower()
    if name.startswith("sta"):
        return 0
    elif name.startswith("wea"):
        return 1
    else:
        raise ValueError("Unknown label pattern")

# Page configuration
st.set_page_config(
    page_title="Polymer Aging Inference",
    initial_sidebar_state="collapsed",
    page_icon="πŸ”¬",
    layout="wide")


# Reset status if nothing is uploaded
if 'uploaded_file' not in st.session_state:
    st.session_state.status_message = "Awaiting input..."
    st.session_state.status_type = "info"

# Title and caption
st.markdown("**πŸ§ͺ Raman Spectrum Classifier**")
st.caption("AI-driven classification of polymer degradation using Raman spectroscopy.")

# Sidebar
with st.sidebar:
    st.header("ℹ️ About This App")
    st.markdown("""

    Part of the **AIRE 2025 Internship Project**:

    `AI-Driven Polymer Aging Prediction and Classification`



    Uses Raman spectra and deep learning to predict material degradation.



    **Author**: Jaser Hasan  

    **Mentor**: Dr. Sanmukh Kuppannagari  

    [πŸ”— GitHub](https://github.com/dev-jaser/ai-ml-polymer-aging-prediction)

    """)

# Metadata for visual badges and metrics
model_metadata = {
    "Figure2CNN (Baseline)": {
        "emoji": "πŸ”¬",
        "description": "Baseline CNN with standard filters",
        "accuracy": "94.80%",
        "f1": "94.30%"
    },
    "ResNet1D (Advanced)": {
        "emoji": "🧠",
        "description": "Residual CNN with deeper feature learning",
        "accuracy": "96.20%",
        "f1": "95.90%"
    }
}

model_config = {
    "Figure2CNN (Baseline)": {
        "model_class": Figure2CNN,
        "model_path": "outputs/figure2_model.pth"
    },
    "ResNet1D (Advanced)": {
        "model_class": ResNet1D,
        "model_path": "outputs/resnet_model.pth"
    }
}

col1, col2 = st.columns([1.1, 2], gap="large")  # optional for cleaner spacing

try:
    with col1:
        # πŸ“Š Upload + Model Selection
        st.markdown("**πŸ“ Upload Spectrum**")

        # [NEW POSITION] 🧠 Model Selection grounded near data input
        with st.container():
            st.markdown("**🧠 Model Selection**")
            # Enhanced model selector
            model_labels = [
                f"{model_metadata[name]['emoji']} {name}" for name in model_config.keys()
            ]
            selected_label = st.selectbox(
                "Choose model architecture:",
                model_labels,
                key="model_selector"
            )
            model_choice = selected_label.split(" ", 1)[1]
            with st.container():
                meta = model_metadata[model_choice]
                st.markdown(f"""

                **πŸ“ˆ Model Overview**

                *{meta['description']}*



                - **Accuracy**: `{meta['accuracy']}`

                - **F1 Score**: `{meta['f1']}`

                """)

            
            # Model path & check
            # [PATCH] Use selected model config
            MODEL_PATH = model_config[model_choice]["model_path"]
            MODEL_EXISTS = Path(MODEL_PATH).exists()
            TARGET_LEN = 500

            if not MODEL_EXISTS:
                st.error("🚫 Model file not found. Please train the model first.")
        tab1, tab2 = st.tabs(["Upload File", "Use Sample"])
        with tab1:
            uploaded_file = st.file_uploader("Upload Raman `.txt` spectrum", type="txt")
        with tab2:
            sample_files = sorted(glob("app/sample_spectra/*.txt"))
            sample_options = ["-- Select --"] + sample_files
            selected_sample = st.selectbox("Choose a sample:", sample_options)
            if selected_sample != "-- Select --":
                with open(selected_sample, "r", encoding="utf-8") as f:
                    file_contents = f.read()
                uploaded_file = StringIO(file_contents)
                uploaded_file.name = os.path.basename(selected_sample)

        # Capture file in session
        if uploaded_file is not None:
            st.session_state['uploaded_file'] = uploaded_file
            st.session_state['filename'] = uploaded_file.name
            st.session_state.status_message = f"πŸ“ File `{uploaded_file.name}` loaded. Ready to infer."
            st.session_state.status_type = "success"
            st.session_state.inference_run_once = False

        # Status banner
        st.markdown("**🚦 Pipeline Status**")
        status_msg = st.session_state.get("status_message", "Awaiting input...")
        status_typ = st.session_state.get("status_type", "info")
        if status_typ == "success":
            st.success(status_msg)
        elif status_typ == "error":
            st.error(status_msg)
        else:
            st.info(status_msg)

        # Inference trigger
        if st.button("▢️ Run Inference") and 'uploaded_file' in st.session_state and MODEL_EXISTS:
            spectrum_name = st.session_state['filename']
            uploaded_file = st.session_state['uploaded_file']
            uploaded_file.seek(0)
            raw_data = uploaded_file.read()
            raw_text = raw_data.decode("utf-8") if isinstance(raw_data, bytes) else raw_data

            # Parse spectrum
            x_vals, y_vals = [], []
            for line in raw_text.splitlines():
                parts = line.strip().replace(",", " ").split()
                numbers = [p for p in parts if p.replace('.', '', 1).replace('-', '', 1).isdigit()]
                if len(numbers) >= 2:
                    try:
                        x, y = float(numbers[0]), float(numbers[1])
                        x_vals.append(x)
                        y_vals.append(y)
                    except ValueError:
                        continue

            x_raw = np.array(x_vals)
            y_raw = np.array(y_vals)
            y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN)
            st.session_state['x_raw'] = x_raw
            st.session_state['y_raw'] = y_raw
            st.session_state['y_resampled'] = y_resampled

            # ---

            # Update banner for inference
            st.session_state.status_message = f"πŸ” Inference running on: `{spectrum_name}`"
            st.session_state.status_type = "info"
            st.session_state.inference_run_once = True


    # Inference
    
    with col2:
        if st.session_state.get("inference_run_once", False):
            # Plot: Raw + Resampled
            x_raw = st.session_state.get("x_raw", None)
            y_raw = st.session_state.get("y_raw", None)
            y_resampled = st.session_state.get("y_resampled", None)
            if x_raw is not None and y_raw is not None and y_resampled is not None:
                st.subheader("πŸ“‰ Spectrum Overview")
                st.write("")  # Spacer line for visual breathing room
                from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
                from PIL import Image
                import io 

                # Create smaller figure
                fig, ax = plt.subplots(1, 2, figsize=(8, 2.5), dpi=150)
                ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray")
                ax[0].set_title("Raw Input")
                ax[0].set_xlabel("Wavenumber")
                ax[0].set_ylabel("Intensity")
                ax[0].legend()

                ax[1].plot(np.linspace(min(x_raw), max(x_raw), TARGET_LEN), y_resampled, label="Resampled", color="steelblue")
                ax[1].set_title("Resampled")
                ax[1].set_xlabel("Wavenumber")
                ax[1].set_ylabel("Intensity")
                ax[1].legend()

                plt.tight_layout()

                # Render to image buffer
                canvas = FigureCanvas(fig)
                buf = io.BytesIO()
                canvas.print_png(buf)
                buf.seek(0)

                # Display fixed-size image
                st.image(Image.open(buf), caption="Raw vs. Resampled Spectrum", width=880)


            st.session_state['x_raw'] = x_raw
            st.session_state['y_raw'] = y_raw

            y_resampled = st.session_state.get('y_resampled', None)
            if y_resampled is None:
                st.error("❌ Error: Missing resampled spectrum. Please upload and run inference.")
                st.stop()
            input_tensor = torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            # [PATCH] Load selected model
            ModelClass = model_config[model_choice]["model_class"]
            model = ModelClass(input_length=TARGET_LEN)

            model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"), strict=False)
            model.eval()
            with torch.no_grad():
                logits = model(input_tensor)
                prediction = torch.argmax(logits, dim=1).item()
                logits_list = logits.numpy().tolist()[0]
            try:
                true_label_idx = label_file(spectrum_name)
                true_label_str = label_map[true_label_idx]
            except Exception:
                true_label_idx = None
                true_label_str = "Unknown"
            predicted_class = label_map.get(prediction, f"Class {prediction}")

            import torch.nn.functional as F
            probs = F.softmax(torch.tensor(logits_list), dim=0).numpy()

            
            # πŸ”¬ Redesigned Prediction Block – Distinguishing Model vs Classification
            tab_summary, tab_logits, tab_system, tab_explainer = st.tabs([
            "🧠 Model Summary", "πŸ”¬ Logits", "βš™οΈ System Info", "πŸ“˜ Explanation"])

            
            with tab_summary:
                st.markdown("### 🧠 AI Model Decision Summary")
                st.markdown(f"""

                **πŸ“ƒ File Analyzed:** `{spectrum_name}`

                

                **πŸ› οΈ Model Chosen:** `{model_choice}`

                """)
                st.markdown("**πŸ” Internal Model Prediction**")
                st.write(f"The model believes this sample best matches: **`{predicted_class}`**")
                if true_label_idx is not None:
                    st.caption(f"Ground Truth Label: `{true_label_str}`")
            
                logit_margin = abs(logits_list[0] - logits_list[1])
                if logit_margin > 1000:
                    strength_desc = "VERY STRONG"
                elif logit_margin > 250:
                    strength_desc = "STRONG"
                elif logit_margin > 100:
                    strength_desc = "MODERATE"
                else:
                    strength_desc = "UNCERTAIN"
            
                st.markdown("πŸ§ͺ Final Classification")
                st.markdown("**πŸ“Š Model Confidence Estimate**")
                st.write(f"**Decision Confidence:** `{strength_desc}` (margin = `{logit_margin:.1f}`)")
                st.success(f"This spectrum is classified as: **`{predicted_class}`**")
            
            with tab_logits:
                st.markdown("πŸ”¬ View Internal Model Output (Logits)")
                st.markdown("""

                    These are the **raw output scores** from the model before making a final prediction.

            

                    Higher scores indicate stronger alignment between the input spectrum and that class.

                """)
                st.json({
                    label_map.get(i, f"Class {i}"): float(score)
                    for i, score in enumerate(logits_list)
                })
            
            with tab_system:
                st.markdown("βš™οΈ View System Info")
                st.json({
                    "Model Chosen": model_choice,
                    "Spectrum Length": TARGET_LEN,
                    "Processing Steps": "Raw Signal β†’ Resampled β†’ Inference"
                })
            
            with tab_explainer:
                st.markdown("πŸ“˜ What Just Happened?")
                st.markdown("""

                **πŸ” Process Overview**

                1. πŸ—‚ A Raman spectrum was uploaded  

                2. πŸ“ Data was standardized  

                3. πŸ€– AI model analyzed the spectrum  

                4. πŸ“Œ A classification was made  

            

                ---

                **🧠 How the Model Operates**

            

                Trained on known polymer conditions, the system detects spectral patterns  

                indicative of stable or weathered polymers.

            

                ---

                **βœ… Why It Matters**

            

                Enables:

                - πŸ”¬ Material longevity research  

                - πŸ” Recycling assessments  

                - 🌱 Sustainability decisions  

                """)
            
except (ValueError, TypeError, RuntimeError) as e:
        st.error(f"❌ Inference error: {e}")