devjas1 commited on
Commit
27f8f90
·
1 Parent(s): 9e5cab1

(UI/UX): stabilize Stage 2 dashboard -

Browse files

- Removed index= override from results selectbox >> fixed UX bug (instant updates on first click)
- Standardized font sizes across sidebar metrics, expander content, and diagnostics
- Finalized expander redesign with badges + removed disclosure arrows
- Improved error handling for inference edge cases
- Achieved [GRN] baseline for UI - fully operational and polished

Files changed (1) hide show
  1. app.py +147 -143
app.py CHANGED
@@ -58,7 +58,7 @@ div[data-testid="stTabs"] > div[role="tablist"] + div { min-height: 420px; }
58
  /* Clean key–value rows for technical info */
59
  .kv-row { display:flex; justify-content:space-between;
60
  border-bottom: 1px dotted rgba(0,0,0,.10); padding: 3px 0; gap: 12px; }
61
- .kv-key { opacity:.75; font-size: 0.92rem; white-space: nowrap; }
62
  .kv-val { font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
63
  overflow-wrap: anywhere; }
64
 
@@ -79,7 +79,7 @@ div.stExpander > details > summary {
79
  margin: 6px 0;
80
  background: rgba(0,0,0,0.04);
81
  font-weight: 600;
82
- font-size: 0.92rem;
83
  }
84
 
85
  /* Remove ugly default disclosure triangle */
@@ -113,16 +113,18 @@ div.stExpander > details > summary::after {
113
  color: #111827;
114
  }
115
 
116
- /* === Variants by Keyword === */
117
- div.stExpander:has(summary:contains("Prediction")) > details > summary {
118
  border-left-color: #2e7d32;
119
  background: rgba(46,125,50,0.08);
120
  }
121
- div.stExpander:has(summary:contains("Prediction")) > details > summary::after {
122
  content: "RESULTS";
123
- background: rgba(46,125,50,0.15); color: #184a1d;
 
124
  }
125
 
 
126
  div.stExpander:has(summary:contains("Technical")) > details > summary {
127
  border-left-color: #ed6c02;
128
  background: rgba(237,108,2,0.08);
@@ -145,7 +147,7 @@ div[data-testid="stMetricLabel"] {
145
 
146
  /* Sidebar expander text */
147
  section[data-testid="stSidebar"] .stMarkdown p {
148
- font-size: 0.92rem !important;
149
  line-height: 1.4;
150
  }
151
 
@@ -211,6 +213,7 @@ def init_session_state():
211
  "log_messages": [],
212
  "uploader_version": 0,
213
  "current_upload_key": "upload_txt_0",
 
214
  }
215
  for k, v in defaults.items():
216
  st.session_state.setdefault(k, v)
@@ -285,6 +288,27 @@ def cleanup_memory():
285
  if torch.cuda.is_available():
286
  torch.cuda.empty_cache()
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  @st.cache_data
290
  def get_sample_files():
@@ -341,8 +365,8 @@ def parse_spectrum_data(raw_text):
341
 
342
  return x, y
343
 
344
-
345
- def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
346
  """Create spectrum visualization plot"""
347
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
348
 
@@ -370,15 +394,13 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
370
  plt.close(fig) # Prevent memory leaks
371
 
372
  return Image.open(buf)
373
-
374
- def _pct(p: float) -> str:
375
- # Fixed-width percent like " 98.7%" or " 2.3%"
376
- return f"{float(p)*100:5.1f}%"
377
 
378
  def render_confidence_progress(
379
  probs: np.ndarray,
380
  labels: list[str] = ["Stable", "Weathered"],
381
- highlight_idx: int | None = None,
382
  side_by_side: bool = True
383
  ):
384
  """Render Streamlit native progress bars (0 - 100). Optionally bold the winning class
@@ -402,10 +424,6 @@ def render_confidence_progress(
402
  st.progress(int(round(val * 100)))
403
 
404
 
405
-
406
-
407
-
408
-
409
  def render_kv_grid(d: dict, ncols: int = 2):
410
  """Display dict as a clean grid of key/value rows."""
411
  if not d:
@@ -731,81 +749,83 @@ def main():
731
  filename = st.session_state.get('filename', 'Unknown')
732
 
733
  if all(v is not None for v in [x_raw, y_raw, y_resampled]):
 
734
 
735
- # Create and display plot
736
- try:
737
- spectrum_plot = create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled)
738
- st.image(
739
- spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
740
- except (ValueError, RuntimeError, TypeError) as e:
741
- st.warning(f"Could not generate plot: {e}")
742
- log_message(f"Plot generation error: {e}")
743
-
744
- # Run inference
745
- try:
746
- with st.spinner("Running AI inference..."):
747
- start_time = time.time()
748
-
749
- # Prepare input tensor
750
- input_tensor = torch.tensor(
751
- y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
752
-
753
- # Run inference
754
- model.eval()
755
- with torch.no_grad():
756
- if model is None:
757
- raise ValueError(
758
- "Model is not loaded. Please check the model configuration or weights.")
759
- logits = model(input_tensor)
760
- prediction = torch.argmax(logits, dim=1).item()
761
- logits_list = logits.detach().numpy().tolist()[0]
762
-
763
- probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
764
-
765
-
766
- inference_time = time.time() - start_time
767
- log_message(
768
- f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
769
-
770
- # Clean up memory
771
- cleanup_memory()
772
-
773
- # Get ground truth if available
774
- true_label_idx = label_file(filename)
775
- true_label_str = LABEL_MAP.get(
776
- true_label_idx, "Unknown") if true_label_idx is not None else "Unknown"
777
-
778
- # Get prediction
779
- predicted_class = LABEL_MAP.get(
780
- int(prediction), f"Class {int(prediction)}")
781
-
782
- # === confidence metrics ===
783
- logit_margin = abs(
784
- logits_list[0] - logits_list[1]) if len(logits_list) >= 2 else 0
785
- confidence_desc, confidence_emoji = get_confidence_description(
786
- logit_margin)
787
-
788
- # ===Detailed results tabs===
789
- tab1, tab2, tab3 = st.tabs(
790
- ["Details", "Technical", "Explanation"])
791
-
792
- with tab1:
793
- # Main prediction
 
 
 
794
  st.markdown(f"""
795
  **Sample**: `{filename}`
796
  **Model**: `{model_choice}`
797
  **Processing Time**: `{inference_time:.2f}s`
798
  """)
799
-
800
- # ===Prediction box && Confidence Margin===
801
- with st.expander("Prediction/Ground Truth & Model Confidence Margin", expanded=False):
802
  if predicted_class == "Stable (Unweathered)":
803
  st.markdown(f"🟢 **Prediction**: {predicted_class}")
804
  else:
805
  st.markdown(f"🟡 **Prediction**: {predicted_class}")
806
  st.markdown(
807
  f"**{confidence_emoji} Confidence**: {confidence_desc} (margin: {logit_margin:.1f})")
808
- # Ground truth comparison
809
  if true_label_idx is not None:
810
  if predicted_class == true_label_str:
811
  st.markdown(
@@ -819,85 +839,69 @@ def main():
819
 
820
  st.markdown("###### Confidence Overview")
821
  render_confidence_progress(
822
- probs,
823
  labels=["Stable", "Weathered"],
824
  highlight_idx=int(prediction),
825
  side_by_side=True, # Set false for stacked <<
826
  )
827
-
828
 
829
- with tab2:
830
- with st.expander("Diagnostics/Technical Info (advanced)", expanded=False):
 
 
831
  st.markdown("###### Model Output (Logits)")
832
  cols = st.columns(2)
833
- for i, score in enumerate(logits_list):
834
- label = LABEL_MAP.get(i, f"Class {i}")
835
- (cols[i % 2]).metric(label, f"{score:.2f}")
836
-
837
  st.markdown("###### Spectrum Statistics")
838
- spec_stats = {
839
- "Original Length": len(x_raw) if x_raw is not None else 0,
840
- "Resampled Length": TARGET_LEN,
841
- "Wavenumber Range": f"{min(x_raw):.1f}–{max(x_raw):.1f} cm⁻¹" if x_raw is not None else "N/A",
842
- "Intensity Range": f"{min(y_raw):.1f}–{max(y_raw):.1f}" if y_raw is not None else "N/A",
843
- "Confidence Bucket": confidence_desc,
844
- }
845
  render_kv_grid(spec_stats, ncols=2)
846
  st.markdown("---")
847
-
848
  st.markdown("###### Model Statistics")
849
- model_path = MODEL_CONFIG[model_choice]["path"]
850
- mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
851
- file_hash = (
852
- hashlib.md5(open(model_path, 'rb').read()).hexdigest()
853
- if os.path.exists(model_path) else "N/A"
854
- )
855
- model_stats = {
856
- "Architecture": model_choice,
857
- "Model Path": model_path,
858
- "Weights Last Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime else "N/A",
859
- "Weights Hash (md5)": file_hash,
860
- "Input Shape": list(input_tensor.shape),
861
- "Output Shape": list(logits.shape),
862
- "Inference Time": f"{inference_time:.3f}s",
863
- "Device": "CPU",
864
- "Model Loaded": model_loaded,
865
- }
866
  render_kv_grid(model_stats, ncols=2)
867
-
868
  st.markdown("---")
869
-
870
-
871
  st.markdown("###### Debug Log")
872
  st.text_area("Logs", "\n".join(st.session_state.get("log_messages", [])), height=110)
873
 
874
-
875
- with tab3:
876
- st.markdown("""
877
- **🔍 Analysis Process**
878
-
879
- 1. **Data Upload**: Raman spectrum file loaded
880
- 2. **Preprocessing**: Data parsed and resampled to 500 points
881
- 3. **AI Inference**: CNN model analyzes spectral patterns
882
- 4. **Classification**: Binary prediction with confidence scores
883
-
884
- **🧠 Model Interpretation**
885
-
886
- The AI model identifies spectral features indicative of:
887
- - **Stable polymers**: Well-preserved molecular structure
888
- - **Weathered polymers**: Degraded/oxidized molecular bonds
889
 
890
- **🎯 Applications**
891
-
892
- - Material longevity assessment
893
- - Recycling viability evaluation
894
- - Quality control in manufacturing
895
- - Environmental impact studies
896
- """)
897
-
898
- except (ValueError, RuntimeError) as e:
899
- st.error(f"❌ Inference failed: {str(e)}")
900
- log_message(f"Inference error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
901
 
902
  else:
903
  st.error(
 
58
  /* Clean key–value rows for technical info */
59
  .kv-row { display:flex; justify-content:space-between;
60
  border-bottom: 1px dotted rgba(0,0,0,.10); padding: 3px 0; gap: 12px; }
61
+ .kv-key { opacity:.75; font-size: 0.95rem; white-space: nowrap; }
62
  .kv-val { font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
63
  overflow-wrap: anywhere; }
64
 
 
79
  margin: 6px 0;
80
  background: rgba(0,0,0,0.04);
81
  font-weight: 600;
82
+ font-size: 0.95rem;
83
  }
84
 
85
  /* Remove ugly default disclosure triangle */
 
113
  color: #111827;
114
  }
115
 
116
+ /* === Stable cross-browser expander behavior === */
117
+ .expander-marker + div[data-testid="stExpander"] summary {
118
  border-left-color: #2e7d32;
119
  background: rgba(46,125,50,0.08);
120
  }
121
+ .expander-marker + div[data-testid="stExpander"] summary::after {
122
  content: "RESULTS";
123
+ background: rgba(46,125,50,0.15);
124
+ color: #184a1d;
125
  }
126
 
127
+
128
  div.stExpander:has(summary:contains("Technical")) > details > summary {
129
  border-left-color: #ed6c02;
130
  background: rgba(237,108,2,0.08);
 
147
 
148
  /* Sidebar expander text */
149
  section[data-testid="stSidebar"] .stMarkdown p {
150
+ font-size: 0.95rem !important;
151
  line-height: 1.4;
152
  }
153
 
 
213
  "log_messages": [],
214
  "uploader_version": 0,
215
  "current_upload_key": "upload_txt_0",
216
+ "active_tab": "Details",
217
  }
218
  for k, v in defaults.items():
219
  st.session_state.setdefault(k, v)
 
288
  if torch.cuda.is_available():
289
  torch.cuda.empty_cache()
290
 
291
+ @st.cache_data
292
+ def run_inference(y_resampled, model_choice, _cache_key=None):
293
+ """Run model inference and cache results"""
294
+ model, model_loaded = load_model(model_choice)
295
+ if not model_loaded:
296
+ return None, None, None, None, None
297
+
298
+ input_tensor = torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
299
+ start_time = time.time()
300
+ model.eval()
301
+ with torch.no_grad():
302
+ if model is None:
303
+ raise ValueError("Model is not loaded. Please check the model configuration or weights.")
304
+ logits = model(input_tensor)
305
+ prediction = torch.argmax(logits, dim=1).item()
306
+ logits_list = logits.detach().numpy().tolist()[0]
307
+ probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
308
+ inference_time = time.time() - start_time
309
+ cleanup_memory()
310
+ return prediction, logits_list, probs, inference_time, logits
311
+
312
 
313
  @st.cache_data
314
  def get_sample_files():
 
365
 
366
  return x, y
367
 
368
+ @st.cache_data
369
+ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None):
370
  """Create spectrum visualization plot"""
371
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
372
 
 
394
  plt.close(fig) # Prevent memory leaks
395
 
396
  return Image.open(buf)
397
+
398
+ from typing import Union
 
 
399
 
400
  def render_confidence_progress(
401
  probs: np.ndarray,
402
  labels: list[str] = ["Stable", "Weathered"],
403
+ highlight_idx: Union[int, None] = None,
404
  side_by_side: bool = True
405
  ):
406
  """Render Streamlit native progress bars (0 - 100). Optionally bold the winning class
 
424
  st.progress(int(round(val * 100)))
425
 
426
 
 
 
 
 
427
  def render_kv_grid(d: dict, ncols: int = 2):
428
  """Display dict as a clean grid of key/value rows."""
429
  if not d:
 
749
  filename = st.session_state.get('filename', 'Unknown')
750
 
751
  if all(v is not None for v in [x_raw, y_raw, y_resampled]):
752
+ # ===Run inference===
753
 
754
+ if y_resampled is None:
755
+ raise ValueError("y_resampled is None. Ensure spectrum data is properly resampled before proceeding.")
756
+ cache_key = hashlib.md5(f"{y_resampled.tobytes()}{model_choice}".encode()).hexdigest()
757
+ prediction, logits_list, probs, inference_time, logits = run_inference(
758
+ y_resampled, model_choice, _cache_key=cache_key
759
+ )
760
+ if prediction is None:
761
+ st.error(" Inference failed: Model not loaded. Please check that weights are available.")
762
+ st.stop() # prevents the rest of the code in this block from executing
763
+
764
+ log_message(f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
765
+
766
+ # ===Get ground truth===
767
+ true_label_idx = label_file(filename)
768
+ true_label_str = LABEL_MAP.get(
769
+ true_label_idx, "Unknown") if true_label_idx is not None else "Unknown"
770
+ # ===Get prediction===
771
+ predicted_class = LABEL_MAP.get(
772
+ int(prediction), f"Class {int(prediction)}")
773
+ # === confidence metrics ===
774
+ logit_margin = abs(
775
+ (logits_list[0] - logits_list[1]) if logits_list is not None and len(logits_list) >= 2 else 0
776
+ )
777
+ confidence_desc, confidence_emoji = get_confidence_description(logit_margin)
778
+
779
+ #===Precompute Stats===
780
+ spec_stats = {
781
+ "Original Length": len(x_raw) if x_raw is not None else 0,
782
+ "Resampled Length": TARGET_LEN,
783
+ "Wavenumber Range": f"{min(x_raw):.1f}-{max(x_raw):.1f} cm⁻¹" if x_raw is not None else "N/A",
784
+ "Intensity Range": f"{min(y_raw):.1f}-{max(y_raw):.1f} cm⁻¹" if y_raw is not None else "N/A",
785
+ "Confidence Bucket": confidence_desc,
786
+ }
787
+ model_path = MODEL_CONFIG[model_choice]["path"]
788
+ mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
789
+ file_hash = (
790
+ hashlib.md5(open(model_path, 'rb').read()).hexdigest()
791
+ if os.path.exists(model_path) else "N/A"
792
+ )
793
+ input_tensor = torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
794
+ model_stats = {
795
+ "Architecture": model_choice,
796
+ "Model Path": model_path,
797
+ "Weights Last Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime else "N/A",
798
+ "Weights Hash (md5)": file_hash,
799
+ "Input Shape": list(input_tensor.shape),
800
+ "Output Shape": list(logits.shape) if logits is not None else "N/A",
801
+ "Inference Time": f"{inference_time:.3f}s",
802
+ "Device": "CPU",
803
+ "Model Loaded": model_loaded,
804
+ }
805
+
806
+ start_render = time.time()
807
+
808
+ active_tab = st.selectbox(
809
+ "View Results",
810
+ ["Details", "Technical", "Explanation"],
811
+ key="active_tab", # reuse the key you were managing manually
812
+ )
813
+
814
+ if active_tab == "Details":
815
+ with st.container():
816
  st.markdown(f"""
817
  **Sample**: `{filename}`
818
  **Model**: `{model_choice}`
819
  **Processing Time**: `{inference_time:.2f}s`
820
  """)
821
+ st.markdown("<div class='expander-marker expander-success'></div>", unsafe_allow_html=True)
822
+ with st.expander("Prediction/Ground Truth & Model Confidence Margin", expanded=True):
 
823
  if predicted_class == "Stable (Unweathered)":
824
  st.markdown(f"🟢 **Prediction**: {predicted_class}")
825
  else:
826
  st.markdown(f"🟡 **Prediction**: {predicted_class}")
827
  st.markdown(
828
  f"**{confidence_emoji} Confidence**: {confidence_desc} (margin: {logit_margin:.1f})")
 
829
  if true_label_idx is not None:
830
  if predicted_class == true_label_str:
831
  st.markdown(
 
839
 
840
  st.markdown("###### Confidence Overview")
841
  render_confidence_progress(
842
+ probs if probs is not None else np.array([]),
843
  labels=["Stable", "Weathered"],
844
  highlight_idx=int(prediction),
845
  side_by_side=True, # Set false for stacked <<
846
  )
 
847
 
848
+ elif active_tab == "Technical":
849
+ with st.container():
850
+ st.markdown("<div class='expander-marker expander-success'></div>", unsafe_allow_html=True)
851
+ with st.expander("Diagnostics/Technical Info (advanced)", expanded=True):
852
  st.markdown("###### Model Output (Logits)")
853
  cols = st.columns(2)
854
+ if logits_list is not None:
855
+ for i, score in enumerate(logits_list):
856
+ label = LABEL_MAP.get(i, f"Class {i}")
857
+ cols[i % 2].metric(label, f"{score:.2f}")
858
  st.markdown("###### Spectrum Statistics")
 
 
 
 
 
 
 
859
  render_kv_grid(spec_stats, ncols=2)
860
  st.markdown("---")
 
861
  st.markdown("###### Model Statistics")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  render_kv_grid(model_stats, ncols=2)
 
863
  st.markdown("---")
 
 
864
  st.markdown("###### Debug Log")
865
  st.text_area("Logs", "\n".join(st.session_state.get("log_messages", [])), height=110)
866
 
867
+ elif active_tab == "Explanation":
868
+ with st.container():
869
+ st.markdown("""
870
+ **🔍 Analysis Process**
 
 
 
 
 
 
 
 
 
 
 
871
 
872
+ 1. **Data Upload**: Raman spectrum file loaded
873
+ 2. **Preprocessing**: Data parsed and resampled to 500 points
874
+ 3. **AI Inference**: CNN model analyzes spectral patterns
875
+ 4. **Classification**: Binary prediction with confidence scores
876
+
877
+ **🧠 Model Interpretation**
878
+
879
+ The AI model identifies spectral features indicative of:
880
+ - **Stable polymers**: Well-preserved molecular structure
881
+ - **Weathered polymers**: Degraded/oxidized molecular bonds
882
+
883
+ **🎯 Applications**
884
+
885
+ - Material longevity assessment
886
+ - Recycling viability evaluation
887
+ - Quality control in manufacturing
888
+ - Environmental impact studies
889
+ """)
890
+
891
+ render_time = time.time() - start_render
892
+ log_message(f"col2 rendered in {render_time:.2f}s, active tab: {active_tab}")
893
+
894
+ st.markdown("<div class='expander-marker expander-success'></div>", unsafe_allow_html=True)
895
+ with st.expander("Spectrum Preprocessing Results", expanded=False):
896
+ # Create and display plot
897
+ cache_key = hashlib.md5(
898
+ f"{(x_raw.tobytes() if x_raw is not None else b'')}"
899
+ f"{(y_raw.tobytes() if y_raw is not None else b'')}"
900
+ f"{(x_resampled.tobytes() if x_resampled is not None else b'')}"
901
+ f"{(y_resampled.tobytes() if y_resampled is not None else b'')}".encode()
902
+ ).hexdigest()
903
+ spectrum_plot = create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=cache_key)
904
+ st.image(spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
905
 
906
  else:
907
  st.error(