devjas1 commited on
Commit
2547be1
·
1 Parent(s): dd49e6b

fix(state): Correct UI bugs and finalize modular integration

Browse files

This commit finalizes the modular refactor by updating the main `app.py` script to correctly orchestrate the new components. It also addresses several critical UI and state management bugs that were present.

Key fixes:
- **Disappearing Buttons:** Fixes a critical bug where the 'Run Analysis' and 'Reset' buttons would disappear in Batch or Sample mode. The button rendering logic has been de-nested from conditional blocks to ensure it is drawn on every script rerun.
- **Redundant Rerun:** Removes an unnecessary `st.rerun()` call from a callback, resolving the "no-op" informational message from Streamlit.
- **State Logic:** Corrects a logical error in the `clear_batch_results` callback to ensure it targets the correct session state variable.
- **Orchestration:** The `app.py` script is now a clean, high-level orchestrator, responsible only for page setup and the main layout, delegating all other tasks to the appropriate modules.

Files changed (1) hide show
  1. app.py +19 -1269
app.py CHANGED
@@ -1,1291 +1,41 @@
1
- from typing import Union
2
- from utils.multifile import create_batch_uploader, process_multiple_files, display_batch_results
3
- from utils.confidence import calculate_softmax_confidence, get_confidence_badge, create_confidence_progress_html
4
- from utils.results_manager import ResultsManager
5
- from utils.errors import ErrorHandler, safe_execute
6
- from utils.preprocessing import resample_spectrum
7
- from models.resnet_cnn import ResNet1D
8
- from models.figure2_cnn import Figure2CNN
9
- import hashlib
10
- import gc
11
- import time
12
- import io
13
- from PIL import Image
14
- import matplotlib.pyplot as plt
15
- import matplotlib
16
- import numpy as np
17
- import torch
18
- import torch.nn.functional as F
19
  import streamlit as st
20
- import os
21
- import sys
22
- from pathlib import Path
23
 
24
- # Ensure 'utils' directory is in the Python path
25
- utils_path = Path(__file__).resolve().parent / "utils"
26
- if utils_path.is_dir() and str(utils_path) not in sys.path:
27
- sys.path.append(str(utils_path))
28
- matplotlib.use("Agg") # ensure headless rendering in Spaces
29
 
30
- # ==Import local modules + new modules==
 
 
 
 
 
31
 
32
- KEEP_KEYS = {
33
- # ==global UI context we want to keep after "Reset"==
34
- "model_select", # sidebar model key
35
- "input_mode", # radio for Upload|Sample
36
- "uploader_version", # version counter for file uploader
37
- "input_registry", # radio controlling Upload vs Sample
38
- }
39
 
40
- # ==Page Configuration==
41
  st.set_page_config(
42
  page_title="ML Polymer Classification",
43
  page_icon="🔬",
44
  layout="wide",
45
  initial_sidebar_state="expanded",
46
- menu_items={
47
- "Get help": "https://github.com/KLab-AI3/ml-polymer-recycling"}
48
  )
49
 
50
 
51
- # ==============================================================================
52
- # THEME-AWARE CUSTOM CSS
53
- # ==============================================================================
54
- # This CSS block has been refactored to use Streamlit's internal theme
55
- # variables. This ensures that all custom components will automatically adapt
56
- # to both light and dark themes selected by the user in the settings menu.
57
- st.markdown("""
58
- <style>
59
- /* ====== Font Imports (Optional but Recommended) ====== */
60
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&family=Fira+Code:wght@400&display=swap');
61
-
62
- /* ====== Base & Typography ====== */
63
- .stApp,
64
- section[data-testid="stSidebar"],
65
- div[data-testid="stMetricValue"],
66
- div[data-testid="stMetricLabel"] {
67
- font-family: 'Inter', sans-serif;
68
- /* Uses the main text color from the current theme (light or dark) */
69
- color: var(--text-color);
70
- }
71
-
72
- .kv-val {
73
- font-family: 'Fira Code', monospace;
74
- }
75
-
76
- /* ====== Custom Containers: Tabs & Info Boxes ====== */
77
- div[data-testid="stTabs"] > div[role="tablist"] + div {
78
- min-height: 400px;
79
- /* Uses the secondary background color, which is different in light and dark modes */
80
- background-color: var(--secondary-background-color);
81
- /* Border color uses a semi-transparent version of the text color for a subtle effect that works on any background */
82
- border: 10px solid rgba(128, 128, 128, 0.2);
83
- border-radius: 10px;
84
- padding: 24px;
85
- box-shadow: 0 2px 4px rgba(0,0,0,0.05);
86
- }
87
-
88
- .info-box {
89
- font-size: 0.9rem;
90
- padding: 12px 16px;
91
- border: 1px solid rgba(128, 128, 128, 0.2);
92
- border-radius: 10px;
93
- background-color: var(--secondary-background-color);
94
- }
95
-
96
- /* ====== Key-Value Pair Styling ====== */
97
- .kv-row {
98
- display: flex;
99
- justify-content: space-between;
100
- gap: 16px;
101
- padding: 8px 0;
102
- border-bottom: 1px solid rgba(128, 128, 128, 0.2);
103
- }
104
- .kv-row:last-child {
105
- border-bottom: none;
106
- }
107
- .kv-key {
108
- opacity: 0.7;
109
- font-size: 0.9rem;
110
- white-space: nowrap;
111
- }
112
- .kv-val {
113
- font-size: 0.9rem;
114
- overflow-wrap: break-word;
115
- text-align: right;
116
- }
117
-
118
- /* ====== Custom Expander Styling ====== */
119
- div.stExpander > details > summary::-webkit-details-marker,
120
- div.stExpander > details > summary::marker,
121
- div[data-testid="stExpander"] summary svg {
122
- display: none !important;
123
- }
124
-
125
- div.stExpander > details > summary::after {
126
- content: 'DETAILS';
127
- font-size: 0.75rem;
128
- font-weight: 600;
129
- letter-spacing: 0.5px;
130
- padding: 4px 12px;
131
- border-radius: 999px;
132
- /* The primary color is set in config.toml and adapted by Streamlit */
133
- background-color: var(--primary);
134
- /* Text on the primary color needs high contrast. White works well for our chosen purple. */
135
-
136
- transition: background-color 0.2s ease-in-out;
137
- }
138
-
139
- div.stExpander > details > summary:hover::after {
140
- /* Using a fixed darker shade on hover. A more advanced solution could use color-mix() in CSS. */
141
- filter: brightness(90%);
142
- }
143
-
144
- /* Specialized Expander Labels */
145
- .expander-results div[data-testid="stExpander"] summary::after {
146
- content: "RESULTS";
147
- background-color: #16A34A; /* Green is universal for success */
148
-
149
- }
150
- div[data-testid="stExpander"] details {
151
- content: "RESULTS";
152
- background-color: var(--primary);
153
- border-radius: 10px;
154
- padding: 10px
155
-
156
- }
157
- .expander-advanced div[data-testid="stExpander"] summary::after {
158
- content: "ADVANCED";
159
- background-color: #D97706; /* Amber is universal for warning/technical */
160
-
161
- }
162
-
163
- [data-testid="stExpanderDetails"] {
164
- padding: 16px 4px 4px 4px;
165
- background-color: transparent;
166
- border-top: 1px solid rgba(128, 128, 128, 0.2);
167
- margin-top: 12px;
168
- }
169
-
170
- /* ====== Sidebar & Metrics ====== */
171
- section[data-testid="stSidebar"] > div:first-child {
172
- background-color: var(--secondary-background-color);
173
- border-right: 1px solid rgba(128, 128, 128, 0.2);
174
- }
175
-
176
- div[data-testid="stMetricValue"] {
177
- font-size: 1.1rem !important;
178
- font-weight: 500;
179
- }
180
- div[data-testid="stMetricLabel"] {
181
- font-size: 0.85rem !important;
182
- opacity: 0.8;
183
- }
184
-
185
- /* ====== Interactivity & Accessibility ====== */
186
- :focus-visible {
187
- /* The focus outline now uses the theme's primary color */
188
- outline: 2px solid var(--primary);
189
- outline-offset: 2px;
190
- border-radius: 8px;
191
- }
192
- </style>
193
- """, unsafe_allow_html=True)
194
-
195
-
196
- # ==CONSTANTS==
197
- TARGET_LEN = 500
198
- SAMPLE_DATA_DIR = Path("sample_data")
199
- # Prefer env var, else 'model_weights' if present; else canonical 'outputs'
200
- MODEL_WEIGHTS_DIR = (
201
- os.getenv("WEIGHTS_DIR")
202
- or ("model_weights" if os.path.isdir("model_weights") else "outputs")
203
- )
204
-
205
- # Model configuration
206
- MODEL_CONFIG = {
207
- "Figure2CNN (Baseline)": {
208
- "class": Figure2CNN,
209
- "path": f"{MODEL_WEIGHTS_DIR}/figure2_model.pth",
210
- "emoji": "",
211
- "description": "Baseline CNN with standard filters",
212
- "accuracy": "94.80%",
213
- "f1": "94.30%"
214
- },
215
- "ResNet1D (Advanced)": {
216
- "class": ResNet1D,
217
- "path": f"{MODEL_WEIGHTS_DIR}/resnet_model.pth",
218
- "emoji": "",
219
- "description": "Residual CNN with deeper feature learning",
220
- "accuracy": "96.20%",
221
- "f1": "95.90%"
222
- }
223
- }
224
-
225
- # ==Label mapping==
226
- LABEL_MAP = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
227
-
228
-
229
- # ==UTILITY FUNCTIONS==
230
- def init_session_state():
231
- """Keep a persistent session state"""
232
- defaults = {
233
- "status_message": "Ready to analyze polymer spectra 🔬",
234
- "status_type": "info",
235
- "input_text": None,
236
- "filename": None,
237
- "input_source": None, # "upload", "batch" or "sample"
238
- "sample_select": "-- Select Sample --",
239
- "input_mode": "Upload File", # controls which pane is visible
240
- "inference_run_once": False,
241
- "x_raw": None, "y_raw": None, "y_resampled": None,
242
- "log_messages": [],
243
- "uploader_version": 0,
244
- "current_upload_key": "upload_txt_0",
245
- "active_tab": "Details",
246
- "batch_mode": False,
247
- }
248
-
249
- if 'uploader_key' not in st.session_state:
250
- st.session_state.uploader_key = 0
251
-
252
- for k, v in defaults.items():
253
- st.session_state.setdefault(k, v)
254
-
255
- for key, default_value in defaults.items():
256
- if key not in st.session_state:
257
- st.session_state[key] = default_value
258
-
259
- # ==Initialize results table==
260
- ResultsManager.init_results_table()
261
-
262
-
263
- def label_file(filename: str) -> int:
264
- """Extract label from filename based on naming convention"""
265
- name = Path(filename).name.lower()
266
- if name.startswith("sta"):
267
- return 0
268
- elif name.startswith("wea"):
269
- return 1
270
- else:
271
- # Return None for unknown patterns instead of raising error
272
- return -1 # Default value for unknown patterns
273
-
274
-
275
- @st.cache_data
276
- def load_state_dict(_mtime, model_path):
277
- """Load state dict with mtime in cache key to detect file changes"""
278
- try:
279
- return torch.load(model_path, map_location="cpu")
280
- except (FileNotFoundError, RuntimeError) as e:
281
- st.warning(f"Error loading state dict: {e}")
282
- return None
283
-
284
-
285
- @st.cache_resource
286
- def load_model(model_name):
287
- """Load and cache the specified model with error handling"""
288
- try:
289
- config = MODEL_CONFIG[model_name]
290
- model_class = config["class"]
291
- model_path = config["path"]
292
-
293
- # Initialize model
294
- model = model_class(input_length=TARGET_LEN)
295
-
296
- # Check if model file exists
297
- if not os.path.exists(model_path):
298
- st.warning(f"⚠️ Model weights not found: {model_path}")
299
- st.info("Using randomly initialized model for demonstration purposes.")
300
- return model, False
301
-
302
- # Get mtime for cache invalidation
303
- mtime = os.path.getmtime(model_path)
304
-
305
- # Load weights
306
- state_dict = load_state_dict(mtime, model_path)
307
- if state_dict:
308
- model.load_state_dict(state_dict, strict=True)
309
- if model is None:
310
- raise ValueError(
311
- "Model is not loaded. Please check the model configuration or weights.")
312
- model.eval()
313
- return model, True
314
- else:
315
- return model, False
316
-
317
- except (FileNotFoundError, KeyError, RuntimeError) as e:
318
- st.error(f"❌ Error loading model {model_name}: {str(e)}")
319
- return None, False
320
-
321
-
322
- def cleanup_memory():
323
- """Clean up memory after inference"""
324
- gc.collect()
325
- if torch.cuda.is_available():
326
- torch.cuda.empty_cache()
327
-
328
-
329
- @st.cache_data
330
- def run_inference(y_resampled, model_choice, _cache_key=None):
331
- """Run model inference and cache results"""
332
- model, model_loaded = load_model(model_choice)
333
- if not model_loaded:
334
- return None, None, None, None, None
335
-
336
- input_tensor = torch.tensor(
337
- y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
338
- start_time = time.time()
339
- model.eval()
340
- with torch.no_grad():
341
- if model is None:
342
- raise ValueError(
343
- "Model is not loaded. Please check the model configuration or weights.")
344
- logits = model(input_tensor)
345
- prediction = torch.argmax(logits, dim=1).item()
346
- logits_list = logits.detach().numpy().tolist()[0]
347
- probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
348
- inference_time = time.time() - start_time
349
- cleanup_memory()
350
- return prediction, logits_list, probs, inference_time, logits
351
-
352
-
353
- @st.cache_data
354
- def get_sample_files():
355
- """Get list of sample files if available"""
356
- sample_dir = Path(SAMPLE_DATA_DIR)
357
- if sample_dir.exists():
358
- return sorted(list(sample_dir.glob("*.txt")))
359
- return []
360
-
361
-
362
- def parse_spectrum_data(raw_text):
363
- """Parse spectrum data from text with robust error handling and validation"""
364
- x_vals, y_vals = [], []
365
-
366
- for line in raw_text.splitlines():
367
- line = line.strip()
368
- if not line or line.startswith('#'): # Skip empty lines and comments
369
- continue
370
-
371
- try:
372
- # Handle different separators
373
- parts = line.replace(",", " ").split()
374
- numbers = [p for p in parts if p.replace('.', '', 1).replace(
375
- '-', '', 1).replace('+', '', 1).isdigit()]
376
-
377
- if len(numbers) >= 2:
378
- x, y = float(numbers[0]), float(numbers[1])
379
- x_vals.append(x)
380
- y_vals.append(y)
381
-
382
- except ValueError:
383
- # Skip problematic lines but don't fail completely
384
- continue
385
-
386
- if len(x_vals) < 10: # Minimum reasonable spectrum length
387
- raise ValueError(
388
- f"Insufficient data points: {len(x_vals)}. Need at least 10 points.")
389
-
390
- x = np.array(x_vals)
391
- y = np.array(y_vals)
392
-
393
- # Check for NaNs
394
- if np.any(np.isnan(x)) or np.any(np.isnan(y)):
395
- raise ValueError("Input data contains NaN values")
396
-
397
- # Check monotonic increasing x
398
- if not np.all(np.diff(x) > 0):
399
- raise ValueError("Wavenumbers must be strictly increasing")
400
-
401
- # Check reasonable range for Raman spectroscopy
402
- if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
403
- raise ValueError(
404
- f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100")
405
-
406
- return x, y
407
-
408
-
409
- @st.cache_data
410
- def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None):
411
- """Create spectrum visualization plot"""
412
- fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
413
-
414
- # == Raw spectrum ==
415
- ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
416
- ax[0].set_title("Raw Input Spectrum")
417
- ax[0].set_xlabel("Wavenumber (cm⁻¹)")
418
- ax[0].set_ylabel("Intensity")
419
- ax[0].grid(True, alpha=0.3)
420
- ax[0].legend()
421
-
422
- # == Resampled spectrum ==
423
- ax[1].plot(x_resampled, y_resampled, label="Resampled",
424
- color="steelblue", linewidth=1)
425
- ax[1].set_title(f"Resampled ({len(y_resampled)} points)")
426
- ax[1].set_xlabel("Wavenumber (cm⁻¹)")
427
- ax[1].set_ylabel("Intensity")
428
- ax[1].grid(True, alpha=0.3)
429
- ax[1].legend()
430
-
431
- plt.tight_layout()
432
- # == Convert to image ==
433
- buf = io.BytesIO()
434
- plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
435
- buf.seek(0)
436
- plt.close(fig) # Prevent memory leaks
437
-
438
- return Image.open(buf)
439
-
440
-
441
- def render_confidence_progress(
442
- probs: np.ndarray,
443
- labels: list[str] = ["Stable", "Weathered"],
444
- highlight_idx: Union[int, None] = None,
445
- side_by_side: bool = True
446
- ):
447
- """Render Streamlit native progress bars with scientific formatting."""
448
- p = np.asarray(probs, dtype=float)
449
- p = np.clip(p, 0.0, 1.0)
450
-
451
- if side_by_side:
452
- cols = st.columns(len(labels))
453
- for i, (lbl, val, col) in enumerate(zip(labels, p, cols)):
454
- with col:
455
- is_highlighted = (
456
- highlight_idx is not None and i == highlight_idx)
457
- label_text = f"**{lbl}**" if is_highlighted else lbl
458
- st.markdown(f"{label_text}: {val*100:.1f}%")
459
- st.progress(int(round(val * 100)))
460
- else:
461
- # Vertical layout for better readability
462
- for i, (lbl, val) in enumerate(zip(labels, p)):
463
- is_highlighted = (highlight_idx is not None and i == highlight_idx)
464
-
465
- # Create a container for each probability
466
- with st.container():
467
- col1, col2 = st.columns([3, 1])
468
- with col1:
469
- if is_highlighted:
470
- st.markdown(f"**{lbl}** ← Predicted")
471
- else:
472
- st.markdown(f"{lbl}")
473
- with col2:
474
- st.metric(
475
- label="",
476
- value=f"{val*100:.1f}%",
477
- delta=None
478
- )
479
-
480
- # Progress bar with conditional styling
481
- if is_highlighted:
482
- st.progress(int(round(val * 100)))
483
- st.caption("🎯 **Model Prediction**")
484
- else:
485
- st.progress(int(round(val * 100)))
486
-
487
- if i < len(labels) - 1: # Add spacing between items
488
- st.markdown("")
489
-
490
-
491
- def render_kv_grid(d: dict, ncols: int = 2):
492
- """Display dict as a clean grid of key/value rows using native Streamlit components."""
493
- if not d:
494
- return
495
- items = list(d.items())
496
- cols = st.columns(ncols)
497
- for i, (k, v) in enumerate(items):
498
- with cols[i % ncols]:
499
- st.caption(f"**{k}:** {v}")
500
-
501
-
502
- def render_model_meta(model_choice: str):
503
- info = MODEL_CONFIG.get(model_choice, {})
504
- emoji = info.get("emoji", "")
505
- desc = info.get("description", "").strip()
506
- acc = info.get("accuracy", "-")
507
- f1 = info.get("f1", "-")
508
-
509
- st.caption(f"{emoji} **Model Snapshot** - {model_choice}")
510
- cols = st.columns(2)
511
- with cols[0]:
512
- st.metric("Accuracy", acc)
513
- with cols[1]:
514
- st.metric("F1 Score", f1)
515
- if desc:
516
- st.caption(desc)
517
-
518
-
519
- def get_confidence_description(logit_margin):
520
- """Get human-readable confidence description"""
521
- if logit_margin > 1000:
522
- return "VERY HIGH", "🟢"
523
- elif logit_margin > 250:
524
- return "HIGH", "🟡"
525
- elif logit_margin > 100:
526
- return "MODERATE", "🟠"
527
- else:
528
- return "LOW", "🔴"
529
-
530
-
531
- def log_message(msg: str):
532
- """Append a timestamped line to the in-app log, creating the buffer if needed."""
533
- ErrorHandler.log_info(msg)
534
-
535
-
536
- def trigger_run():
537
- """Set a flag so we can detect button press reliably across reruns"""
538
- st.session_state['run_requested'] = True
539
-
540
-
541
- def on_sample_change():
542
- """Read selected sample once and persist as text."""
543
- sel = st.session_state.get("sample_select", "-- Select Sample --")
544
- if sel == "-- Select Sample --":
545
- return
546
- try:
547
- text = (Path(SAMPLE_DATA_DIR / sel).read_text(encoding="utf-8"))
548
- st.session_state["input_text"] = text
549
- st.session_state["filename"] = sel
550
- st.session_state["input_source"] = "sample"
551
- # 🔧 Clear previous results so right column resets immediately
552
- reset_results("New sample selected")
553
- st.session_state["status_message"] = f"📁 Sample '{sel}' ready for analysis"
554
- st.session_state["status_type"] = "success"
555
- except (FileNotFoundError, IOError) as e:
556
- st.session_state["status_message"] = f"❌ Error loading sample: {e}"
557
- st.session_state["status_type"] = "error"
558
-
559
-
560
- def on_input_mode_change():
561
- """Reset sample when switching to Upload"""
562
- if st.session_state["input_mode"] == "Upload File":
563
- st.session_state["sample_select"] = "-- Select Sample --"
564
- st.session_state["batch_mode"] = False # Reset batch mode
565
- elif st.session_state["input_mode"] == "Sample Data":
566
- st.session_state["batch_mode"] = False # Reset batch mode
567
- # 🔧 Reset when switching modes to prevent stale right-column visuals
568
- reset_results("Switched input mode")
569
-
570
-
571
- def on_model_change():
572
- """Force the right column back to init state when the model changes"""
573
- reset_results("Model changed")
574
-
575
-
576
- def reset_results(reason: str = ""):
577
- """Clear previous inference artifacts so the right column returns to initial state."""
578
- st.session_state["inference_run_once"] = False
579
- st.session_state["x_raw"] = None
580
- st.session_state["y_raw"] = None
581
- st.session_state["y_resampled"] = None
582
- # ||== Clear batch results when resetting ==||
583
- if "batch_results" in st.session_state:
584
- del st.session_state["batch_results"]
585
- # ||== Clear logs between runs ==||
586
- st.session_state["log_messages"] = []
587
- # ||== Always reset the status box ==||
588
- st.session_state["status_message"] = (
589
- f"ℹ️ {reason}"
590
- if reason else "Ready to analyze polymer spectra 🔬"
591
- )
592
- st.session_state["status_type"] = "info"
593
-
594
-
595
- def reset_ephemeral_state():
596
- """Comprehensive reset for the entire app state."""
597
- # Define keys that should NOT be cleared by a full reset
598
- keep_keys = {"model_select", "input_mode"}
599
-
600
- for k in list(st.session_state.keys()):
601
- if k not in keep_keys:
602
- st.session_state.pop(k, None)
603
-
604
- # Re-initialize the core state after clearing
605
- init_session_state()
606
-
607
- # CRITICAL: Bump the uploader version to force a widget reset
608
- st.session_state["uploader_version"] += 1
609
- st.session_state["current_upload_key"] = f"upload_txt_{st.session_state['uploader_version']}"
610
-
611
-
612
- # --- START: BUG 2 FIX (Callback Function) ---
613
-
614
-
615
- def clear_batch_results():
616
- """Callback to clear only the batch results and the results log table."""
617
- if "batch_results" in st.session_state:
618
- del st.session_state["batch_files"]
619
- # Also clear the persistent table from the ResultsManager utility
620
- ResultsManager.clear_results()
621
- st.rerun()
622
- # --- END: BUG 2 FIX (Callback Function) ---
623
-
624
-
625
- def reset_all():
626
- # Increment the key to force the file uploader to re-render
627
- st.session_state.uploader_key += 1
628
-
629
-
630
- # Main app
631
  def main():
 
 
632
  init_session_state()
633
 
634
- # Sidebar
635
- with st.sidebar:
636
- # Header
637
- st.header("AI-Driven Polymer Classification")
638
- st.caption(
639
- "Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.1")
640
- model_labels = [
641
- f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()]
642
- selected_label = st.selectbox(
643
- "Choose AI Model", model_labels, key="model_select", on_change=on_model_change)
644
- model_choice = selected_label.split(" ", 1)[1]
645
-
646
- # ===Compact metadata directly under dropdown===
647
- render_model_meta(model_choice)
648
-
649
- # ===Collapsed info to reduce clutter===
650
- with st.expander("About This App", icon=":material/info:", expanded=False):
651
- st.markdown("""
652
- AI-Driven Polymer Aging Prediction and Classification
653
-
654
- **Purpose**: Classify polymer degradation using AI
655
- **Input**: Raman spectroscopy `.txt` files
656
- **Models**: CNN architectures for binary classification
657
- **Next**: More trained CNNs in evaluation pipeline
658
-
659
-
660
- **Contributors**
661
- Dr. Sanmukh Kuppannagari (Mentor)
662
- Dr. Metin Karailyan (Mentor)
663
- Jaser Hasan (Author)
664
-
665
-
666
- **Links**
667
- [Live HF Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
668
- [GitHub Repository](https://github.com/KLab-AI3/ml-polymer-recycling)
669
-
670
 
671
- **Citation Figure2CNN (baseline)**
672
- Neo et al., 2023, *Resour. Conserv. Recycl.*, 188, 106718.
673
- [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718)
674
- """, )
675
-
676
- # Main content area
677
  col1, col2 = st.columns([1, 1.35], gap="small")
678
-
679
  with col1:
680
- st.markdown("##### Data Input")
681
-
682
- mode = st.radio(
683
- "Input mode",
684
- ["Upload File", "Batch Upload", "Sample Data"],
685
- key="input_mode",
686
- horizontal=True,
687
- on_change=on_input_mode_change
688
- )
689
-
690
- # ==Upload tab==
691
- if mode == "Upload File":
692
- upload_key = st.session_state["current_upload_key"]
693
- up = st.file_uploader(
694
- "Upload Raman spectrum (.txt)",
695
- type="txt",
696
- help="Upload a text file with wavenumber and intensity columns",
697
- key=upload_key, # ← versioned key
698
- )
699
-
700
- # ==Process change immediately (no on_change; simpler & reliable)==
701
- if up is not None:
702
- raw = up.read()
703
- text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
704
- # == only reparse if its a different file|source ==
705
- if st.session_state.get("filename") != getattr(up, "name", None) or st.session_state.get("input_source") != "upload":
706
- st.session_state["input_text"] = text
707
- st.session_state["filename"] = getattr(up, "name", None)
708
- st.session_state["input_source"] = "upload"
709
- # Ensure single file mode
710
- st.session_state["batch_mode"] = False
711
- st.session_state["status_message"] = f"File '{st.session_state['filename']}' ready for analysis"
712
- st.session_state["status_type"] = "success"
713
- reset_results("New file uploaded")
714
-
715
- # ==Batch Upload tab==
716
- elif mode == "Batch Upload":
717
- st.session_state["batch_mode"] = True
718
- # --- START: BUG 1 & 3 FIX ---
719
- # Use a versioned key to ensure the file uploader resets properly.
720
- batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
721
- uploaded_files = st.file_uploader(
722
- "Upload multiple Raman spectrum files (.txt)",
723
- type="txt",
724
- accept_multiple_files=True,
725
- help="Upload one or more text files with wavenumber and intensity columns.",
726
- key=batch_upload_key
727
- )
728
- # --- END: BUG 1 & 3 FIX ---
729
-
730
- if uploaded_files:
731
- # --- START: Bug 1 Fix ---
732
- # Use a dictionary to keep only unique files based on name and size
733
- unique_files = {(file.name, file.size)
734
- : file for file in uploaded_files}
735
- unique_file_list = list(unique_files.values())
736
-
737
- num_uploaded = len(uploaded_files)
738
- num_unique = len(unique_file_list)
739
-
740
- # Optionally, inform the user that duplicates were removed
741
- if num_uploaded > num_unique:
742
- st.info(
743
- f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed.")
744
-
745
- # Use the unique list
746
- st.session_state["batch_files"] = unique_file_list
747
- st.session_state["status_message"] = f"{num_unique} ready for batch analysis"
748
- st.session_state["status_type"] = "success"
749
- # --- END: Bug 1 Fix ---
750
- else:
751
- st.session_state["batch_files"] = []
752
- # This check prevents resetting the status if files are already staged
753
- if not st.session_state.get("batch_files"):
754
- st.session_state["status_message"] = "No files selected for batch processing"
755
- st.session_state["status_type"] = "info"
756
-
757
- # ==Sample tab==
758
- elif mode == "Sample Data":
759
- st.session_state["batch_mode"] = False
760
- sample_files = get_sample_files()
761
- if sample_files:
762
- options = ["-- Select Sample --"] + \
763
- [p.name for p in sample_files]
764
- sel = st.selectbox(
765
- "Choose sample spectrum:",
766
- options,
767
- key="sample_select",
768
- on_change=on_sample_change,
769
- )
770
- if sel != "-- Select Sample --":
771
- st.session_state["status_message"] = f"📁 Sample '{sel}' ready for analysis"
772
- st.session_state["status_type"] = "success"
773
- else:
774
- st.info("No sample data available")
775
-
776
- # ==Status box==
777
- msg = st.session_state.get("status_message", "Ready")
778
- typ = st.session_state.get("status_type", "info")
779
- if typ == "success":
780
- st.success(msg)
781
- elif typ == "error":
782
- st.error(msg)
783
- else:
784
- st.info(msg)
785
-
786
- # ==Model load==
787
- model, model_loaded = load_model(model_choice)
788
- if not model_loaded:
789
- st.warning("⚠️ Model weights not available - using demo mode")
790
-
791
- # ==Ready to run if we have text (single) or files (batch) and a model==|
792
- is_batch_mode = st.session_state.get("batch_mode", False)
793
- batch_files = st.session_state.get("batch_files", [])
794
-
795
- inference_ready = False # Initialize with a default value
796
- if is_batch_mode:
797
- inference_ready = len(batch_files) > 0 and (model is not None)
798
- else:
799
- inference_ready = st.session_state.get(
800
- "input_text") is not None and (model is not None)
801
-
802
- # === Run Analysis (form submit batches state) ===
803
- with st.form("analysis_form", clear_on_submit=False):
804
- submitted = st.form_submit_button(
805
- "Run Analysis",
806
- type="primary",
807
- disabled=not inference_ready,
808
- )
809
-
810
- # Renamed for clarity and uses the robust on_click callback
811
- st.button("Reset All", on_click=reset_ephemeral_state,
812
- help="Clear all uploaded files and results.")
813
-
814
- if submitted and inference_ready:
815
- if is_batch_mode:
816
- with st.spinner(f"Processing {len(batch_files)} files ..."):
817
- try:
818
- batch_results = process_multiple_files(
819
- uploaded_files=batch_files,
820
- model_choice=model_choice,
821
- load_model_func=load_model,
822
- run_inference_func=run_inference,
823
- label_file_func=label_file
824
- )
825
- st.session_state["batch_results"] = batch_results
826
- st.success(
827
- f"Successfully processed {len([r for r in batch_results if r.get('success', False)])}/{len(batch_files)} files")
828
- except Exception as e:
829
- st.error(f"Error during batch processing: {e}")
830
- else:
831
- try:
832
- x_raw, y_raw = parse_spectrum_data(
833
- st.session_state["input_text"])
834
- x_resampled, y_resampled = resample_spectrum(
835
- x_raw, y_raw, TARGET_LEN)
836
- st.session_state["x_raw"] = x_raw
837
- st.session_state["y_raw"] = y_raw
838
- st.session_state["x_resampled"] = x_resampled
839
- st.session_state["y_resampled"] = y_resampled
840
- st.session_state["inference_run_once"] = True
841
- except (ValueError, TypeError) as e:
842
- st.error(f"Error processing spectrum data: {e}")
843
- st.session_state["status_message"] = f"❌ Error: {e}"
844
- st.session_state["status_type"] = "error"
845
-
846
- # Results column
847
  with col2:
848
-
849
- # Check if we're in batch more or have batch results
850
- is_batch_mode = st.session_state.get("batch_mode", False)
851
- has_batch_results = "batch_results" in st.session_state
852
-
853
- if is_batch_mode and has_batch_results:
854
- # Display batch results
855
- st.markdown("##### Batch Analysis Results")
856
- batch_results = st.session_state["batch_results"]
857
- display_batch_results(batch_results)
858
-
859
- # Add session results table
860
- st.markdown("---")
861
-
862
- # --- START: BUG 2 FIX (Button) ---
863
- # This button will clear all results from col2 correctly.
864
- # st.button("Clear Results", on_click=clear_batch_results,
865
- # help="Clear all uploaded files and results.")
866
- # --- END: BUG 2 FIX (Button) ---
867
-
868
- ResultsManager.display_results_table()
869
-
870
- elif st.session_state.get("inference_run_once", False) and not is_batch_mode:
871
- st.markdown("##### Analysis Results")
872
-
873
- # Get data from session state
874
- x_raw = st.session_state.get('x_raw')
875
- y_raw = st.session_state.get('y_raw')
876
- x_resampled = st.session_state.get('x_resampled') # ← NEW
877
- y_resampled = st.session_state.get('y_resampled')
878
- filename = st.session_state.get('filename', 'Unknown')
879
-
880
- if all(v is not None for v in [x_raw, y_raw, y_resampled]):
881
- # ===Run inference===
882
- if y_resampled is None:
883
- raise ValueError(
884
- "y_resampled is None. Ensure spectrum data is properly resampled before proceeding.")
885
- cache_key = hashlib.md5(
886
- f"{y_resampled.tobytes()}{model_choice}".encode()).hexdigest()
887
- prediction, logits_list, probs, inference_time, logits = run_inference(
888
- y_resampled, model_choice, _cache_key=cache_key
889
- )
890
- if prediction is None:
891
- st.error(
892
- "❌ Inference failed: Model not loaded. Please check that weights are available.")
893
- st.stop() # prevents the rest of the code in this block from executing
894
-
895
- log_message(
896
- f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
897
-
898
- # ===Get ground truth===
899
- true_label_idx = label_file(filename)
900
- true_label_str = LABEL_MAP.get(
901
- true_label_idx, "Unknown") if true_label_idx is not None else "Unknown"
902
- # ===Get prediction===
903
- predicted_class = LABEL_MAP.get(
904
- int(prediction), f"Class {int(prediction)}")
905
-
906
- # Enhanced confidence calculation
907
- if logits is not None:
908
- # Use new softmax-based confidence
909
- probs_np, max_confidence, confidence_level, confidence_emoji = calculate_softmax_confidence(
910
- logits)
911
- confidence_desc = confidence_level
912
- else:
913
- # Fallback to legace method
914
- logit_margin = abs(
915
- (logits_list[0] - logits_list[1]) if logits_list is not None and len(logits_list) >= 2 else 0)
916
- confidence_desc, confidence_emoji = get_confidence_description(
917
- logit_margin)
918
- max_confidence = logit_margin / 10.0 # Normalize for display
919
- probs_np = np.array([])
920
-
921
- # Store result in results manager for single file too
922
- ResultsManager.add_results(
923
- filename=filename,
924
- model_name=model_choice,
925
- prediction=int(prediction),
926
- predicted_class=predicted_class,
927
- confidence=max_confidence,
928
- logits=logits_list if logits_list else [],
929
- ground_truth=true_label_idx if true_label_idx >= 0 else None,
930
- processing_time=inference_time if inference_time is not None else 0.0,
931
- metadata={
932
- "confidence_level": confidence_desc,
933
- "confidence_emoji": confidence_emoji
934
- }
935
- )
936
-
937
- # ===Precompute Stats===
938
- spec_stats = {
939
- "Original Length": len(x_raw) if x_raw is not None else 0,
940
- "Resampled Length": TARGET_LEN,
941
- "Wavenumber Range": f"{min(x_raw):.1f}-{max(x_raw):.1f} cm⁻¹" if x_raw is not None else "N/A",
942
- "Intensity Range": f"{min(y_raw):.1f}-{max(y_raw):.1f} au" if y_raw is not None else "N/A",
943
- "Confidence Bucket": confidence_desc,
944
- }
945
- model_path = MODEL_CONFIG[model_choice]["path"]
946
- mtime = os.path.getmtime(
947
- model_path) if os.path.exists(model_path) else None
948
- file_hash = (
949
- hashlib.md5(open(model_path, 'rb').read()).hexdigest()
950
- if os.path.exists(model_path) else "N/A"
951
- )
952
- input_tensor = torch.tensor(
953
- y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
954
- model_stats = {
955
- "Architecture": model_choice,
956
- "Model Path": model_path,
957
- "Weights Last Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime else "N/A",
958
- "Weights Hash (md5)": file_hash,
959
- "Input Shape": list(input_tensor.shape),
960
- "Output Shape": list(logits.shape) if logits is not None else "N/A",
961
- "Inference Time": f"{inference_time:.3f}s",
962
- "Device": "CPU",
963
- "Model Loaded": model_loaded,
964
- }
965
-
966
- start_render = time.time()
967
-
968
- active_tab = st.selectbox(
969
- "View Results",
970
- ["Details", "Technical", "Explanation"],
971
- key="active_tab", # reuse the key you were managing manually
972
- )
973
-
974
- if active_tab == "Details":
975
- st.markdown('<div class="expander-results">',
976
- unsafe_allow_html=True)
977
- # Use a dynamic and informative title for the expander
978
- with st.expander(f"Results for {filename}", expanded=True):
979
-
980
- # --- START: STREAMLINED METRICS ---
981
- # A single, powerful row for the most important results.
982
- key_metric_cols = st.columns(3)
983
-
984
- # Metric 1: The Prediction
985
- key_metric_cols[0].metric(
986
- "Prediction", predicted_class)
987
-
988
- # Metric 2: The Confidence (with level in tooltip)
989
- confidence_icon = "🟢" if max_confidence >= 0.8 else "🟡" if max_confidence >= 0.6 else "🔴"
990
- key_metric_cols[1].metric(
991
- "Confidence",
992
- f"{confidence_icon} {max_confidence:.1%}",
993
- help=f"Confidence Level: {confidence_desc}"
994
- )
995
-
996
- # Metric 3: Ground Truth + Correctness (Combined)
997
- if true_label_idx is not None:
998
- is_correct = (predicted_class == true_label_str)
999
- delta_text = "✅ Correct" if is_correct else "❌ Incorrect"
1000
- # Use delta_color="normal" to let the icon provide the visual cue
1001
- key_metric_cols[2].metric(
1002
- "Ground Truth", true_label_str, delta=delta_text, delta_color="normal")
1003
- else:
1004
- key_metric_cols[2].metric("Ground Truth", "N/A")
1005
-
1006
- st.divider()
1007
- # --- END: STREAMLINED METRICS ---
1008
-
1009
- # --- START: CONSOLIDATED CONFIDENCE ANALYSIS ---
1010
- st.markdown("##### Probability Breakdown")
1011
-
1012
- # This custom bullet bar logic remains as it is highly specific and valuable
1013
- def create_bullet_bar(probability, width=20, predicted=False):
1014
- filled_count = int(probability * width)
1015
- bar = "▤" * filled_count + \
1016
- "▢" * (width - filled_count)
1017
- percentage = f"{probability:.1%}"
1018
- pred_marker = "↩ Predicted" if predicted else ""
1019
- return f"{bar} {percentage} {pred_marker}"
1020
-
1021
- stable_prob, weathered_prob = probs[0], probs[1]
1022
- is_stable_predicted, is_weathered_predicted = (
1023
- int(prediction) == 0), (int(prediction) == 1)
1024
-
1025
- st.markdown(f"""
1026
- <div style="font-family: 'Fira Code', monospace;">
1027
- Stable (Unweathered)<br>
1028
- {create_bullet_bar(stable_prob, predicted=is_stable_predicted)}<br><br>
1029
- Weathered (Degraded)<br>
1030
- {create_bullet_bar(weathered_prob, predicted=is_weathered_predicted)}
1031
- </div>
1032
- """, unsafe_allow_html=True)
1033
- # --- END: CONSOLIDATED CONFIDENCE ANALYSIS ---
1034
-
1035
- st.divider()
1036
-
1037
- # --- START: CLEAN METADATA FOOTER ---
1038
- # Secondary info is now a clean, single-line caption
1039
- st.caption(
1040
- f"Analyzed with **{model_choice}** in **{inference_time:.2f}s**.")
1041
- # --- END: CLEAN METADATA FOOTER ---
1042
-
1043
- st.markdown('</div>', unsafe_allow_html=True)
1044
-
1045
- elif active_tab == "Technical":
1046
- with st.container():
1047
- st.markdown("Technical Diagnostics")
1048
-
1049
- # Model performance metrics
1050
- with st.container(border=True):
1051
- st.markdown("##### **Model Performance**")
1052
- tech_col1, tech_col2 = st.columns(2)
1053
-
1054
- with tech_col1:
1055
- st.metric("Inference Time",
1056
- f"{inference_time:.3f}s")
1057
- st.metric(
1058
- "Input Length", f"{len(x_raw) if x_raw is not None else 0} points")
1059
- st.metric("Resampled Length",
1060
- f"{TARGET_LEN} points")
1061
-
1062
- with tech_col2:
1063
- st.metric("Model Loaded",
1064
- "✅ Yes" if model_loaded else "❌ No")
1065
- st.metric("Device", "CPU")
1066
- st.metric("Confidence Score",
1067
- f"{max_confidence:.3f}")
1068
-
1069
- # Raw logits display
1070
- with st.container(border=True):
1071
- st.markdown("##### **Raw Model Outputs (Logits)**")
1072
- if logits_list is not None:
1073
- logits_df = {
1074
- "Class": [LABEL_MAP.get(i, f"Class {i}") for i in range(len(logits_list))],
1075
- "Logit Value": [f"{score:.4f}" for score in logits_list],
1076
- "Probability": [f"{prob:.4f}" for prob in probs_np] if len(probs_np) > 0 else ["N/A"] * len(logits_list)
1077
- }
1078
-
1079
- # Display as a simple table format
1080
- for i, (cls, logit, prob) in enumerate(zip(logits_df["Class"], logits_df["Logit Value"], logits_df["Probability"])):
1081
- col1, col2, col3 = st.columns([2, 1, 1])
1082
- with col1:
1083
- if i == prediction:
1084
- st.markdown(f"**{cls}** ← Predicted")
1085
- else:
1086
- st.markdown(cls)
1087
- with col2:
1088
- st.caption(f"Logit: {logit}")
1089
- with col3:
1090
- st.caption(f"Prob: {prob}")
1091
-
1092
- # Spectrum statistics in organized sections
1093
- with st.container(border=True):
1094
- st.markdown("##### **Spectrum Analysis**")
1095
- spec_cols = st.columns(2)
1096
-
1097
- with spec_cols[0]:
1098
- st.markdown("**Original Spectrum:**")
1099
- render_kv_grid({
1100
- "Length": f"{len(x_raw) if x_raw is not None else 0} points",
1101
- "Range": f"{min(x_raw):.1f} - {max(x_raw):.1f} cm⁻¹" if x_raw is not None else "N/A",
1102
- "Min Intensity": f"{min(y_raw):.2e}" if y_raw is not None else "N/A",
1103
- "Max Intensity": f"{max(y_raw):.2e}" if y_raw is not None else "N/A"
1104
- }, ncols=1)
1105
-
1106
- with spec_cols[1]:
1107
- st.markdown("**Processed Spectrum:**")
1108
- render_kv_grid({
1109
- "Length": f"{TARGET_LEN} points",
1110
- "Resampling": "Linear interpolation",
1111
- "Normalization": "None",
1112
- "Input Shape": f"(1, 1, {TARGET_LEN})"
1113
- }, ncols=1)
1114
-
1115
- # Model information
1116
- with st.container(border=True):
1117
- st.markdown("##### **Model Information**")
1118
- model_info_cols = st.columns(2)
1119
-
1120
- with model_info_cols[0]:
1121
- render_kv_grid({
1122
- "Architecture": model_choice,
1123
- "Path": MODEL_CONFIG[model_choice]["path"],
1124
- "Weights Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime else "N/A"
1125
- }, ncols=1)
1126
-
1127
- with model_info_cols[1]:
1128
- if os.path.exists(model_path):
1129
- file_hash = hashlib.md5(
1130
- open(model_path, 'rb').read()).hexdigest()
1131
- render_kv_grid({
1132
- "Weights Hash": f"{file_hash[:16]}...",
1133
- "Output Shape": f"(1, {len(LABEL_MAP)})",
1134
- "Activation": "Softmax"
1135
- }, ncols=1)
1136
-
1137
- # Debug logs (collapsed by default)
1138
- with st.expander("📋 Debug Logs", expanded=False):
1139
- log_content = "\n".join(
1140
- st.session_state.get("log_messages", []))
1141
- if log_content.strip():
1142
- st.code(log_content, language="text")
1143
- else:
1144
- st.caption("No debug logs available")
1145
-
1146
- elif active_tab == "Explanation":
1147
- with st.container():
1148
- st.markdown("### 🔍 Methodology & Interpretation")
1149
-
1150
- # Process explanation
1151
- st.markdown("Analysis Pipeline")
1152
- process_steps = [
1153
- "📁 **Data Upload**: Raman spectrum file loaded and validated",
1154
- "🔍 **Preprocessing**: Spectrum parsed and resampled to 500 data points using linear interpolation",
1155
- "🧠 **AI Inference**: Convolutional Neural Network analyzes spectral patterns and molecular signatures",
1156
- "📊 **Classification**: Binary prediction with confidence scoring using softmax probabilities",
1157
- "✅ **Validation**: Ground truth comparison (when available from filename)"
1158
- ]
1159
-
1160
- for step in process_steps:
1161
- st.markdown(step)
1162
-
1163
- st.markdown("---")
1164
-
1165
- # Model interpretation
1166
- st.markdown("#### Scientific Interpretation")
1167
-
1168
- interp_col1, interp_col2 = st.columns(2)
1169
-
1170
- with interp_col1:
1171
- st.markdown("**Stable (Unweathered) Polymers:**")
1172
- st.info("""
1173
- - Well-preserved molecular structure
1174
- - Minimal oxidative degradation
1175
- - Characteristic Raman peaks intact
1176
- - Suitable for recycling applications
1177
- """)
1178
-
1179
- with interp_col2:
1180
- st.markdown("**Weathered (Degraded) Polymers:**")
1181
- st.warning("""
1182
- - Oxidized molecular bonds
1183
- - Surface degradation present
1184
- - Altered spectral signatures
1185
- - May require additional processing
1186
- """)
1187
-
1188
- st.markdown("---")
1189
-
1190
- # Applications
1191
- st.markdown("#### Research Applications")
1192
-
1193
- applications = [
1194
- "🔬 **Material Science**: Polymer degradation studies",
1195
- "♻️ **Recycling Research**: Viability assessment for circular economy",
1196
- "🌱 **Environmental Science**: Microplastic weathering analysis",
1197
- "🏭 **Quality Control**: Manufacturing process monitoring",
1198
- "📈 **Longevity Studies**: Material aging prediction"
1199
- ]
1200
-
1201
- for app in applications:
1202
- st.markdown(app)
1203
-
1204
- # Technical details
1205
- # MODIFIED: Wrap the expander in a div with the 'expander-advanced' class
1206
- st.markdown('<div class="expander-advanced">',
1207
- unsafe_allow_html=True)
1208
- with st.expander("🔧 Technical Details", expanded=False):
1209
- st.markdown("""
1210
- **Model Architecture:**
1211
- - Convolutional layers for feature extraction
1212
- - Residual connections for gradient flow
1213
- - Fully connected layers for classification
1214
- - Softmax activation for probability distribution
1215
-
1216
- **Performance Metrics:**
1217
- - Accuracy: 94.8-96.2% on validation set
1218
- - F1-Score: 94.3-95.9% across classes
1219
- - Robust to spectral noise and baseline variations
1220
-
1221
- **Data Processing:**
1222
- - Input: Raman spectra (any length)
1223
- - Resampling: Linear interpolation to 500 points
1224
- - Normalization: None (preserves intensity relationships)
1225
- """)
1226
- st.markdown(
1227
- '</div>', unsafe_allow_html=True) # Close the wrapper div
1228
-
1229
- render_time = time.time() - start_render
1230
- log_message(
1231
- f"col2 rendered in {render_time:.2f}s, active tab: {active_tab}")
1232
-
1233
- with st.expander("Spectrum Preprocessing Results", expanded=False):
1234
- st.caption("<br>Spectral Analysis", unsafe_allow_html=True)
1235
-
1236
- # Add some context about the preprocessing
1237
- st.markdown("""
1238
- **Preprocessing Overview:**
1239
- - **Original Spectrum**: Raw Raman data as uploaded
1240
- - **Resampled Spectrum**: Data interpolated to 500 points for model input
1241
- - **Purpose**: Ensures consistent input dimensions for neural network
1242
- """)
1243
-
1244
- # Create and display plot
1245
- cache_key = hashlib.md5(
1246
- f"{(x_raw.tobytes() if x_raw is not None else b'')}"
1247
- f"{(y_raw.tobytes() if y_raw is not None else b'')}"
1248
- f"{(x_resampled.tobytes() if x_resampled is not None else b'')}"
1249
- f"{(y_resampled.tobytes() if y_resampled is not None else b'')}".encode()
1250
- ).hexdigest()
1251
- spectrum_plot = create_spectrum_plot(
1252
- x_raw, y_raw, x_resampled, y_resampled, _cache_key=cache_key)
1253
- st.image(
1254
- spectrum_plot, caption="Raman Spectrum: Raw vs Processed", use_container_width=True)
1255
-
1256
- else:
1257
- st.error(
1258
- "❌ Missing spectrum data. Please upload a file and run analysis.")
1259
- else:
1260
- # ===Getting Started===
1261
- st.markdown("""
1262
- ##### How to Get Started
1263
-
1264
- 1. **Select an AI Model:** Use the dropdown menu in the sidebar to choose a model.
1265
- 2. **Provide Your Data:** Select one of the three input modes:
1266
- - **Upload File:** Analyze a single spectrum.
1267
- - **Batch Upload:** Process multiple files at once.
1268
- - **Sample Data:** Explore functionality with pre-loaded examples.
1269
- 3. **Run Analysis:** Click the "Run Analysis" button to generate the classification results.
1270
-
1271
- ---
1272
-
1273
- ##### Supported Data Format
1274
-
1275
- - **File Type:** Plain text (`.txt`)
1276
- - **Content:** Must contain two columns: `wavenumber` and `intensity`.
1277
- - **Separators:** Values can be separated by spaces or commas.
1278
- - **Preprocessing:** Your spectrum will be automatically resampled to 500 data points to match the model's input requirements.
1279
-
1280
- ---
1281
-
1282
- ##### Example Applications
1283
- - 🔬 Research on polymer degradation
1284
- - ♻️ Recycling feasibility assessment
1285
- - 🌱 Sustainability impact studies
1286
- - 🏭 Quality control in manufacturing
1287
- """)
1288
 
1289
 
1290
- # Run the application
1291
- main()
 
1
+ """Streamlit main entrance; modularized for clarity"""
2
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import streamlit as st
 
 
 
4
 
5
+ from modules.callbacks import init_session_state
 
 
 
 
6
 
7
+ from modules.ui_components import (
8
+ render_sidebar,
9
+ render_results_column,
10
+ render_input_column,
11
+ load_css,
12
+ )
13
 
 
 
 
 
 
 
 
14
 
15
+ # --- Page Setup (Called only ONCE) ---
16
  st.set_page_config(
17
  page_title="ML Polymer Classification",
18
  page_icon="🔬",
19
  layout="wide",
20
  initial_sidebar_state="expanded",
21
+ menu_items={"Get help": "https://github.com/KLab-AI3/ml-polymer-recycling"},
 
22
  )
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def main():
26
+ """Modularized main content to other scripts to clean the main app"""
27
+ load_css("static/style.css")
28
  init_session_state()
29
 
30
+ # Render UI components
31
+ render_sidebar()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
33
  col1, col2 = st.columns([1, 1.35], gap="small")
 
34
  with col1:
35
+ render_input_column()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  with col2:
37
+ render_results_column()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
+ if __name__ == "__main__":
41
+ main()