devjas1 commited on
Commit
9a4db95
·
1 Parent(s): 6927358

(UI): refine expander styling + confidence display

Browse files

- Enhance and streamline UI/UX for visual clarity
- Reorganized page sections
- Custom CSS styling

Files changed (1) hide show
  1. app.py +314 -213
app.py CHANGED
@@ -40,15 +40,128 @@ st.set_page_config(
40
  initial_sidebar_state="expanded"
41
  )
42
 
43
- # Stabilize tab panel height on HF Spaces to prevent visible column jitter.
44
- # This sets a minimum height for the content area under the tab headers.
45
  st.markdown("""
46
  <style>
47
- /* Tabs content area: the sibling after the tablist */
48
- div[data-testid="stTabs"] > div[role="tablist"] + div { min-height: 420px;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  </style>
50
  """, unsafe_allow_html=True)
51
 
 
52
  # Constants
53
  TARGET_LEN = 500
54
  SAMPLE_DATA_DIR = Path("sample_data")
@@ -63,7 +176,7 @@ MODEL_CONFIG = {
63
  "Figure2CNN (Baseline)": {
64
  "class": Figure2CNN,
65
  "path": f"{MODEL_WEIGHTS_DIR}/figure2_model.pth",
66
- "emoji": "🔬",
67
  "description": "Baseline CNN with standard filters",
68
  "accuracy": "94.80%",
69
  "f1": "94.30%"
@@ -71,7 +184,7 @@ MODEL_CONFIG = {
71
  "ResNet1D (Advanced)": {
72
  "class": ResNet1D,
73
  "path": f"{MODEL_WEIGHTS_DIR}/resnet_model.pth",
74
- "emoji": "🧠",
75
  "description": "Residual CNN with deeper feature learning",
76
  "accuracy": "96.20%",
77
  "f1": "95.90%"
@@ -84,6 +197,7 @@ LABEL_MAP = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
84
 
85
  # === UTILITY FUNCTIONS ===
86
  def init_session_state():
 
87
  defaults = {
88
  "status_message": "Ready to analyze polymer spectra 🔬",
89
  "status_type": "info",
@@ -256,11 +370,74 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
256
  plt.close(fig) # Prevent memory leaks
257
 
258
  return Image.open(buf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- def render_confidence_bar(probabilities, class_labels):
261
- bar = lambda p: "█" * int(p * 20)
262
- for label, prob in zip(class_labels, probabilities):
263
- st.write(f"**{label}**: {bar(prob)} {prob*100:.1f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
 
266
  def get_confidence_description(logit_margin):
@@ -331,7 +508,7 @@ def reset_results(reason: str = ""):
331
  st.session_state["status_type"] = "info"
332
 
333
  def reset_ephemeral_state():
334
- # === remove everything except KEPT global UI context ===
335
  for k in list(st.session_state.keys()):
336
  if k not in KEEP_KEYS:
337
  st.session_state.pop(k, None)
@@ -356,99 +533,57 @@ def reset_ephemeral_state():
356
 
357
  st.rerun()
358
 
359
- def plot_confidence_bar(probabilities: list[float], class_labels: list[str]) -> None:
360
- """Renders a horizontal bar chart of prediction confidences per class."""
361
- fig, ax = plt.subplots(figsize=(4, 1.5))
362
- bars = ax.barh(class_labels, probabilities, color=[
363
- "green" if i == np.argmax(probabilities) else "gray"
364
- for i in range(len(probabilities))
365
- ])
366
- ax.set_xlabel("Confidence")
367
- ax.set_title("Prediction Confidence")
368
- ax.xaxis.set_ticks([0, 0.5, 1.0])
369
- ax.set_xlim(0, 1.0)
370
- for i, (label, prob) in enumerate(zip(class_labels, probabilities)):
371
- ax.text(prob + 0.01, i, f"{prob*100:.1f}%", va='center', fontsize=8)
372
-
373
- st.pyplot(fig)
374
-
375
-
376
  # Main app
377
  def main():
378
  init_session_state()
379
- # Header
380
- st.title("🔬 AI-Driven Polymer Classification")
381
- st.markdown(
382
- "**Predict polymer degradation states using Raman spectroscopy and deep learning**")
383
- st.info(
384
- "**Prototype Notice:** v0.1 Raman-only. "
385
- "Multi-model CNN evaluation in progress. "
386
- "FTIR support planned.",
387
- icon="⚡"
388
- )
389
 
390
  # Sidebar
391
  with st.sidebar:
392
- st.header("ℹ️ About This App")
393
- st.sidebar.markdown("""
394
- AI-Driven Polymer Aging Prediction and Classification
395
-
396
- 🎯 **Purpose**: Classify polymer degradation using AI
397
- 📊 **Input**: Raman spectroscopy `.txt` files
398
- 🧠 **Models**: CNN architectures for binary classification
399
- 💾 **Current**: Figure2CNN (baseline)
400
- 📈 **Next**: More trained CNNs in evaluation pipeline
401
-
402
- ---
403
 
404
- **Team**
405
- Dr. Sanmukh Kuppannagari (Mentor)
406
- Dr. Metin Karailyan (Mentor)
407
- 👨‍💻 Jaser Hasan (Author)
408
 
409
- ---
 
 
 
410
 
411
- **Links**
412
- 🔗 [Live HF Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
413
- 📂 [GitHub Repository](https://github.com/KLab-AI3/ml-polymer-recycling)
 
414
 
415
- ---
416
 
417
- **Model Credit**
418
- Baseline model inspired by *Figure 2 CNN* from:
419
- > Neo, E.R.K., Low, J.S.C., Goodship, V., Debattista, K. (2023).
420
- > *Deep learning for chemometric analysis of plastic spectral data from infrared and Raman databases*.
421
- > _Resources, Conservation & Recycling_, **188**, 106718.
422
 
423
- [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718)
424
- """)
425
 
426
- st.markdown("---")
 
 
427
 
428
- # Model selection
429
- st.subheader("🧠 Model Selection")
430
- model_labels = [
431
- f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()]
432
- selected_label = st.selectbox("Choose AI model:", model_labels,
433
- key="model_select", on_change=on_model_change)
434
- model_choice = selected_label.split(" ", 1)[1]
435
 
436
- # Model info
437
- config = MODEL_CONFIG[model_choice]
438
- st.markdown(f"""
439
- **📈 {config['emoji']} Model Details**
440
-
441
- *{config['description']}*
442
-
443
- - **Accuracy**: `{config['accuracy']}`
444
- - **F1 Score**: `{config['f1']}`
445
- """)
446
 
447
  # Main content area
448
- col1, col2 = st.columns([1, 1.5], gap="large")
449
 
450
  with col1:
451
- st.subheader("📁 Data Input")
452
 
453
  mode = st.radio(
454
  "Input mode",
@@ -484,7 +619,7 @@ def main():
484
  st.session_state["status_type"] = "success"
485
 
486
  if up:
487
- st.success(f"✅ Loaded: {up.name}")
488
 
489
  # ---- Sample tab ----
490
  else:
@@ -499,12 +634,12 @@ def main():
499
  on_change=on_sample_change, # <-- critical
500
  )
501
  if sel != "-- Select Sample --":
502
- st.success(f"✅ Loaded sample: {sel}")
503
  else:
504
  st.info("No sample data available")
505
 
506
  # ---- Status box ----
507
- st.subheader("🚦 Status")
508
  msg = st.session_state.get("status_message", "Ready")
509
  typ = st.session_state.get("status_type", "info")
510
  if typ == "success":
@@ -553,11 +688,8 @@ def main():
553
  r1, r2 = resample_spectrum(x_raw, y_raw, TARGET_LEN)
554
 
555
  def _is_strictly_increasing(a):
556
- try:
557
- a = np.asarray(a)
558
- return a.ndim == 1 and a.size >= 2 and np.all(np.diff(a) > 0)
559
- except Exception:
560
- return False
561
 
562
  if _is_strictly_increasing(r1) and not _is_strictly_increasing(r2):
563
  x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
@@ -592,7 +724,7 @@ def main():
592
  # Results column
593
  with col2:
594
  if st.session_state.get("inference_run_once", False):
595
- st.subheader("📊 Analysis Results")
596
 
597
  # Get data from session state
598
  x_raw = st.session_state.get('x_raw')
@@ -650,127 +782,98 @@ def main():
650
  predicted_class = LABEL_MAP.get(
651
  int(prediction), f"Class {int(prediction)}")
652
 
653
- # Calculate confidence metrics
654
  logit_margin = abs(
655
  logits_list[0] - logits_list[1]) if len(logits_list) >= 2 else 0
656
  confidence_desc, confidence_emoji = get_confidence_description(
657
  logit_margin)
658
 
659
- # Display results
660
- st.markdown("### 🎯 Prediction Results")
661
-
662
- # Main prediction
663
- st.markdown(f"""
664
- **🔬 Sample**: `{filename}`
665
- **🧠 Model**: `{model_choice}`
666
- **⏱️ Processing Time**: `{inference_time:.2f}s`
667
- """)
668
-
669
- # Prediction box
670
- if predicted_class == "Stable (Unweathered)":
671
- st.success(f"🟢 **Prediction**: {predicted_class}")
672
- else:
673
- st.warning(f"🟡 **Prediction**: {predicted_class}")
674
-
675
- # Confidence
676
- st.markdown(
677
- f"**{confidence_emoji} Confidence**: {confidence_desc} (margin: {logit_margin:.1f})")
678
-
679
- # Ground truth comparison
680
- if true_label_idx is not None:
681
- if predicted_class == true_label_str:
682
- st.success(
683
- f"✅ **Ground Truth**: {true_label_str} - **Correct!**")
684
- else:
685
- st.error(
686
- f"❌ **Ground Truth**: {true_label_str} - **Incorrect**")
687
- else:
688
- st.info(
689
- "ℹ️ **Ground Truth**: Unknown (filename doesn't follow naming convention)")
690
-
691
- # ===display confidence results===
692
- class_labels = ["Stable", "Weathered"]
693
- st.markdown("#### 🔬 Confidence Overview")
694
- def render_confidence_bar(prob, length=20):
695
- filled = int(prob + length)
696
- return "█" * filled + "░" * (length - filled)
697
-
698
- for i, label in enumerate(class_labels):
699
- st.write(f"**{label}**: {render_confidence_bar(probs[i])} {probs[i]*100:.1f}%")
700
-
701
  # ===Detailed results tabs===
702
  tab1, tab2, tab3 = st.tabs(
703
- ["📊 Details", "🔬 Technical", "📘 Explanation"])
704
 
705
  with tab1:
706
- st.markdown("**Model Output (Logits)**")
707
- for i, score in enumerate(logits_list):
708
- label = LABEL_MAP.get(i, f"Class {i}")
709
- st.metric(label, f"{score:.2f}")
710
-
711
- st.markdown("**Spectrum Statistics**")
712
- st.json({
713
- "Original Length": len(x_raw) if x_raw is not None else 0,
714
- "Resampled Length": TARGET_LEN,
715
- "Wavenumber Range": f"{min(x_raw):.1f} - {max(x_raw):.1f} cm⁻¹" if x_raw is not None else "N/A",
716
- "Intensity Range": f"{min(y_raw):.1f} - {max(y_raw):.1f}" if y_raw is not None else "N/A",
717
- "Model Confidence": confidence_desc
718
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
  with tab2:
721
- st.markdown("**Technical Information**")
722
- model_path = MODEL_CONFIG[model_choice]["path"]
723
- mtime = os.path.getmtime(model_path) if os.path.exists(
724
- model_path) else "N/A"
725
- file_hash = hashlib.md5(open(model_path, 'rb').read(
726
- )).hexdigest() if os.path.exists(model_path) else "N/A"
727
- st.json({
728
- "Model Architecture": model_choice,
729
- "Model Path": model_path,
730
- "Weights Last Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime != "N/A" else "N/A",
731
- "Weights Hash": file_hash,
732
- "Input Shape": list(input_tensor.shape),
733
- "Output Shape": list(logits.shape),
734
- "Inference Time": f"{inference_time:.3f}s",
735
- "Device": "CPU",
736
- "Model Loaded": model_loaded
737
- })
738
-
739
- if not model_loaded:
740
- st.warning(
741
- "⚠️ Demo mode: Using randomly initialized weights")
742
-
743
- # Debug log
744
- st.markdown("**Debug Log**")
745
- st.text_area("Logs", "\n".join(
746
- st.session_state.get("log_messages", [])), height=200)
747
-
748
- try:
749
- resampler_mod = getattr(resample_spectrum, "__module__", "unknown")
750
- resampler_doc = getattr(resample_spectrum, "__doc__", None)
751
- resampler_doc = resampler_doc.splitlines()[0] if isinstance(resampler_doc, str) and resampler_doc else "no doc"
752
-
753
- y_rs = st.session_state.get("y_resampled", None)
754
- diag = {}
755
- if y_rs is not None:
756
- arr = np.asarray(y_rs)
757
- diag = {
758
- "y_resampled_len": int(arr.size),
759
- "y_resampled_min": float(np.min(arr)) if arr.size else None,
760
- "y_resampled_max": float(np.max(arr)) if arr.size else None,
761
- "y_resampled_ptp": float(np.ptp(arr)) if arr.size else None,
762
- "y_resampled_unique": int(np.unique(arr).size) if arr.size else None,
763
- "y_resampled_all_equal": bool(np.ptp(arr) == 0.0) if arr.size else None,
764
- }
765
-
766
- st.markdown("**Resampler Info")
767
- st.json({
768
- "module": resampler_mod,
769
- "doc": resampler_doc,
770
- **({"y_resampled_stats": diag} if diag else {})
771
- })
772
- except Exception as _e:
773
- st.warning(f"Diagnostics skipped: {_e}")
774
 
775
  with tab3:
776
  st.markdown("""
@@ -803,21 +906,19 @@ def main():
803
  st.error(
804
  "❌ Missing spectrum data. Please upload a file and run analysis.")
805
  else:
806
- # Welcome message
807
  st.markdown("""
808
- ### 👋 Welcome to AI Polymer Classification
809
-
810
- **Get started by:**
811
- 1. 🧠 Select an AI model in the sidebar
812
- 2. 📁 Upload a Raman spectrum file or choose a sample
813
- 3. ▶️ Click "Run Analysis" to get predictions
814
 
815
- **Supported formats:**
816
  - Text files (.txt) with wavenumber and intensity columns
817
  - Space or comma-separated values
818
  - Any length (automatically resampled to 500 points)
819
 
820
- **Example applications:**
821
  - 🔬 Research on polymer degradation
822
  - ♻️ Recycling feasibility assessment
823
  - 🌱 Sustainability impact studies
 
40
  initial_sidebar_state="expanded"
41
  )
42
 
 
 
43
  st.markdown("""
44
  <style>
45
+ /* Keep only scoped utility styles; no .block-container edits */
46
+
47
+ /* Tabs content area height (your original intent) */
48
+ div[data-testid="stTabs"] > div[role="tablist"] + div { min-height: 420px; }
49
+
50
+ /* Compact info box for confidence bar */
51
+ .confbox {
52
+ font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
53
+ font-size: 0.95rem;
54
+ padding: 8px 10px; border: 1px solid rgba(0,0,0,.07);
55
+ border-radius: 8px; background: rgba(0,0,0,.02);
56
+ }
57
+
58
+ /* Clean key–value rows for technical info */
59
+ .kv-row { display:flex; justify-content:space-between;
60
+ border-bottom: 1px dotted rgba(0,0,0,.10); padding: 3px 0; gap: 12px; }
61
+ .kv-key { opacity:.75; font-size: 0.92rem; white-space: nowrap; }
62
+ .kv-val { font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
63
+ overflow-wrap: anywhere; }
64
+
65
+ /* Ensure markdown h5 headings remain visible after layout shifts */
66
+ :where(h5, .stMarkdown h5) { margin-top: 0.25rem; }
67
+
68
+ /* === Base Expander Header === */
69
+ div.stExpander > details > summary {
70
+ display: flex;
71
+ align-items: center;
72
+ justify-content: space-between;
73
+ list-style: none; /* remove default arrow */
74
+ cursor: pointer;
75
+ border: 1px solid rgba(0,0,0,.15);
76
+ border-left: 4px solid #9ca3af; /* default gray accent */
77
+ border-radius: 6px;
78
+ padding: 6px 12px;
79
+ margin: 6px 0;
80
+ background: rgba(0,0,0,0.04);
81
+ font-weight: 600;
82
+ font-size: 0.92rem;
83
+ }
84
+
85
+ /* Remove ugly default disclosure triangle */
86
+ div.stExpander > details > summary::-webkit-details-marker {
87
+ display: none;
88
+ }
89
+ div.stExpander > details > summary::marker {
90
+ display: none;
91
+ }
92
+
93
+ /* Hover/active subtlety */
94
+ div.stExpander > details[open] > summary {
95
+ background: rgba(0,0,0,0.06);
96
+ }
97
+
98
+ /* Hide Streamlit's custom arrow icon inside expanders */
99
+ div[data-testid="stExpander"] summary svg {
100
+ display: none !important;
101
+ }
102
+
103
+ /* === Right Badge === */
104
+ div.stExpander > details > summary::after {
105
+ content: "MORE ↓";
106
+ font-size: 0.70rem;
107
+ font-weight: 600;
108
+ letter-spacing: .04em;
109
+ padding: 2px 8px;
110
+ border-radius: 999px;
111
+ margin-left: auto;
112
+ background: #e5e7eb;
113
+ color: #111827;
114
+ }
115
+
116
+ /* === Variants by Keyword === */
117
+ div.stExpander:has(summary:contains("Prediction")) > details > summary {
118
+ border-left-color: #2e7d32;
119
+ background: rgba(46,125,50,0.08);
120
+ }
121
+ div.stExpander:has(summary:contains("Prediction")) > details > summary::after {
122
+ content: "RESULTS";
123
+ background: rgba(46,125,50,0.15); color: #184a1d;
124
+ }
125
+
126
+ div.stExpander:has(summary:contains("Technical")) > details > summary {
127
+ border-left-color: #ed6c02;
128
+ background: rgba(237,108,2,0.08);
129
+ }
130
+ div.stExpander:has(summary:contains("Technical")) > details > summary::after {
131
+ content: "ADVANCED";
132
+ background: rgba(237,108,2,0.18); color: #7a3d00;
133
+ }
134
+
135
+ /* === FONT SIZE STANDARDIZATION === */
136
+
137
+ /* Sidebar metrics (Accuracy, F1 Score) */
138
+ div[data-testid="stMetricValue"] {
139
+ font-size: 0.95rem !important; /* uniform body size */
140
+ }
141
+ div[data-testid="stMetricLabel"] {
142
+ font-size: 0.85rem !important;
143
+ opacity: 0.85;
144
+ }
145
+
146
+ /* Sidebar expander text */
147
+ section[data-testid="stSidebar"] .stMarkdown p {
148
+ font-size: 0.92rem !important;
149
+ line-height: 1.4;
150
+ }
151
+
152
+ /* Diagnostics tab metrics (Logits) */
153
+ div[data-testid="stMetricValue"] {
154
+ font-size: 0.95rem !important;
155
+ }
156
+ div[data-testid="stMetricLabel"] {
157
+ font-size: 0.85rem !important;
158
+ }
159
+
160
+
161
  </style>
162
  """, unsafe_allow_html=True)
163
 
164
+
165
  # Constants
166
  TARGET_LEN = 500
167
  SAMPLE_DATA_DIR = Path("sample_data")
 
176
  "Figure2CNN (Baseline)": {
177
  "class": Figure2CNN,
178
  "path": f"{MODEL_WEIGHTS_DIR}/figure2_model.pth",
179
+ "emoji": "",
180
  "description": "Baseline CNN with standard filters",
181
  "accuracy": "94.80%",
182
  "f1": "94.30%"
 
184
  "ResNet1D (Advanced)": {
185
  "class": ResNet1D,
186
  "path": f"{MODEL_WEIGHTS_DIR}/resnet_model.pth",
187
+ "emoji": "",
188
  "description": "Residual CNN with deeper feature learning",
189
  "accuracy": "96.20%",
190
  "f1": "95.90%"
 
197
 
198
  # === UTILITY FUNCTIONS ===
199
  def init_session_state():
200
+ """Keep a persistent session state"""
201
  defaults = {
202
  "status_message": "Ready to analyze polymer spectra 🔬",
203
  "status_type": "info",
 
370
  plt.close(fig) # Prevent memory leaks
371
 
372
  return Image.open(buf)
373
+
374
+ def _pct(p: float) -> str:
375
+ # Fixed-width percent like " 98.7%" or " 2.3%"
376
+ return f"{float(p)*100:5.1f}%"
377
+
378
+ def render_confidence_progress(
379
+ probs: np.ndarray,
380
+ labels: list[str] = ["Stable", "Weathered"],
381
+ highlight_idx: int | None = None,
382
+ side_by_side: bool = True
383
+ ):
384
+ """Render Streamlit native progress bars (0 - 100). Optionally bold the winning class
385
+ and place the two bars side-by-side for compactness."""
386
+ p = np.asarray(probs, dtype=float)
387
+ p = np.clip(p, 0.0, 1.0)
388
+
389
+ def _title(i: int, lbl: str, val: float) -> str:
390
+ t = f"{lbl} - {val*100:.1f}%"
391
+ return f"**{t}**" if (highlight_idx is not None and i == highlight_idx) else t
392
+
393
+ if side_by_side:
394
+ cols = st.columns(len(labels))
395
+ for i, (lbl, val, col) in enumerate(zip(labels, p, cols)):
396
+ with col:
397
+ st.markdown(_title(i, lbl, float(val)))
398
+ st.progress(int(round(val * 100)))
399
+ else:
400
+ for i, (lbl, val) in enumerate(zip(labels, p)):
401
+ st.markdown(_title(i, lbl, float(val)))
402
+ st.progress(int(round(val * 100)))
403
+
404
 
405
+
406
+
407
+
408
+
409
+ def render_kv_grid(d: dict, ncols: int = 2):
410
+ """Display dict as a clean grid of key/value rows."""
411
+ if not d:
412
+ return
413
+ items = list(d.items())
414
+ cols = st.columns(ncols)
415
+ for i, (k, v) in enumerate(items):
416
+ with cols[i % ncols]:
417
+ st.markdown(
418
+ f"<div class='kv-row'><span class='kv-key'>{k}</span>"
419
+ f"<span class='kv-val'>{v}</span></div>",
420
+ unsafe_allow_html=True
421
+ )
422
+
423
+
424
+
425
+
426
+ def render_model_meta(model_choice: str):
427
+ info = MODEL_CONFIG.get(model_choice, {})
428
+ emoji = info.get("emoji", "")
429
+ desc = info.get("description", "").strip()
430
+ acc = info.get("accuracy", "-")
431
+ f1 = info.get("f1", "-")
432
+
433
+ st.caption(f"{emoji} **Model Snapshot** - {model_choice}")
434
+ cols = st.columns(2)
435
+ with cols[0]:
436
+ st.metric("Accuracy", acc)
437
+ with cols[1]:
438
+ st.metric("F1 Score", f1)
439
+ if desc:
440
+ st.caption(desc)
441
 
442
 
443
  def get_confidence_description(logit_margin):
 
508
  st.session_state["status_type"] = "info"
509
 
510
  def reset_ephemeral_state():
511
+ """remove everything except KEPT global UI context"""
512
  for k in list(st.session_state.keys()):
513
  if k not in KEEP_KEYS:
514
  st.session_state.pop(k, None)
 
533
 
534
  st.rerun()
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  # Main app
537
  def main():
538
  init_session_state()
 
 
 
 
 
 
 
 
 
 
539
 
540
  # Sidebar
541
  with st.sidebar:
542
+ # Header
543
+ st.header("AI-Driven Polymer Classification")
544
+ st.caption("Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.1")
545
+ model_labels = [f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()]
546
+ selected_label = st.selectbox("Choose AI Model", model_labels, key="model_select", on_change=on_model_change)
547
+ model_choice = selected_label.split(" ", 1)[1]
 
 
 
 
 
548
 
549
+ # ===Compact metadata directly under dropdown===
550
+ render_model_meta(model_choice)
 
 
551
 
552
+ # ===Collapsed info to reduce clutter===
553
+ with st.expander("About This App",icon=":material/info:", expanded=False):
554
+ st.markdown("""
555
+ AI-Driven Polymer Aging Prediction and Classification
556
 
557
+ **Purpose**: Classify polymer degradation using AI
558
+ **Input**: Raman spectroscopy `.txt` files
559
+ **Models**: CNN architectures for binary classification
560
+ **Next**: More trained CNNs in evaluation pipeline
561
 
562
+ ---
563
 
564
+ **Contributors**
565
+ Dr. Sanmukh Kuppannagari (Mentor)
566
+ Dr. Metin Karailyan (Mentor)
567
+ 👨‍💻 Jaser Hasan (Author)
 
568
 
569
+ ---
 
570
 
571
+ **Links**
572
+ 🔗 [Live HF Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
573
+ 📂 [GitHub Repository](https://github.com/KLab-AI3/ml-polymer-recycling)
574
 
575
+ ---
 
 
 
 
 
 
576
 
577
+ **Citation Figure2CNN (baseline)**
578
+ Neo et al., 2023, *Resour. Conserv. Recycl.*, 188, 106718.
579
+ [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718)
580
+ """)
 
 
 
 
 
 
581
 
582
  # Main content area
583
+ col1, col2 = st.columns([1, 1.35], gap="small")
584
 
585
  with col1:
586
+ st.markdown("##### Data Input")
587
 
588
  mode = st.radio(
589
  "Input mode",
 
619
  st.session_state["status_type"] = "success"
620
 
621
  if up:
622
+ st.markdown(f"✅ Loaded: {up.name}")
623
 
624
  # ---- Sample tab ----
625
  else:
 
634
  on_change=on_sample_change, # <-- critical
635
  )
636
  if sel != "-- Select Sample --":
637
+ st.markdown(f"✅ Loaded sample: {sel}")
638
  else:
639
  st.info("No sample data available")
640
 
641
  # ---- Status box ----
642
+ st.markdown("##### Status")
643
  msg = st.session_state.get("status_message", "Ready")
644
  typ = st.session_state.get("status_type", "info")
645
  if typ == "success":
 
688
  r1, r2 = resample_spectrum(x_raw, y_raw, TARGET_LEN)
689
 
690
  def _is_strictly_increasing(a):
691
+ a = np.asarray(a)
692
+ return a.ndim == 1 and a.size >= 2 and np.all(np.diff(a) > 0)
 
 
 
693
 
694
  if _is_strictly_increasing(r1) and not _is_strictly_increasing(r2):
695
  x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
 
724
  # Results column
725
  with col2:
726
  if st.session_state.get("inference_run_once", False):
727
+ st.markdown("##### Analysis Results")
728
 
729
  # Get data from session state
730
  x_raw = st.session_state.get('x_raw')
 
782
  predicted_class = LABEL_MAP.get(
783
  int(prediction), f"Class {int(prediction)}")
784
 
785
+ # === confidence metrics ===
786
  logit_margin = abs(
787
  logits_list[0] - logits_list[1]) if len(logits_list) >= 2 else 0
788
  confidence_desc, confidence_emoji = get_confidence_description(
789
  logit_margin)
790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  # ===Detailed results tabs===
792
  tab1, tab2, tab3 = st.tabs(
793
+ ["Details", "Technical", "Explanation"])
794
 
795
  with tab1:
796
+ # Main prediction
797
+ st.markdown(f"""
798
+ **Sample**: `{filename}`
799
+ **Model**: `{model_choice}`
800
+ **Processing Time**: `{inference_time:.2f}s`
801
+ """)
802
+
803
+ # ===Prediction box && Confidence Margin===
804
+ with st.expander("Prediction/Ground Truth & Model Confidence Margin", expanded=False):
805
+ if predicted_class == "Stable (Unweathered)":
806
+ st.markdown(f"🟢 **Prediction**: {predicted_class}")
807
+ else:
808
+ st.markdown(f"🟡 **Prediction**: {predicted_class}")
809
+ st.markdown(
810
+ f"**{confidence_emoji} Confidence**: {confidence_desc} (margin: {logit_margin:.1f})")
811
+ # Ground truth comparison
812
+ if true_label_idx is not None:
813
+ if predicted_class == true_label_str:
814
+ st.markdown(
815
+ f"✅ **Ground Truth**: {true_label_str} - **Correct!**")
816
+ else:
817
+ st.markdown(
818
+ f"❌ **Ground Truth**: {true_label_str} - **Incorrect**")
819
+ else:
820
+ st.markdown(
821
+ "**Ground Truth**: Unknown (filename doesn't follow naming convention)")
822
+
823
+ st.markdown("###### Confidence Overview")
824
+ render_confidence_progress(
825
+ probs,
826
+ labels=["Stable", "Weathered"],
827
+ highlight_idx=int(prediction),
828
+ side_by_side=True, # Set false for stacked <<
829
+ )
830
+
831
 
832
  with tab2:
833
+ with st.expander("Diagnostics/Technical Info (advanced)", expanded=False):
834
+ st.markdown("###### Model Output (Logits)")
835
+ cols = st.columns(2)
836
+ for i, score in enumerate(logits_list):
837
+ label = LABEL_MAP.get(i, f"Class {i}")
838
+ (cols[i % 2]).metric(label, f"{score:.2f}")
839
+
840
+ st.markdown("###### Spectrum Statistics")
841
+ spec_stats = {
842
+ "Original Length": len(x_raw) if x_raw is not None else 0,
843
+ "Resampled Length": TARGET_LEN,
844
+ "Wavenumber Range": f"{min(x_raw):.1f}–{max(x_raw):.1f} cm⁻¹" if x_raw is not None else "N/A",
845
+ "Intensity Range": f"{min(y_raw):.1f}–{max(y_raw):.1f}" if y_raw is not None else "N/A",
846
+ "Confidence Bucket": confidence_desc,
847
+ }
848
+ render_kv_grid(spec_stats, ncols=2)
849
+ st.markdown("---")
850
+
851
+ st.markdown("###### Model Statistics")
852
+ model_path = MODEL_CONFIG[model_choice]["path"]
853
+ mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
854
+ file_hash = (
855
+ hashlib.md5(open(model_path, 'rb').read()).hexdigest()
856
+ if os.path.exists(model_path) else "N/A"
857
+ )
858
+ model_stats = {
859
+ "Architecture": model_choice,
860
+ "Model Path": model_path,
861
+ "Weights Last Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime else "N/A",
862
+ "Weights Hash (md5)": file_hash,
863
+ "Input Shape": list(input_tensor.shape),
864
+ "Output Shape": list(logits.shape),
865
+ "Inference Time": f"{inference_time:.3f}s",
866
+ "Device": "CPU",
867
+ "Model Loaded": model_loaded,
868
+ }
869
+ render_kv_grid(model_stats, ncols=2)
870
+
871
+ st.markdown("---")
872
+
873
+
874
+ st.markdown("###### Debug Log")
875
+ st.text_area("Logs", "\n".join(st.session_state.get("log_messages", [])), height=110)
876
+
 
 
 
 
 
 
 
 
 
877
 
878
  with tab3:
879
  st.markdown("""
 
906
  st.error(
907
  "❌ Missing spectrum data. Please upload a file and run analysis.")
908
  else:
909
+ # ===Getting Started===
910
  st.markdown("""
911
+ ##### Get started by:
912
+ 1. Select an AI model in the sidebar
913
+ 2. Upload a Raman spectrum file or choose a sample
914
+ 3. Click "Run Analysis" to get predictions
 
 
915
 
916
+ ##### Supported formats:
917
  - Text files (.txt) with wavenumber and intensity columns
918
  - Space or comma-separated values
919
  - Any length (automatically resampled to 500 points)
920
 
921
+ ##### Example applications:
922
  - 🔬 Research on polymer degradation
923
  - ♻️ Recycling feasibility assessment
924
  - 🌱 Sustainability impact studies