devjas1 commited on
Commit
222f7ff
·
1 Parent(s): 71b3dbd

(FEAT)[UI/UX]: Add support for FTIR, multi-format upload, and model comparison tab

Browse files

Sidebar:
- Added spectroscopy modality selection (Raman/FTIR) with explanatory info for each.
- Expanded model selection and improved project description to reflect FTIR and multi-model features.

Input column:
- File uploader now accepts .txt, .csv, and .json for single and batch uploads.
- Updated help text and file type validation.

New function 'render_comparison_tab':
- Allows users to select multiple models and upload/choose sample data for side-by-side prediction.
- Displays comparison results in tables and visualizations (confidence bar chart, agreement stats, performance metrics).
- Supports exporting results in JSON/full report formats.
- Shows historical comparison statistics with agreement matrix and heatmap.

New function render_performance_tab:
- Integrates performance dashboard from tracker utility.

Files changed (1) hide show
  1. modules/ui_components.py +478 -51
modules/ui_components.py CHANGED
@@ -13,9 +13,9 @@ from modules.callbacks import (
13
  on_model_change,
14
  on_input_mode_change,
15
  on_sample_change,
 
16
  reset_ephemeral_state,
17
  log_message,
18
- clear_batch_results,
19
  )
20
  from core_logic import (
21
  get_sample_files,
@@ -24,7 +24,6 @@ from core_logic import (
24
  parse_spectrum_data,
25
  label_file,
26
  )
27
- from modules.callbacks import reset_results
28
  from utils.results_manager import ResultsManager
29
  from utils.confidence import calculate_softmax_confidence
30
  from utils.multifile import process_multiple_files, display_batch_results
@@ -41,7 +40,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
41
  """Create spectrum visualization plot"""
42
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
43
 
44
- # == Raw spectrum ==
45
  ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
46
  ax[0].set_title("Raw Input Spectrum")
47
  ax[0].set_xlabel("Wavenumber (cm⁻¹)")
@@ -49,7 +48,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
49
  ax[0].grid(True, alpha=0.3)
50
  ax[0].legend()
51
 
52
- # == Resampled spectrum ==
53
  ax[1].plot(
54
  x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
55
  )
@@ -60,7 +59,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
60
  ax[1].legend()
61
 
62
  fig.tight_layout()
63
- # == Convert to image ==
64
  buf = io.BytesIO()
65
  plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
66
  buf.seek(0)
@@ -69,6 +68,9 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
69
  return Image.open(buf)
70
 
71
 
 
 
 
72
  def render_confidence_progress(
73
  probs: np.ndarray,
74
  labels: list[str] = ["Stable", "Weathered"],
@@ -114,7 +116,10 @@ def render_confidence_progress(
114
  st.markdown("")
115
 
116
 
117
- def render_kv_grid(d: dict = {}, ncols: int = 2):
 
 
 
118
  if d is None:
119
  d = {}
120
  if not d:
@@ -126,6 +131,9 @@ def render_kv_grid(d: dict = {}, ncols: int = 2):
126
  st.caption(f"**{k}:** {v}")
127
 
128
 
 
 
 
129
  def render_model_meta(model_choice: str):
130
  info = MODEL_CONFIG.get(model_choice, {})
131
  emoji = info.get("emoji", "")
@@ -143,6 +151,9 @@ def render_model_meta(model_choice: str):
143
  st.caption(desc)
144
 
145
 
 
 
 
146
  def get_confidence_description(logit_margin):
147
  """Get human-readable confidence description"""
148
  if logit_margin > 1000:
@@ -155,13 +166,35 @@ def get_confidence_description(logit_margin):
155
  return "LOW", "🔴"
156
 
157
 
 
 
 
158
  def render_sidebar():
159
  with st.sidebar:
160
  # Header
161
  st.header("AI-Driven Polymer Classification")
162
  st.caption(
163
- "Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.1"
 
 
 
 
 
 
 
 
 
 
164
  )
 
 
 
 
 
 
 
 
 
165
  model_labels = [
166
  f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
167
  ]
@@ -173,10 +206,10 @@ def render_sidebar():
173
  )
174
  model_choice = selected_label.split(" ", 1)[1]
175
 
176
- # ===Compact metadata directly under dropdown===
177
  render_model_meta(model_choice)
178
 
179
- # ===Collapsed info to reduce clutter===
180
  with st.expander("About This App", icon=":material/info:", expanded=False):
181
  st.markdown(
182
  """
@@ -184,8 +217,9 @@ def render_sidebar():
184
 
185
  **Purpose**: Classify polymer degradation using AI<br>
186
  **Input**: Raman spectroscopy .txt files<br>
187
- **Models**: CNN architectures for binary classification<br>
188
- **Next**: More trained CNNs in evaluation pipeline<br>
 
189
 
190
 
191
  **Contributors**<br>
@@ -207,11 +241,7 @@ def render_sidebar():
207
  )
208
 
209
 
210
- # col1 goes here
211
-
212
- # In modules/ui_components.py
213
-
214
-
215
  def render_input_column():
216
  st.markdown("##### Data Input")
217
 
@@ -224,22 +254,20 @@ def render_input_column():
224
  )
225
 
226
  # == Input Mode Logic ==
227
- # ... (The if/elif/else block for Upload, Batch, and Sample modes remains exactly the same) ...
228
- # ==Upload tab==
229
  if mode == "Upload File":
230
  upload_key = st.session_state["current_upload_key"]
231
  up = st.file_uploader(
232
- "Upload Raman spectrum (.txt)",
233
- type="txt",
234
- help="Upload a text file with wavenumber and intensity columns",
235
  key=upload_key, # ← versioned key
236
  )
237
 
238
- # ==Process change immediately (no on_change; simpler & reliable)==
239
  if up is not None:
240
  raw = up.read()
241
  text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
242
- # == only reparse if its a different file|source ==
243
  if (
244
  st.session_state.get("filename") != getattr(up, "name", None)
245
  or st.session_state.get("input_source") != "upload"
@@ -255,23 +283,20 @@ def render_input_column():
255
  st.session_state["status_type"] = "success"
256
  reset_results("New file uploaded")
257
 
258
- # ==Batch Upload tab==
259
  elif mode == "Batch Upload":
260
  st.session_state["batch_mode"] = True
261
- # --- START: BUG 1 & 3 FIX ---
262
  # Use a versioned key to ensure the file uploader resets properly.
263
  batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
264
  uploaded_files = st.file_uploader(
265
- "Upload multiple Raman spectrum files (.txt)",
266
- type="txt",
267
  accept_multiple_files=True,
268
- help="Upload one or more text files with wavenumber and intensity columns.",
269
  key=batch_upload_key,
270
  )
271
- # --- END: BUG 1 & 3 FIX ---
272
 
273
  if uploaded_files:
274
- # --- START: Bug 1 Fix ---
275
  # Use a dictionary to keep only unique files based on name and size
276
  unique_files = {(file.name, file.size): file for file in uploaded_files}
277
  unique_file_list = list(unique_files.values())
@@ -281,9 +306,7 @@ def render_input_column():
281
 
282
  # Optionally, inform the user that duplicates were removed
283
  if num_uploaded > num_unique:
284
- st.info(
285
- f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed."
286
- )
287
 
288
  # Use the unique list
289
  st.session_state["batch_files"] = unique_file_list
@@ -291,7 +314,6 @@ def render_input_column():
291
  f"{num_unique} ready for batch analysis"
292
  )
293
  st.session_state["status_type"] = "success"
294
- # --- END: Bug 1 Fix ---
295
  else:
296
  st.session_state["batch_files"] = []
297
  # This check prevents resetting the status if files are already staged
@@ -301,7 +323,7 @@ def render_input_column():
301
  )
302
  st.session_state["status_type"] = "info"
303
 
304
- # ==Sample tab==
305
  elif mode == "Sample Data":
306
  st.session_state["batch_mode"] = False
307
  sample_files = get_sample_files()
@@ -330,9 +352,6 @@ def render_input_column():
330
  else:
331
  st.info(msg)
332
 
333
- # --- DE-NESTED LOGIC STARTS HERE ---
334
- # This code now runs on EVERY execution, guaranteeing the buttons will appear.
335
-
336
  # Safely get model choice from session state
337
  model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
338
  model = load_model(model_choice)
@@ -388,7 +407,7 @@ def render_input_column():
388
  st.error(f"Error processing spectrum data: {e}")
389
 
390
 
391
- # col2 goes here
392
 
393
 
394
  def render_results_column():
@@ -410,7 +429,7 @@ def render_results_column():
410
  filename = st.session_state.get("filename", "Unknown")
411
 
412
  if all(v is not None for v in [x_raw, y_raw, y_resampled]):
413
- # ===Run inference===
414
  if y_resampled is None:
415
  raise ValueError(
416
  "y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
@@ -437,14 +456,14 @@ def render_results_column():
437
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
438
  )
439
 
440
- # ===Get ground truth===
441
  true_label_idx = label_file(filename)
442
  true_label_str = (
443
  LABEL_MAP.get(true_label_idx, "Unknown")
444
  if true_label_idx is not None
445
  else "Unknown"
446
  )
447
- # ===Get prediction===
448
  predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
449
 
450
  # Enhanced confidence calculation
@@ -455,7 +474,7 @@ def render_results_column():
455
  )
456
  confidence_desc = confidence_level
457
  else:
458
- # Fallback to legace method
459
  logit_margin = abs(
460
  (logits_list[0] - logits_list[1])
461
  if logits_list is not None and len(logits_list) >= 2
@@ -487,7 +506,7 @@ def render_results_column():
487
  },
488
  )
489
 
490
- # ===Precompute Stats===
491
  model_choice = (
492
  st.session_state.get("model_select", "").split(" ", 1)[1]
493
  if "model_select" in st.session_state
@@ -505,7 +524,6 @@ def render_results_column():
505
  if os.path.exists(model_path)
506
  else "N/A"
507
  )
508
- # Removed unused variable 'input_tensor'
509
 
510
  start_render = time.time()
511
 
@@ -590,17 +608,13 @@ def render_results_column():
590
  """,
591
  unsafe_allow_html=True,
592
  )
593
- # --- END: CONSOLIDATED CONFIDENCE ANALYSIS ---
594
 
595
  st.divider()
596
 
597
- # --- START: CLEAN METADATA FOOTER ---
598
- # Secondary info is now a clean, single-line caption
599
  st.caption(
600
  f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
601
  )
602
- # --- END: CLEAN METADATA FOOTER ---
603
-
604
  st.markdown("</div>", unsafe_allow_html=True)
605
 
606
  elif active_tab == "Technical":
@@ -918,7 +932,7 @@ def render_results_column():
918
  """
919
  )
920
  else:
921
- # ===Getting Started===
922
  st.markdown(
923
  """
924
  ##### How to Get Started
@@ -948,3 +962,416 @@ def render_results_column():
948
  - 🏭 Quality control in manufacturing
949
  """
950
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  on_model_change,
14
  on_input_mode_change,
15
  on_sample_change,
16
+ reset_results,
17
  reset_ephemeral_state,
18
  log_message,
 
19
  )
20
  from core_logic import (
21
  get_sample_files,
 
24
  parse_spectrum_data,
25
  label_file,
26
  )
 
27
  from utils.results_manager import ResultsManager
28
  from utils.confidence import calculate_softmax_confidence
29
  from utils.multifile import process_multiple_files, display_batch_results
 
40
  """Create spectrum visualization plot"""
41
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
42
 
43
+ # Raw spectrum
44
  ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
45
  ax[0].set_title("Raw Input Spectrum")
46
  ax[0].set_xlabel("Wavenumber (cm⁻¹)")
 
48
  ax[0].grid(True, alpha=0.3)
49
  ax[0].legend()
50
 
51
+ # Resampled spectrum
52
  ax[1].plot(
53
  x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
54
  )
 
59
  ax[1].legend()
60
 
61
  fig.tight_layout()
62
+ # Convert to image
63
  buf = io.BytesIO()
64
  plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
65
  buf.seek(0)
 
68
  return Image.open(buf)
69
 
70
 
71
+ # //////////////////////////////////////////
72
+
73
+
74
  def render_confidence_progress(
75
  probs: np.ndarray,
76
  labels: list[str] = ["Stable", "Weathered"],
 
116
  st.markdown("")
117
 
118
 
119
+ from typing import Optional
120
+
121
+
122
+ def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
123
  if d is None:
124
  d = {}
125
  if not d:
 
131
  st.caption(f"**{k}:** {v}")
132
 
133
 
134
+ # //////////////////////////////////////////
135
+
136
+
137
  def render_model_meta(model_choice: str):
138
  info = MODEL_CONFIG.get(model_choice, {})
139
  emoji = info.get("emoji", "")
 
151
  st.caption(desc)
152
 
153
 
154
+ # //////////////////////////////////////////
155
+
156
+
157
  def get_confidence_description(logit_margin):
158
  """Get human-readable confidence description"""
159
  if logit_margin > 1000:
 
166
  return "LOW", "🔴"
167
 
168
 
169
+ # //////////////////////////////////////////
170
+
171
+
172
  def render_sidebar():
173
  with st.sidebar:
174
  # Header
175
  st.header("AI-Driven Polymer Classification")
176
  st.caption(
177
+ "Predict polymer degradation (Stable vs Weathered) from Raman/FTIR spectra using validated CNN models. — v0.01"
178
+ )
179
+
180
+ # Modality Selection
181
+ st.markdown("##### Spectroscopy Modality")
182
+ modality = st.selectbox(
183
+ "Choose Modality",
184
+ ["raman", "ftir"],
185
+ index=0,
186
+ key="modality_select",
187
+ format_func=lambda x: f"{'Raman' if x == 'raman' else 'FTIR'}",
188
  )
189
+
190
+ # Display modality info
191
+ if modality == "ftir":
192
+ st.info("FTIR mode: 400-4000 cm-1 range with atmospheric correction")
193
+ else:
194
+ st.info("Raman mode: 200-4000 cm-1 range with standard preprocessing")
195
+
196
+ # Model selection
197
+ st.markdown("##### AI Model Selection")
198
  model_labels = [
199
  f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
200
  ]
 
206
  )
207
  model_choice = selected_label.split(" ", 1)[1]
208
 
209
+ # Compact metadata directly under dropdown
210
  render_model_meta(model_choice)
211
 
212
+ # Collapsed info to reduce clutter
213
  with st.expander("About This App", icon=":material/info:", expanded=False):
214
  st.markdown(
215
  """
 
217
 
218
  **Purpose**: Classify polymer degradation using AI<br>
219
  **Input**: Raman spectroscopy .txt files<br>
220
+ **Models**: CNN architectures for classification<br>
221
+ **Modalities**: Raman and FTIR spectroscopy support<br>
222
+ **Features**: Multi-model comparison and analysis<br>
223
 
224
 
225
  **Contributors**<br>
 
241
  )
242
 
243
 
244
+ # //////////////////////////////////////////
 
 
 
 
245
  def render_input_column():
246
  st.markdown("##### Data Input")
247
 
 
254
  )
255
 
256
  # == Input Mode Logic ==
 
 
257
  if mode == "Upload File":
258
  upload_key = st.session_state["current_upload_key"]
259
  up = st.file_uploader(
260
+ "Upload spectrum file (.txt, .csv, .json)",
261
+ type=["txt", "csv", "json"],
262
+ help="Upload spectroscopy data: TXT (2-column), CSV (with headers), or JSON format",
263
  key=upload_key, # ← versioned key
264
  )
265
 
266
+ # Process change immediately
267
  if up is not None:
268
  raw = up.read()
269
  text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
270
+ # only reparse if its a different file|source
271
  if (
272
  st.session_state.get("filename") != getattr(up, "name", None)
273
  or st.session_state.get("input_source") != "upload"
 
283
  st.session_state["status_type"] = "success"
284
  reset_results("New file uploaded")
285
 
286
+ # Batch Upload tab
287
  elif mode == "Batch Upload":
288
  st.session_state["batch_mode"] = True
 
289
  # Use a versioned key to ensure the file uploader resets properly.
290
  batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
291
  uploaded_files = st.file_uploader(
292
+ "Upload multiple spectrum files (.txt, .csv, .json)",
293
+ type=["txt", "csv", "json"],
294
  accept_multiple_files=True,
295
+ help="Upload spectroscopy files in TXT, CSV, or JSON format.",
296
  key=batch_upload_key,
297
  )
 
298
 
299
  if uploaded_files:
 
300
  # Use a dictionary to keep only unique files based on name and size
301
  unique_files = {(file.name, file.size): file for file in uploaded_files}
302
  unique_file_list = list(unique_files.values())
 
306
 
307
  # Optionally, inform the user that duplicates were removed
308
  if num_uploaded > num_unique:
309
+ st.info(f"{num_uploaded - num_unique} duplicate file(s) were removed.")
 
 
310
 
311
  # Use the unique list
312
  st.session_state["batch_files"] = unique_file_list
 
314
  f"{num_unique} ready for batch analysis"
315
  )
316
  st.session_state["status_type"] = "success"
 
317
  else:
318
  st.session_state["batch_files"] = []
319
  # This check prevents resetting the status if files are already staged
 
323
  )
324
  st.session_state["status_type"] = "info"
325
 
326
+ # Sample tab
327
  elif mode == "Sample Data":
328
  st.session_state["batch_mode"] = False
329
  sample_files = get_sample_files()
 
352
  else:
353
  st.info(msg)
354
 
 
 
 
355
  # Safely get model choice from session state
356
  model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
357
  model = load_model(model_choice)
 
407
  st.error(f"Error processing spectrum data: {e}")
408
 
409
 
410
+ # //////////////////////////////////////////
411
 
412
 
413
  def render_results_column():
 
429
  filename = st.session_state.get("filename", "Unknown")
430
 
431
  if all(v is not None for v in [x_raw, y_raw, y_resampled]):
432
+ # Run inference
433
  if y_resampled is None:
434
  raise ValueError(
435
  "y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
 
456
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
457
  )
458
 
459
+ # Get ground truth
460
  true_label_idx = label_file(filename)
461
  true_label_str = (
462
  LABEL_MAP.get(true_label_idx, "Unknown")
463
  if true_label_idx is not None
464
  else "Unknown"
465
  )
466
+ # Get prediction
467
  predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
468
 
469
  # Enhanced confidence calculation
 
474
  )
475
  confidence_desc = confidence_level
476
  else:
477
+ # Fallback to legacy method
478
  logit_margin = abs(
479
  (logits_list[0] - logits_list[1])
480
  if logits_list is not None and len(logits_list) >= 2
 
506
  },
507
  )
508
 
509
+ # Precompute Stats
510
  model_choice = (
511
  st.session_state.get("model_select", "").split(" ", 1)[1]
512
  if "model_select" in st.session_state
 
524
  if os.path.exists(model_path)
525
  else "N/A"
526
  )
 
527
 
528
  start_render = time.time()
529
 
 
608
  """,
609
  unsafe_allow_html=True,
610
  )
 
611
 
612
  st.divider()
613
 
614
+ # METADATA FOOTER
 
615
  st.caption(
616
  f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
617
  )
 
 
618
  st.markdown("</div>", unsafe_allow_html=True)
619
 
620
  elif active_tab == "Technical":
 
932
  """
933
  )
934
  else:
935
+ # Getting Started
936
  st.markdown(
937
  """
938
  ##### How to Get Started
 
962
  - 🏭 Quality control in manufacturing
963
  """
964
  )
965
+
966
+
967
+ # //////////////////////////////////////////
968
+
969
+
970
+ def render_comparison_tab():
971
+ """Render the multi-model comparison interface"""
972
+ import streamlit as st
973
+ import matplotlib.pyplot as plt
974
+ from models.registry import choices, validate_model_list
975
+ from utils.results_manager import ResultsManager
976
+ from core_logic import get_sample_files, run_inference, parse_spectrum_data
977
+ from utils.preprocessing import preprocess_spectrum
978
+ from utils.multifile import parse_spectrum_data
979
+ import numpy as np
980
+ import time
981
+
982
+ st.markdown("### Multi-Model Comparison Analysis")
983
+ st.markdown(
984
+ "Compare predictions across different AI models for comprehensive analysis."
985
+ )
986
+
987
+ # Model selection for comparison
988
+ st.markdown("##### Select Models for Comparison")
989
+
990
+ available_models = choices()
991
+ selected_models = st.multiselect(
992
+ "Choose models to compare",
993
+ available_models,
994
+ default=(
995
+ available_models[:2] if len(available_models) >= 2 else available_models
996
+ ),
997
+ help="Select 2 or more models to compare their predictions side-by-side",
998
+ )
999
+
1000
+ if len(selected_models) < 2:
1001
+ st.warning("⚠️ Please select at least 2 models for comparison.")
1002
+
1003
+ # Input selection for comparison
1004
+ col1, col2 = st.columns([1, 1.5])
1005
+
1006
+ with col1:
1007
+ st.markdown("###### Input Data")
1008
+
1009
+ # File upload for comparison
1010
+ comparison_file = st.file_uploader(
1011
+ "Upload spectrum for comparison",
1012
+ type=["txt", "csv", "json"],
1013
+ key="comparison_file_upload",
1014
+ help="Upload a spectrum file to test across all selected models",
1015
+ )
1016
+
1017
+ # Or select sample data
1018
+ selected_sample = None # Initialize with a default value
1019
+ sample_files = get_sample_files()
1020
+ if sample_files:
1021
+ sample_options = ["-- Select Sample --"] + [p.name for p in sample_files]
1022
+ selected_sample = st.selectbox(
1023
+ "Or choose sample data", sample_options, key="comparison_sample_select"
1024
+ )
1025
+
1026
+ # Get modality from session state
1027
+ modality = st.session_state.get("modality_select", "raman")
1028
+ st.info(f"Using {modality.upper()} preprocessing parameters")
1029
+
1030
+ # Run comparison button
1031
+ run_comparison = st.button(
1032
+ "Run Multi-Model Comparison",
1033
+ type="primary",
1034
+ disabled=not (
1035
+ comparison_file
1036
+ or (sample_files and selected_sample != "-- Select Sample --")
1037
+ ),
1038
+ )
1039
+
1040
+ with col2:
1041
+ st.markdown("###### Comparison Results")
1042
+
1043
+ if run_comparison:
1044
+ # Determine input source
1045
+ input_text = None
1046
+ filename = "unknown"
1047
+
1048
+ if comparison_file:
1049
+ raw = comparison_file.read()
1050
+ input_text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
1051
+ filename = comparison_file.name
1052
+ elif sample_files and selected_sample != "-- Select Sample --":
1053
+ sample_path = next(p for p in sample_files if p.name == selected_sample)
1054
+ with open(sample_path, "r") as f:
1055
+ input_text = f.read()
1056
+ filename = selected_sample
1057
+
1058
+ if input_text:
1059
+ try:
1060
+ # Parse spectrum data
1061
+ x_raw, y_raw = parse_spectrum_data(
1062
+ str(input_text), filename or "unknown_filename"
1063
+ )
1064
+
1065
+ # Store results
1066
+ comparison_results = {}
1067
+ processing_times = {}
1068
+
1069
+ progress_bar = st.progress(0)
1070
+ status_text = st.empty()
1071
+
1072
+ for i, model_name in enumerate(selected_models):
1073
+ status_text.text(f"Running inference with {model_name}...")
1074
+
1075
+ start_time = time.time()
1076
+
1077
+ # Preprocess spectrum with modality-specific parameters
1078
+ _, y_processed = preprocess_spectrum(
1079
+ x_raw, y_raw, modality=modality, target_len=500
1080
+ )
1081
+
1082
+ # Run inference
1083
+ prediction, logits_list, probs, inference_time, logits = (
1084
+ run_inference(y_processed, model_name)
1085
+ )
1086
+
1087
+ processing_time = time.time() - start_time
1088
+
1089
+ if prediction is not None:
1090
+ # Map prediction to class name
1091
+ class_names = ["Stable", "Weathered"]
1092
+ predicted_class = (
1093
+ class_names[int(prediction)]
1094
+ if prediction < len(class_names)
1095
+ else f"Class_{prediction}"
1096
+ )
1097
+ confidence = (
1098
+ max(probs)
1099
+ if probs is not None and len(probs) > 0
1100
+ else 0.0
1101
+ )
1102
+
1103
+ comparison_results[model_name] = {
1104
+ "prediction": prediction,
1105
+ "predicted_class": predicted_class,
1106
+ "confidence": confidence,
1107
+ "probs": probs if probs is not None else [],
1108
+ "logits": (
1109
+ logits_list if logits_list is not None else []
1110
+ ),
1111
+ "processing_time": processing_time,
1112
+ }
1113
+ processing_times[model_name] = processing_time
1114
+
1115
+ progress_bar.progress((i + 1) / len(selected_models))
1116
+
1117
+ status_text.text("Comparison complete!")
1118
+
1119
+ # Display results
1120
+ if comparison_results:
1121
+ st.markdown("###### Model Predictions")
1122
+
1123
+ # Create comparison table
1124
+ import pandas as pd
1125
+
1126
+ table_data = []
1127
+ for model_name, result in comparison_results.items():
1128
+ row = {
1129
+ "Model": model_name,
1130
+ "Prediction": result["predicted_class"],
1131
+ "Confidence": f"{result['confidence']:.3f}",
1132
+ "Processing Time (s)": f"{result['processing_time']:.3f}",
1133
+ }
1134
+ table_data.append(row)
1135
+
1136
+ df = pd.DataFrame(table_data)
1137
+ st.dataframe(df, use_container_width=True)
1138
+
1139
+ # Show confidence comparison
1140
+ st.markdown("##### Confidence Comparison")
1141
+ conf_col1, conf_col2 = st.columns(2)
1142
+
1143
+ with conf_col1:
1144
+ # Bar chart of confidences
1145
+ models = list(comparison_results.keys())
1146
+ confidences = [
1147
+ comparison_results[m]["confidence"] for m in models
1148
+ ]
1149
+
1150
+ fig, ax = plt.subplots(figsize=(8, 5))
1151
+ bars = ax.bar(
1152
+ models,
1153
+ confidences,
1154
+ alpha=0.7,
1155
+ color=["steelblue", "orange", "green", "red"][
1156
+ : len(models)
1157
+ ],
1158
+ )
1159
+ ax.set_ylabel("Confidence")
1160
+ ax.set_title("Model Confidence Comparison")
1161
+ ax.set_ylim(0, 1)
1162
+ plt.xticks(rotation=45)
1163
+
1164
+ # Add value labels on bars
1165
+ for bar, conf in zip(bars, confidences):
1166
+ height = bar.get_height()
1167
+ ax.text(
1168
+ bar.get_x() + bar.get_width() / 2.0,
1169
+ height + 0.01,
1170
+ f"{conf:.3f}",
1171
+ ha="center",
1172
+ va="bottom",
1173
+ )
1174
+
1175
+ plt.tight_layout()
1176
+ st.pyplot(fig)
1177
+
1178
+ with conf_col2:
1179
+ # Agreement analysis
1180
+ predictions = [
1181
+ comparison_results[m]["prediction"] for m in models
1182
+ ]
1183
+ unique_predictions = set(predictions)
1184
+
1185
+ if len(unique_predictions) == 1:
1186
+ st.success("✅ All models agree on the prediction!")
1187
+ else:
1188
+ st.warning("⚠️ Models disagree on the prediction")
1189
+
1190
+ # Show prediction distribution
1191
+ from collections import Counter
1192
+
1193
+ pred_counts = Counter(predictions)
1194
+
1195
+ st.markdown("**Prediction Distribution:**")
1196
+ for pred, count in pred_counts.items():
1197
+ class_name = (
1198
+ ["Stable", "Weathered"][pred]
1199
+ if pred < 2
1200
+ else f"Class_{pred}"
1201
+ )
1202
+ percentage = (count / len(predictions)) * 100
1203
+ st.write(
1204
+ f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
1205
+ )
1206
+
1207
+ # Performance metrics
1208
+ st.markdown("##### Performance Metrics")
1209
+ perf_col1, perf_col2 = st.columns(2)
1210
+
1211
+ with perf_col1:
1212
+ avg_time = np.mean(list(processing_times.values()))
1213
+ fastest_model = min(
1214
+ processing_times.keys(),
1215
+ key=lambda k: processing_times[k],
1216
+ )
1217
+ slowest_model = max(
1218
+ processing_times.keys(),
1219
+ key=lambda k: processing_times[k],
1220
+ )
1221
+
1222
+ st.metric("Average Processing Time", f"{avg_time:.3f}s")
1223
+ st.metric(
1224
+ "Fastest Model",
1225
+ f"{fastest_model}",
1226
+ f"{processing_times[fastest_model]:.3f}s",
1227
+ )
1228
+ st.metric(
1229
+ "Slowest Model",
1230
+ f"{slowest_model}",
1231
+ f"{processing_times[slowest_model]:.3f}s",
1232
+ )
1233
+
1234
+ with perf_col2:
1235
+ most_confident = max(
1236
+ comparison_results.keys(),
1237
+ key=lambda k: comparison_results[k]["confidence"],
1238
+ )
1239
+ least_confident = min(
1240
+ comparison_results.keys(),
1241
+ key=lambda k: comparison_results[k]["confidence"],
1242
+ )
1243
+
1244
+ st.metric(
1245
+ "Most Confident",
1246
+ f"{most_confident}",
1247
+ f"{comparison_results[most_confident]['confidence']:.3f}",
1248
+ )
1249
+ st.metric(
1250
+ "Least Confident",
1251
+ f"{least_confident}",
1252
+ f"{comparison_results[least_confident]['confidence']:.3f}",
1253
+ )
1254
+
1255
+ # Store results in session state for potential export
1256
+ # Store results in session state for potential export
1257
+ st.session_state["last_comparison_results"] = {
1258
+ "filename": filename,
1259
+ "modality": modality,
1260
+ "models": comparison_results,
1261
+ "summary": {
1262
+ "agreement": len(unique_predictions) == 1,
1263
+ "avg_processing_time": avg_time,
1264
+ "fastest_model": fastest_model,
1265
+ "most_confident": most_confident,
1266
+ },
1267
+ }
1268
+
1269
+ except Exception as e:
1270
+ st.error(f"Error during comparison: {str(e)}")
1271
+
1272
+ # Show recent comparison results if available
1273
+ elif "last_comparison_results" in st.session_state:
1274
+ st.info(
1275
+ "Previous comparison results available. Upload a new file or select a sample to run new comparison."
1276
+ )
1277
+
1278
+ # Show comparison history
1279
+ comparison_stats = ResultsManager.get_comparison_stats()
1280
+ if comparison_stats:
1281
+ st.markdown("#### Comparison History")
1282
+
1283
+ with st.expander("View detailed comparison statistics", expanded=False):
1284
+ # Show model statistics table
1285
+ stats_data = []
1286
+ for model_name, stats in comparison_stats.items():
1287
+ row = {
1288
+ "Model": model_name,
1289
+ "Total Predictions": stats["total_predictions"],
1290
+ "Avg Confidence": f"{stats['avg_confidence']:.3f}",
1291
+ "Avg Processing Time": f"{stats['avg_processing_time']:.3f}s",
1292
+ "Accuracy": (
1293
+ f"{stats['accuracy']:.3f}"
1294
+ if stats["accuracy"] is not None
1295
+ else "N/A"
1296
+ ),
1297
+ }
1298
+ stats_data.append(row)
1299
+
1300
+ if stats_data:
1301
+ import pandas as pd
1302
+
1303
+ stats_df = pd.DataFrame(stats_data)
1304
+ st.dataframe(stats_df, use_container_width=True)
1305
+
1306
+ # Show agreement matrix if multiple models
1307
+ agreement_matrix = ResultsManager.get_agreement_matrix()
1308
+ if not agreement_matrix.empty and len(agreement_matrix) > 1:
1309
+ st.markdown("**Model Agreement Matrix**")
1310
+ st.dataframe(agreement_matrix.round(3), use_container_width=True)
1311
+
1312
+ # Plot agreement heatmap
1313
+ fig, ax = plt.subplots(figsize=(8, 6))
1314
+ im = ax.imshow(
1315
+ agreement_matrix.values, cmap="RdYlGn", vmin=0, vmax=1
1316
+ )
1317
+
1318
+ # Add text annotations
1319
+ for i in range(len(agreement_matrix)):
1320
+ for j in range(len(agreement_matrix.columns)):
1321
+ text = ax.text(
1322
+ j,
1323
+ i,
1324
+ f"{agreement_matrix.iloc[i, j]:.2f}",
1325
+ ha="center",
1326
+ va="center",
1327
+ color="black",
1328
+ )
1329
+
1330
+ ax.set_xticks(range(len(agreement_matrix.columns)))
1331
+ ax.set_yticks(range(len(agreement_matrix)))
1332
+ ax.set_xticklabels(agreement_matrix.columns, rotation=45)
1333
+ ax.set_yticklabels(agreement_matrix.index)
1334
+ ax.set_title("Model Agreement Matrix")
1335
+
1336
+ plt.colorbar(im, ax=ax, label="Agreement Rate")
1337
+ plt.tight_layout()
1338
+ st.pyplot(fig)
1339
+
1340
+ # Export functionality
1341
+ if "last_comparison_results" in st.session_state:
1342
+ st.markdown("##### Export Results")
1343
+
1344
+ export_col1, export_col2 = st.columns(2)
1345
+
1346
+ with export_col1:
1347
+ if st.button("📥 Export Comparison (JSON)"):
1348
+ import json
1349
+
1350
+ results = st.session_state["last_comparison_results"]
1351
+ json_str = json.dumps(results, indent=2, default=str)
1352
+ st.download_button(
1353
+ label="Download JSON",
1354
+ data=json_str,
1355
+ file_name=f"comparison_{results['filename'].split('.')[0]}.json",
1356
+ mime="application/json",
1357
+ )
1358
+
1359
+ with export_col2:
1360
+ if st.button("📊 Export Full Report"):
1361
+ report = ResultsManager.export_comparison_report()
1362
+ st.download_button(
1363
+ label="Download Full Report",
1364
+ data=report,
1365
+ file_name="model_comparison_report.json",
1366
+ mime="application/json",
1367
+ )
1368
+
1369
+
1370
+ # //////////////////////////////////////////
1371
+
1372
+
1373
+ def render_performance_tab():
1374
+ """Render the performance tracking and analysis tab."""
1375
+ from utils.performance_tracker import display_performance_dashboard
1376
+
1377
+ display_performance_dashboard()