devjas1 commited on
Commit
40a522b
·
1 Parent(s): c8f5637

(FIX/FEATURE): Enhance app.py — improve error handling, stabilize UI, and add logging for analysis

Browse files
Files changed (1) hide show
  1. app.py +195 -106
app.py CHANGED
@@ -5,6 +5,12 @@ This is an adapted version of the Streamlit app optimized for Hugging Face Space
5
  It maintains all the functionality of the original app while being self-contained and cloud-ready.
6
  """
7
 
 
 
 
 
 
 
8
  import os
9
  import sys
10
  from pathlib import Path
@@ -24,7 +30,8 @@ import io
24
  from pathlib import Path
25
  import time
26
  import gc
27
- from io import StringIO
 
28
 
29
  # Import local modules
30
  from models.figure2_cnn import Figure2CNN
@@ -43,6 +50,15 @@ st.set_page_config(
43
  initial_sidebar_state="expanded"
44
  )
45
 
 
 
 
 
 
 
 
 
 
46
  # Constants
47
  TARGET_LEN = 500
48
  SAMPLE_DATA_DIR = "sample_data"
@@ -87,6 +103,15 @@ def label_file(filename: str) -> int:
87
  # Return None for unknown patterns instead of raising error
88
  return -1 # Default value for unknown patterns
89
 
 
 
 
 
 
 
 
 
 
90
  @st.cache_resource
91
  def load_model(model_name):
92
  """Load and cache the specified model with error handling"""
@@ -104,15 +129,17 @@ def load_model(model_name):
104
  st.info("Using randomly initialized model for demonstration purposes.")
105
  return model, False
106
 
 
 
 
107
  # Load weights
108
- state_dict = torch.load(model_path, map_location="cpu")
109
- model.load_state_dict(state_dict, strict=False)
110
- if model is not None:
111
  model.eval()
 
112
  else:
113
- raise ValueError("Model is not loaded. Please check the model configuration or weights.")
114
-
115
- return model, True
116
 
117
  except Exception as e:
118
  st.error(f"❌ Error loading model {model_name}: {str(e)}")
@@ -133,7 +160,7 @@ def get_sample_files():
133
  return []
134
 
135
  def parse_spectrum_data(raw_text):
136
- """Parse spectrum data from text with robust error handling"""
137
  x_vals, y_vals = [], []
138
 
139
  for line in raw_text.splitlines():
@@ -158,7 +185,22 @@ def parse_spectrum_data(raw_text):
158
  if len(x_vals) < 10: # Minimum reasonable spectrum length
159
  raise ValueError(f"Insufficient data points: {len(x_vals)}. Need at least 10 points.")
160
 
161
- return np.array(x_vals), np.array(y_vals)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  def create_spectrum_plot(x_raw, y_raw, y_resampled):
164
  """Create spectrum visualization plot"""
@@ -202,29 +244,73 @@ def get_confidence_description(logit_margin):
202
  else:
203
  return "LOW", "🔴"
204
 
205
- # Initialize session state
206
  def init_session_state():
207
- """Initialize session state variables"""
208
  defaults = {
209
- 'status_message': "Ready to analyze polymer spectra 🔬",
210
- 'status_type': "info",
211
- 'uploaded_file': None, # legacy; kept for compatibility
212
- 'input_text': None, # ←←← NEW: canonical store for spectrum text
213
- 'filename': None,
214
- 'inference_run_once': False,
215
- 'x_raw': None,
216
- 'y_raw': None,
217
- 'y_resampled': None
218
  }
 
 
 
219
 
220
  for key, default_value in defaults.items():
221
  if key not in st.session_state:
222
  st.session_state[key] = default_value
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # Main app
225
  def main():
226
  init_session_state()
227
-
228
  # Header
229
  st.title("🔬 AI-Driven Polymer Classification")
230
  st.markdown("**Predict polymer degradation states using Raman spectroscopy and deep learning**")
@@ -273,107 +359,97 @@ def main():
273
  with col1:
274
  st.subheader("📁 Data Input")
275
 
276
- # File upload tabs
277
- tab1, tab2 = st.tabs(["📤 Upload File", "🧪 Sample Data"])
278
-
279
- uploaded_file = None
280
-
281
- with tab1:
282
- uploaded_file = st.file_uploader(
283
- "Upload Raman spectrum (.txt)",
 
 
 
 
284
  type="txt",
285
  help="Upload a text file with wavenumber and intensity columns",
286
- key="upload_text"
 
287
  )
 
 
288
 
289
- if uploaded_file:
290
- # Read now and persist raw text; avoid holding open buffers in session_state
291
- raw = uploaded_file.read()
292
- text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
293
- st.session_state['input_text'] = text
294
- st.session_state['filename'] = uploaded_file.name
295
- st.session_state['uploaded_file'] = None # avoid stale buffers
296
- st.success(f"✅ Loaded: {uploaded_file.name}")
297
-
298
- with tab2:
299
  sample_files = get_sample_files()
300
  if sample_files:
301
- sample_options = ["-- Select Sample --"] + [f.name for f in sample_files]
302
- selected_sample = st.selectbox("Choose sample spectrum:", sample_options, key="sample_select")
303
-
304
- if selected_sample != "-- Select Sample --":
305
- selected_path = Path(SAMPLE_DATA_DIR) / selected_sample
306
- try:
307
- with open(selected_path, "r", encoding="utf-8") as f:
308
- file_contents = f.read()
309
- # Persist raw text + name; no open file handles in session_state
310
- st.session_state['input_text'] = file_contents
311
- st.session_state['filename'] = selected_sample
312
- st.session_state['uploaded_file'] = None
313
- st.success(f"✅ Loaded sample: {selected_sample}")
314
- except (FileNotFoundError, IOError) as e:
315
- st.error(f"Error loading sample: {e}")
316
  else:
317
  st.info("No sample data available")
318
 
319
- # Update session state
320
- # If we captured text via either tab, reflect readiness in status
321
- if st.session_state.get('input_text'):
322
- st.session_state['status_message'] = f"📁 File '{st.session_state.get('filename', '(unnamed)')}' ready for analysis"
323
- st.session_state['status_type'] = "success"
324
-
325
- # Status display
326
  st.subheader("🚦 Status")
327
- status_msg = st.session_state.get("status_message", "Ready")
328
- status_type = st.session_state.get("status_type", "info")
329
-
330
- if status_type == "success":
331
- st.success(status_msg)
332
- elif status_type == "error":
333
- st.error(status_msg)
334
  else:
335
- st.info(status_msg)
336
 
337
- # Load model
338
  model, model_loaded = load_model(model_choice)
339
-
340
- # Ready if we have cached text and a model instance
341
- inference_ready = bool(st.session_state.get('input_text')) and (model is not None)
342
-
343
  if not model_loaded:
344
  st.warning("⚠️ Model weights not available - using demo mode")
345
 
346
- if st.button("▶️ Run Analysis", disabled=not inference_ready, type="primary", key="run_btn"):
347
- if inference_ready:
348
- try:
349
- # Use persisted text + filename (works for uploads and samples)
350
- raw_text = st.session_state.get('input_text')
351
- filename = st.session_state.get('filename') or "unknown.txt"
352
- if not raw_text:
353
- raise ValueError("No input text available. Please upload or select a sample.")
354
-
355
- # Parse spectrum
356
- with st.spinner("Parsing spectrum data..."):
357
- x_raw, y_raw = parse_spectrum_data(raw_text)
358
-
359
- # Resample spectrum
360
- with st.spinner("Resampling spectrum..."):
361
- y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN)
362
-
363
- # Store in session state
364
- st.session_state['x_raw'] = x_raw
365
- st.session_state['y_raw'] = y_raw
366
- st.session_state['y_resampled'] = y_resampled
367
- st.session_state['inference_run_once'] = True
368
- st.session_state['status_message'] = f"🔍 Analysis completed for: {filename}"
369
- st.session_state['status_type'] = "success"
370
-
371
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
- except Exception as e:
374
- st.error(f"❌ Analysis failed: {str(e)}")
375
- st.session_state['status_message'] = f"❌ Error: {str(e)}"
376
- st.session_state['status_type'] = "error"
377
 
378
  # Results column
379
  with col2:
@@ -394,6 +470,7 @@ def main():
394
  st.image(spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
395
  except Exception as e:
396
  st.warning(f"Could not generate plot: {e}")
 
397
 
398
  # Run inference
399
  try:
@@ -413,6 +490,7 @@ def main():
413
  logits_list = logits.detach().numpy().tolist()[0]
414
 
415
  inference_time = time.time() - start_time
 
416
 
417
  # Clean up memory
418
  cleanup_memory()
@@ -476,8 +554,14 @@ def main():
476
 
477
  with tab2:
478
  st.markdown("**Technical Information**")
 
 
 
479
  st.json({
480
  "Model Architecture": model_choice,
 
 
 
481
  "Input Shape": list(input_tensor.shape),
482
  "Output Shape": list(logits.shape),
483
  "Inference Time": f"{inference_time:.3f}s",
@@ -488,6 +572,10 @@ def main():
488
  if not model_loaded:
489
  st.warning("⚠️ Demo mode: Using randomly initialized weights")
490
 
 
 
 
 
491
  with tab3:
492
  st.markdown("""
493
  **🔍 Analysis Process**
@@ -513,6 +601,7 @@ def main():
513
 
514
  except Exception as e:
515
  st.error(f"❌ Inference failed: {str(e)}")
 
516
 
517
  else:
518
  st.error("❌ Missing spectrum data. Please upload a file and run analysis.")
@@ -539,4 +628,4 @@ def main():
539
  """)
540
 
541
  # Run the application
542
- main()
 
5
  It maintains all the functionality of the original app while being self-contained and cloud-ready.
6
  """
7
 
8
+ BUILD_LABEL = "proof-2025-08-24-01"
9
+ import os, streamlit as st, sys
10
+ st.sidebar.caption(
11
+ f"Build: {BUILD_LABEL} | __file__: {__file__} | cwd: {os.getcwd()} | py: {sys.version.split()[0]}"
12
+ )
13
+
14
  import os
15
  import sys
16
  from pathlib import Path
 
30
  from pathlib import Path
31
  import time
32
  import gc
33
+ import hashlib
34
+ import logging
35
 
36
  # Import local modules
37
  from models.figure2_cnn import Figure2CNN
 
50
  initial_sidebar_state="expanded"
51
  )
52
 
53
+ # Stabilize tab panel height on HF Spaces to prevent visible column jitter.
54
+ # This sets a minimum height for the content area under the tab headers.
55
+ st.markdown("""
56
+ <style>
57
+ /* Tabs content area: the sibling after the tablist */
58
+ div[data-testid="stTabs"] > div[role="tablist"] + div { min-height: 420px;}
59
+ </style>
60
+ """, unsafe_allow_html=True)
61
+
62
  # Constants
63
  TARGET_LEN = 500
64
  SAMPLE_DATA_DIR = "sample_data"
 
103
  # Return None for unknown patterns instead of raising error
104
  return -1 # Default value for unknown patterns
105
 
106
+ @st.cache_data
107
+ def load_state_dict(_mtime, model_path):
108
+ """Load state dict with mtime in cache key to detect file changes"""
109
+ try:
110
+ return torch.load(model_path, map_location="cpu")
111
+ except Exception as e:
112
+ st.warning(f"Error loading state dict: {e}")
113
+ return None
114
+
115
  @st.cache_resource
116
  def load_model(model_name):
117
  """Load and cache the specified model with error handling"""
 
129
  st.info("Using randomly initialized model for demonstration purposes.")
130
  return model, False
131
 
132
+ # Get mtime for cache invalidation
133
+ mtime = os.path.getmtime(model_path)
134
+
135
  # Load weights
136
+ state_dict = load_state_dict(mtime, model_path)
137
+ if state_dict:
138
+ model.load_state_dict(state_dict, strict=True)
139
  model.eval()
140
+ return model, True
141
  else:
142
+ return model, False
 
 
143
 
144
  except Exception as e:
145
  st.error(f"❌ Error loading model {model_name}: {str(e)}")
 
160
  return []
161
 
162
  def parse_spectrum_data(raw_text):
163
+ """Parse spectrum data from text with robust error handling and validation"""
164
  x_vals, y_vals = [], []
165
 
166
  for line in raw_text.splitlines():
 
185
  if len(x_vals) < 10: # Minimum reasonable spectrum length
186
  raise ValueError(f"Insufficient data points: {len(x_vals)}. Need at least 10 points.")
187
 
188
+ x = np.array(x_vals)
189
+ y = np.array(y_vals)
190
+
191
+ # Check for NaNs
192
+ if np.any(np.isnan(x)) or np.any(np.isnan(y)):
193
+ raise ValueError("Input data contains NaN values")
194
+
195
+ # Check monotonic increasing x
196
+ if not np.all(np.diff(x) > 0):
197
+ raise ValueError("Wavenumbers must be strictly increasing")
198
+
199
+ # Check reasonable range for Raman spectroscopy
200
+ if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
201
+ raise ValueError(f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100")
202
+
203
+ return x, y
204
 
205
  def create_spectrum_plot(x_raw, y_raw, y_resampled):
206
  """Create spectrum visualization plot"""
 
244
  else:
245
  return "LOW", "🔴"
246
 
 
247
  def init_session_state():
 
248
  defaults = {
249
+ "status_message": "Ready to analyze polymer spectra 🔬",
250
+ "status_type": "info",
251
+ "input_text": None,
252
+ "filename": None,
253
+ "input_source": None, # "upload" or "sample"
254
+ "sample_select": "-- Select Sample --",
255
+ "input_mode": "Upload File", # controls which pane is visible
256
+ "inference_run_once": False,
257
+ "x_raw": None, "y_raw": None, "y_resampled": None,
258
  }
259
+ for k, v in defaults.items():
260
+ st.session_state.setdefault(k, v)
261
+
262
 
263
  for key, default_value in defaults.items():
264
  if key not in st.session_state:
265
  st.session_state[key] = default_value
266
 
267
+ def log_message(msg):
268
+ """Log message for observability"""
269
+ st.session_state['log_messages'].append(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
270
+
271
+ def trigger_run():
272
+ """Set a flag so we can detect button press reliably across reruns"""
273
+ st.session_state['run_requested'] = True
274
+
275
+ def on_upload_change():
276
+ """Read uploaded file once and persist as text."""
277
+ up = st.session_state.get("upload_txt") # the uploader's key
278
+ if not up:
279
+ return
280
+ raw = up.read()
281
+ text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
282
+ st.session_state["input_text"] = text
283
+ st.session_state["filename"] = getattr(up, "name", "uploaded.txt")
284
+ st.session_state["input_source"] = "upload"
285
+ st.session_state["status_message"] = f"📁 File '{st.session_state['filename']}' ready for analysis"
286
+ st.session_state["status_type"] = "success"
287
+
288
+ def on_sample_change():
289
+ """Read selected sample once and persist as text."""
290
+ sel = st.session_state.get("sample_select", "-- Select Sample --")
291
+ if sel == "-- Select Sample --":
292
+ # Do nothing; leave current input intact (prevents clobbering uploads)
293
+ return
294
+ try:
295
+ text = (Path(SAMPLE_DATA_DIR) / sel).read_text(encoding="utf-8")
296
+ st.session_state["input_text"] = text
297
+ st.session_state["filename"] = sel
298
+ st.session_state["input_source"] = "sample"
299
+ st.session_state["status_message"] = f"📁 Sample '{sel}' ready for analysis"
300
+ st.session_state["status_type"] = "success"
301
+ except Exception as e:
302
+ st.session_state["status_message"] = f"❌ Error loading sample: {e}"
303
+ st.session_state["status_type"] = "error"
304
+
305
+ def on_input_mode_change():
306
+ if st.session_state["input_mode"] == "Upload File":
307
+ # reset sample when switching to Upload
308
+ st.session_state["sample_select"] = "-- Select Sample --"
309
+
310
+
311
  # Main app
312
  def main():
313
  init_session_state()
 
314
  # Header
315
  st.title("🔬 AI-Driven Polymer Classification")
316
  st.markdown("**Predict polymer degradation states using Raman spectroscopy and deep learning**")
 
359
  with col1:
360
  st.subheader("📁 Data Input")
361
 
362
+ mode = st.radio(
363
+ "Input mode",
364
+ ["Upload File", "Sample Data"],
365
+ key="input_mode",
366
+ horizontal=True,
367
+ on_change=on_input_mode_change
368
+ )
369
+
370
+ # ---- Upload tab ----
371
+ if mode == "Upload File":
372
+ up = st.file_uploader(
373
+ "Upload Raman spectrum (.txt)",
374
  type="txt",
375
  help="Upload a text file with wavenumber and intensity columns",
376
+ key="upload_txt",
377
+ on_change=on_upload_change, # <-- critical
378
  )
379
+ if up:
380
+ st.success(f"✅ Loaded: {up.name}")
381
 
382
+ # ---- Sample tab ----
383
+ else:
 
 
 
 
 
 
 
 
384
  sample_files = get_sample_files()
385
  if sample_files:
386
+ options = ["-- Select Sample --"] + [p.name for p in sample_files]
387
+ sel = st.selectbox(
388
+ "Choose sample spectrum:",
389
+ options,
390
+ key="sample_select",
391
+ on_change=on_sample_change, # <-- critical
392
+ )
393
+ if sel != "-- Select Sample --":
394
+ st.success(f"✅ Loaded sample: {sel}")
 
 
 
 
 
 
395
  else:
396
  st.info("No sample data available")
397
 
398
+ # ---- Status box ----
 
 
 
 
 
 
399
  st.subheader("🚦 Status")
400
+ msg = st.session_state.get("status_message", "Ready")
401
+ typ = st.session_state.get("status_type", "info")
402
+ if typ == "success":
403
+ st.success(msg)
404
+ elif typ == "error":
405
+ st.error(msg)
 
406
  else:
407
+ st.info(msg)
408
 
409
+ # ---- Model load ----
410
  model, model_loaded = load_model(model_choice)
 
 
 
 
411
  if not model_loaded:
412
  st.warning("⚠️ Model weights not available - using demo mode")
413
 
414
+ # Ready to run if we have text and a model
415
+ inference_ready = bool(st.session_state.get("input_text")) and (model is not None)
416
+
417
+ # ---- Run Analysis (form submit batches state + submit atomically) ----
418
+ with st.form("analysis_form", clear_on_submit=False):
419
+ submitted = st.form_submit_button(
420
+ "▶️ Run Analysis",
421
+ type="primary",
422
+ disabled=not inference_ready,
423
+ )
424
+
425
+ if submitted and inference_ready:
426
+ try:
427
+ raw_text = st.session_state["input_text"]
428
+ filename = st.session_state.get("filename") or "unknown.txt"
429
+
430
+ # Parse
431
+ with st.spinner("Parsing spectrum data..."):
432
+ x_raw, y_raw = parse_spectrum_data(raw_text)
433
+
434
+ # Resample
435
+ with st.spinner("Resampling spectrum..."):
436
+ y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN)
437
+
438
+ # Persist results (drives right column)
439
+ st.session_state["x_raw"] = x_raw
440
+ st.session_state["y_raw"] = y_raw
441
+ st.session_state["y_resampled"] = y_resampled
442
+ st.session_state["inference_run_once"] = True
443
+ st.session_state["status_message"] = f"🔍 Analysis completed for: {filename}"
444
+ st.session_state["status_type"] = "success"
445
+
446
+ st.rerun()
447
+
448
+ except Exception as e:
449
+ st.error(f"❌ Analysis failed: {e}")
450
+ st.session_state["status_message"] = f"❌ Error: {e}"
451
+ st.session_state["status_type"] = "error"
452
 
 
 
 
 
453
 
454
  # Results column
455
  with col2:
 
470
  st.image(spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
471
  except Exception as e:
472
  st.warning(f"Could not generate plot: {e}")
473
+ log_message(f"Plot generation error: {e}")
474
 
475
  # Run inference
476
  try:
 
490
  logits_list = logits.detach().numpy().tolist()[0]
491
 
492
  inference_time = time.time() - start_time
493
+ log_message(f"Inference completed in {inference_time:.2f}s, prediction: {prediction}")
494
 
495
  # Clean up memory
496
  cleanup_memory()
 
554
 
555
  with tab2:
556
  st.markdown("**Technical Information**")
557
+ model_path = MODEL_CONFIG[model_choice]["path"]
558
+ mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else "N/A"
559
+ file_hash = hashlib.md5(open(model_path, 'rb').read()).hexdigest() if os.path.exists(model_path) else "N/A"
560
  st.json({
561
  "Model Architecture": model_choice,
562
+ "Model Path": model_path,
563
+ "Weights Last Modified": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) if mtime != "N/A" else "N/A",
564
+ "Weights Hash": file_hash,
565
  "Input Shape": list(input_tensor.shape),
566
  "Output Shape": list(logits.shape),
567
  "Inference Time": f"{inference_time:.3f}s",
 
572
  if not model_loaded:
573
  st.warning("⚠️ Demo mode: Using randomly initialized weights")
574
 
575
+ # Debug log
576
+ st.markdown("**Debug Log**")
577
+ st.text_area("Logs", "\n".join(st.session_state['log_messages']), height=200)
578
+
579
  with tab3:
580
  st.markdown("""
581
  **🔍 Analysis Process**
 
601
 
602
  except Exception as e:
603
  st.error(f"❌ Inference failed: {str(e)}")
604
+ log_message(f"Inference error: {str(e)}")
605
 
606
  else:
607
  st.error("❌ Missing spectrum data. Please upload a file and run analysis.")
 
628
  """)
629
 
630
  # Run the application
631
+ main()