devjas1 commited on
Commit
723ebe4
Β·
1 Parent(s): 65f2520

(DEPLOY): make app.py portable for HF + canonical

Browse files

- Use Agg backend for headless ploting.
- Dual-path import for resample_spectrum (scripts/ then utils/)
- Flexible weights path (WEIGHTS_DIR env -> model_weights -> outputs)
- Detach logits before numpy to avoid autograd refs

Files changed (1) hide show
  1. deploy/hf-space/app.py +537 -0
deploy/hf-space/app.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI-Driven Polymer Aging Prediction and Classification
3
+ Hugging Face Spaces Deployment
4
+ This is an adapted version of the Streamlit app optimized for Hugging Face Spaces deployment.
5
+ It maintains all the functionality of the original app while being self-contained and cloud-ready.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ # Ensure 'utils' directory is in the Python path
13
+ utils_path = Path(__file__).resolve().parent / "utils"
14
+ if utils_path.is_dir() and str(utils_path) not in sys.path:
15
+ sys.path.append(str(utils_path))
16
+ import streamlit as st
17
+ import torch
18
+ import numpy as np
19
+ import matplotlib
20
+ matplotlib.use("Agg") # ensure headless rendering in Spaces
21
+ import matplotlib.pyplot as plt
22
+ from PIL import Image
23
+ import io
24
+ from pathlib import Path
25
+ import time
26
+ import gc
27
+ from io import StringIO
28
+
29
+ # Import local modules
30
+ from models.figure2_cnn import Figure2CNN
31
+ from models.resnet_cnn import ResNet1D
32
+ # Prefer canonical script; fallback to local utils for HF hard-copy scenario
33
+ try:
34
+ from scripts.preprocess_dataset import resample_spectrum
35
+ except ImportError:
36
+ from utils.preprocessing import resample_spectrum
37
+
38
+ # Configuration
39
+ st.set_page_config(
40
+ page_title="ML Polymer Classification",
41
+ page_icon="πŸ”¬",
42
+ layout="wide",
43
+ initial_sidebar_state="expanded"
44
+ )
45
+
46
+ # Constants
47
+ TARGET_LEN = 500
48
+ SAMPLE_DATA_DIR = "sample_data"
49
+ # Prefer env var, else 'model_weights' if present; else canonical 'outputs'
50
+ MODEL_WEIGHTS_DIR = (
51
+ os.getenv("WEIGHTS_DIR")
52
+ or ("model_weights" if os.path.isdir("model_weights") else "outputs")
53
+ )
54
+
55
+ # Model configuration
56
+ MODEL_CONFIG = {
57
+ "Figure2CNN (Baseline)": {
58
+ "class": Figure2CNN,
59
+ "path": f"{MODEL_WEIGHTS_DIR}/figure2_model.pth",
60
+ "emoji": "πŸ”¬",
61
+ "description": "Baseline CNN with standard filters",
62
+ "accuracy": "94.80%",
63
+ "f1": "94.30%"
64
+ },
65
+ "ResNet1D (Advanced)": {
66
+ "class": ResNet1D,
67
+ "path": f"{MODEL_WEIGHTS_DIR}/resnet_model.pth",
68
+ "emoji": "🧠",
69
+ "description": "Residual CNN with deeper feature learning",
70
+ "accuracy": "96.20%",
71
+ "f1": "95.90%"
72
+ }
73
+ }
74
+
75
+ # Label mapping
76
+ LABEL_MAP = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
77
+
78
+ # Utility functions
79
+ def label_file(filename: str) -> int:
80
+ """Extract label from filename based on naming convention"""
81
+ name = Path(filename).name.lower()
82
+ if name.startswith("sta"):
83
+ return 0
84
+ elif name.startswith("wea"):
85
+ return 1
86
+ else:
87
+ # Return None for unknown patterns instead of raising error
88
+ return -1 # Default value for unknown patterns
89
+
90
+ @st.cache_resource
91
+ def load_model(model_name):
92
+ """Load and cache the specified model with error handling"""
93
+ try:
94
+ config = MODEL_CONFIG[model_name]
95
+ model_class = config["class"]
96
+ model_path = config["path"]
97
+
98
+ # Initialize model
99
+ model = model_class(input_length=TARGET_LEN)
100
+
101
+ # Check if model file exists
102
+ if not os.path.exists(model_path):
103
+ st.warning(f"⚠️ Model weights not found: {model_path}")
104
+ st.info("Using randomly initialized model for demonstration purposes.")
105
+ return model, False
106
+
107
+ # Load weights
108
+ state_dict = torch.load(model_path, map_location="cpu")
109
+ model.load_state_dict(state_dict, strict=False)
110
+ model.eval()
111
+
112
+ return model, True
113
+
114
+ except Exception as e:
115
+ st.error(f"❌ Error loading model {model_name}: {str(e)}")
116
+ return None, False
117
+
118
+ def cleanup_memory():
119
+ """Clean up memory after inference"""
120
+ gc.collect()
121
+ if torch.cuda.is_available():
122
+ torch.cuda.empty_cache()
123
+
124
+ @st.cache_data
125
+ def get_sample_files():
126
+ """Get list of sample files if available"""
127
+ sample_dir = Path(SAMPLE_DATA_DIR)
128
+ if sample_dir.exists():
129
+ return sorted(list(sample_dir.glob("*.txt")))
130
+ return []
131
+
132
+ def parse_spectrum_data(raw_text):
133
+ """Parse spectrum data from text with robust error handling"""
134
+ x_vals, y_vals = [], []
135
+
136
+ for line in raw_text.splitlines():
137
+ line = line.strip()
138
+ if not line or line.startswith('#'): # Skip empty lines and comments
139
+ continue
140
+
141
+ try:
142
+ # Handle different separators
143
+ parts = line.replace(",", " ").split()
144
+ numbers = [p for p in parts if p.replace('.', '', 1).replace('-', '', 1).replace('+', '', 1).isdigit()]
145
+
146
+ if len(numbers) >= 2:
147
+ x, y = float(numbers[0]), float(numbers[1])
148
+ x_vals.append(x)
149
+ y_vals.append(y)
150
+
151
+ except ValueError:
152
+ # Skip problematic lines but don't fail completely
153
+ continue
154
+
155
+ if len(x_vals) < 10: # Minimum reasonable spectrum length
156
+ raise ValueError(f"Insufficient data points: {len(x_vals)}. Need at least 10 points.")
157
+
158
+ return np.array(x_vals), np.array(y_vals)
159
+
160
+ def create_spectrum_plot(x_raw, y_raw, y_resampled):
161
+ """Create spectrum visualization plot"""
162
+ fig, ax = plt.subplots(1, 2, figsize=(12, 4), dpi=100)
163
+
164
+ # Raw spectrum
165
+ ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
166
+ ax[0].set_title("Raw Input Spectrum")
167
+ ax[0].set_xlabel("Wavenumber (cm⁻¹)")
168
+ ax[0].set_ylabel("Intensity")
169
+ ax[0].grid(True, alpha=0.3)
170
+ ax[0].legend()
171
+
172
+ # Resampled spectrum
173
+ x_resampled = np.linspace(min(x_raw), max(x_raw), TARGET_LEN)
174
+ ax[1].plot(x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1)
175
+ ax[1].set_title(f"Resampled ({TARGET_LEN} points)")
176
+ ax[1].set_xlabel("Wavenumber (cm⁻¹)")
177
+ ax[1].set_ylabel("Intensity")
178
+ ax[1].grid(True, alpha=0.3)
179
+ ax[1].legend()
180
+
181
+ plt.tight_layout()
182
+
183
+ # Convert to image
184
+ buf = io.BytesIO()
185
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
186
+ buf.seek(0)
187
+ plt.close(fig) # Prevent memory leaks
188
+
189
+ return Image.open(buf)
190
+
191
+ def get_confidence_description(logit_margin):
192
+ """Get human-readable confidence description"""
193
+ if logit_margin > 1000:
194
+ return "VERY HIGH", "🟒"
195
+ elif logit_margin > 250:
196
+ return "HIGH", "🟑"
197
+ elif logit_margin > 100:
198
+ return "MODERATE", "🟠"
199
+ else:
200
+ return "LOW", "πŸ”΄"
201
+
202
+ # Initialize session state
203
+ def init_session_state():
204
+ """Initialize session state variables"""
205
+ defaults = {
206
+ 'status_message': "Ready to analyze polymer spectra πŸ”¬",
207
+ 'status_type': "info",
208
+ 'uploaded_file': None,
209
+ 'filename': None,
210
+ 'inference_run_once': False,
211
+ 'x_raw': None,
212
+ 'y_raw': None,
213
+ 'y_resampled': None
214
+ }
215
+
216
+ for key, default_value in defaults.items():
217
+ if key not in st.session_state:
218
+ st.session_state[key] = default_value
219
+
220
+ # Main app
221
+ def main():
222
+ init_session_state()
223
+
224
+ # Header
225
+ st.title("πŸ”¬ AI-Driven Polymer Classification")
226
+ st.markdown("**Predict polymer degradation states using Raman spectroscopy and deep learning**")
227
+
228
+ # Sidebar
229
+ with st.sidebar:
230
+ st.header("ℹ️ About This App")
231
+ st.markdown("""
232
+ **AIRE 2025 Internship Project**
233
+ AI-Driven Polymer Aging Prediction and Classification
234
+
235
+ 🎯 **Purpose**: Classify polymer degradation using AI
236
+ πŸ“Š **Input**: Raman spectroscopy data
237
+ 🧠 **Models**: CNN architectures for binary classification
238
+
239
+ **Team**:
240
+ - **Mentor**: Dr. Sanmukh Kuppannagari
241
+ - **Mentor**: Dr. Metin Karailyan
242
+ - **Author**: Jaser Hasan
243
+
244
+ πŸ”— [GitHub Repository](https://github.com/KLab-AI3/ml-polymer-recycling)
245
+ """)
246
+
247
+ st.markdown("---")
248
+
249
+ # Model selection
250
+ st.subheader("🧠 Model Selection")
251
+ model_labels = [f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()]
252
+ selected_label = st.selectbox("Choose AI model:", model_labels)
253
+ model_choice = selected_label.split(" ", 1)[1]
254
+
255
+ # Model info
256
+ config = MODEL_CONFIG[model_choice]
257
+ st.markdown(f"""
258
+ **πŸ“ˆ {config['emoji']} Model Details**
259
+
260
+ *{config['description']}*
261
+
262
+ - **Accuracy**: `{config['accuracy']}`
263
+ - **F1 Score**: `{config['f1']}`
264
+ """)
265
+
266
+ # Main content area
267
+ col1, col2 = st.columns([1, 1.5], gap="large")
268
+
269
+ with col1:
270
+ st.subheader("πŸ“ Data Input")
271
+
272
+ # File upload tabs
273
+ tab1, tab2 = st.tabs(["πŸ“€ Upload File", "πŸ§ͺ Sample Data"])
274
+
275
+ uploaded_file = None
276
+
277
+ with tab1:
278
+ uploaded_file = st.file_uploader(
279
+ "Upload Raman spectrum (.txt)",
280
+ type="txt",
281
+ help="Upload a text file with wavenumber and intensity columns"
282
+ )
283
+
284
+ if uploaded_file:
285
+ st.success(f"βœ… Loaded: {uploaded_file.name}")
286
+
287
+ with tab2:
288
+ sample_files = get_sample_files()
289
+ if sample_files:
290
+ sample_options = ["-- Select Sample --"] + [f.name for f in sample_files]
291
+ selected_sample = st.selectbox("Choose sample spectrum:", sample_options)
292
+
293
+ if selected_sample != "-- Select Sample --":
294
+ selected_path = Path(SAMPLE_DATA_DIR) / selected_sample
295
+ try:
296
+ with open(selected_path, "r", encoding="utf-8") as f:
297
+ file_contents = f.read()
298
+ uploaded_file = StringIO(file_contents)
299
+ uploaded_file.name = selected_sample
300
+ st.success(f"βœ… Loaded sample: {selected_sample}")
301
+ except Exception as e:
302
+ st.error(f"Error loading sample: {e}")
303
+ else:
304
+ st.info("No sample data available")
305
+
306
+ # Update session state
307
+ if uploaded_file is not None:
308
+ st.session_state['uploaded_file'] = uploaded_file
309
+ st.session_state['filename'] = uploaded_file.name
310
+ st.session_state['status_message'] = f"πŸ“ File '{uploaded_file.name}' ready for analysis"
311
+ st.session_state['status_type'] = "success"
312
+
313
+ # Status display
314
+ st.subheader("🚦 Status")
315
+ status_msg = st.session_state.get("status_message", "Ready")
316
+ status_type = st.session_state.get("status_type", "info")
317
+
318
+ if status_type == "success":
319
+ st.success(status_msg)
320
+ elif status_type == "error":
321
+ st.error(status_msg)
322
+ else:
323
+ st.info(status_msg)
324
+
325
+ # Load model
326
+ model, model_loaded = load_model(model_choice)
327
+
328
+ # Inference button
329
+ inference_ready = (
330
+ 'uploaded_file' in st.session_state and
331
+ st.session_state['uploaded_file'] is not None and
332
+ model is not None
333
+ )
334
+
335
+ if not model_loaded:
336
+ st.warning("⚠️ Model weights not available - using demo mode")
337
+
338
+ if st.button("▢️ Run Analysis", disabled=not inference_ready, type="primary"):
339
+ if inference_ready:
340
+ try:
341
+ # Get file data
342
+ uploaded_file = st.session_state['uploaded_file']
343
+ filename = st.session_state['filename']
344
+
345
+ # Read file content
346
+ uploaded_file.seek(0)
347
+ raw_data = uploaded_file.read()
348
+ raw_text = raw_data.decode("utf-8") if isinstance(raw_data, bytes) else raw_data
349
+
350
+ # Parse spectrum
351
+ with st.spinner("Parsing spectrum data..."):
352
+ x_raw, y_raw = parse_spectrum_data(raw_text)
353
+
354
+ # Resample spectrum
355
+ with st.spinner("Resampling spectrum..."):
356
+ y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN)
357
+
358
+ # Store in session state
359
+ st.session_state['x_raw'] = x_raw
360
+ st.session_state['y_raw'] = y_raw
361
+ st.session_state['y_resampled'] = y_resampled
362
+ st.session_state['inference_run_once'] = True
363
+ st.session_state['status_message'] = f"πŸ” Analysis completed for: {filename}"
364
+ st.session_state['status_type'] = "success"
365
+
366
+ st.rerun()
367
+
368
+ except Exception as e:
369
+ st.error(f"❌ Analysis failed: {str(e)}")
370
+ st.session_state['status_message'] = f"❌ Error: {str(e)}"
371
+ st.session_state['status_type'] = "error"
372
+
373
+ # Results column
374
+ with col2:
375
+ if st.session_state.get("inference_run_once", False):
376
+ st.subheader("πŸ“Š Analysis Results")
377
+
378
+ # Get data from session state
379
+ x_raw = st.session_state.get('x_raw')
380
+ y_raw = st.session_state.get('y_raw')
381
+ y_resampled = st.session_state.get('y_resampled')
382
+ filename = st.session_state.get('filename', 'Unknown')
383
+
384
+ if all(v is not None for v in [x_raw, y_raw, y_resampled]):
385
+
386
+ # Create and display plot
387
+ try:
388
+ spectrum_plot = create_spectrum_plot(x_raw, y_raw, y_resampled)
389
+ st.image(spectrum_plot, caption="Spectrum Preprocessing Results", use_column_width=True)
390
+ except Exception as e:
391
+ st.warning(f"Could not generate plot: {e}")
392
+
393
+ # Run inference
394
+ try:
395
+ with st.spinner("Running AI inference..."):
396
+ start_time = time.time()
397
+
398
+ # Prepare input tensor
399
+ input_tensor = torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
400
+
401
+ # Run inference
402
+ model.eval()
403
+ with torch.no_grad():
404
+ if model is None:
405
+ raise ValueError("Model is not loaded. Please check the model configuration or weights.")
406
+ logits = model(input_tensor)
407
+ prediction = torch.argmax(logits, dim=1).item()
408
+ logits_list = logits.detach().numpy().tolist()[0]
409
+
410
+ inference_time = time.time() - start_time
411
+
412
+ # Clean up memory
413
+ cleanup_memory()
414
+
415
+ # Get ground truth if available
416
+ true_label_idx = label_file(filename)
417
+ true_label_str = LABEL_MAP.get(true_label_idx, "Unknown") if true_label_idx is not None else "Unknown"
418
+
419
+ # Get prediction
420
+ predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
421
+
422
+ # Calculate confidence metrics
423
+ logit_margin = abs(logits_list[0] - logits_list[1]) if len(logits_list) >= 2 else 0
424
+ confidence_desc, confidence_emoji = get_confidence_description(logit_margin)
425
+
426
+ # Display results
427
+ st.markdown("### 🎯 Prediction Results")
428
+
429
+ # Main prediction
430
+ st.markdown(f"""
431
+ **πŸ”¬ Sample**: `{filename}`
432
+ **🧠 Model**: `{model_choice}`
433
+ **⏱️ Processing Time**: `{inference_time:.2f}s`
434
+ """)
435
+
436
+ # Prediction box
437
+ if predicted_class == "Stable (Unweathered)":
438
+ st.success(f"🟒 **Prediction**: {predicted_class}")
439
+ else:
440
+ st.warning(f"🟑 **Prediction**: {predicted_class}")
441
+
442
+ # Confidence
443
+ st.markdown(f"**{confidence_emoji} Confidence**: {confidence_desc} (margin: {logit_margin:.1f})")
444
+
445
+ # Ground truth comparison
446
+ if true_label_idx is not None:
447
+ if predicted_class == true_label_str:
448
+ st.success(f"βœ… **Ground Truth**: {true_label_str} - **Correct!**")
449
+ else:
450
+ st.error(f"❌ **Ground Truth**: {true_label_str} - **Incorrect**")
451
+ else:
452
+ st.info("ℹ️ **Ground Truth**: Unknown (filename doesn't follow naming convention)")
453
+
454
+ # Detailed results tabs
455
+ tab1, tab2, tab3 = st.tabs(["πŸ“Š Details", "πŸ”¬ Technical", "πŸ“˜ Explanation"])
456
+
457
+ with tab1:
458
+ st.markdown("**Model Output (Logits)**")
459
+ for i, score in enumerate(logits_list):
460
+ label = LABEL_MAP.get(i, f"Class {i}")
461
+ st.metric(label, f"{score:.2f}")
462
+
463
+ st.markdown("**Spectrum Statistics**")
464
+ st.json({
465
+ "Original Length": len(x_raw),
466
+ "Resampled Length": TARGET_LEN,
467
+ "Wavenumber Range": f"{min(x_raw):.1f} - {max(x_raw):.1f} cm⁻¹",
468
+ "Intensity Range": f"{min(y_raw):.1f} - {max(y_raw):.1f}",
469
+ "Model Confidence": confidence_desc
470
+ })
471
+
472
+ with tab2:
473
+ st.markdown("**Technical Information**")
474
+ st.json({
475
+ "Model Architecture": model_choice,
476
+ "Input Shape": list(input_tensor.shape),
477
+ "Output Shape": list(logits.shape),
478
+ "Inference Time": f"{inference_time:.3f}s",
479
+ "Device": "CPU",
480
+ "Model Loaded": model_loaded
481
+ })
482
+
483
+ if not model_loaded:
484
+ st.warning("⚠️ Demo mode: Using randomly initialized weights")
485
+
486
+ with tab3:
487
+ st.markdown("""
488
+ **πŸ” Analysis Process**
489
+
490
+ 1. **Data Upload**: Raman spectrum file loaded
491
+ 2. **Preprocessing**: Data parsed and resampled to 500 points
492
+ 3. **AI Inference**: CNN model analyzes spectral patterns
493
+ 4. **Classification**: Binary prediction with confidence scores
494
+
495
+ **🧠 Model Interpretation**
496
+
497
+ The AI model identifies spectral features indicative of:
498
+ - **Stable polymers**: Well-preserved molecular structure
499
+ - **Weathered polymers**: Degraded/oxidized molecular bonds
500
+
501
+ **🎯 Applications**
502
+
503
+ - Material longevity assessment
504
+ - Recycling viability evaluation
505
+ - Quality control in manufacturing
506
+ - Environmental impact studies
507
+ """)
508
+
509
+ except Exception as e:
510
+ st.error(f"❌ Inference failed: {str(e)}")
511
+
512
+ else:
513
+ st.error("❌ Missing spectrum data. Please upload a file and run analysis.")
514
+ else:
515
+ # Welcome message
516
+ st.markdown("""
517
+ ### πŸ‘‹ Welcome to AI Polymer Classification
518
+
519
+ **Get started by:**
520
+ 1. 🧠 Select an AI model in the sidebar
521
+ 2. πŸ“ Upload a Raman spectrum file or choose a sample
522
+ 3. ▢️ Click "Run Analysis" to get predictions
523
+
524
+ **Supported formats:**
525
+ - Text files (.txt) with wavenumber and intensity columns
526
+ - Space or comma-separated values
527
+ - Any length (automatically resampled to 500 points)
528
+
529
+ **Example applications:**
530
+ - πŸ”¬ Research on polymer degradation
531
+ - ♻️ Recycling feasibility assessment
532
+ - 🌱 Sustainability impact studies
533
+ - 🏭 Quality control in manufacturing
534
+ """)
535
+
536
+ # Run the application
537
+ main()