dev-jas commited on
Commit
bcdc411
·
unverified ·
2 Parent(s): 1ffa0fe 0392c68

Merge pull request #3 from devjas1/new-space-deploy

Browse files
.gitignore CHANGED
@@ -26,3 +26,4 @@ datasets/**
26
  # ---------------------------------------
27
 
28
  __pycache__.py
 
 
26
  # ---------------------------------------
27
 
28
  __pycache__.py
29
+ outputs/performance_tracking.db
CODEBASE_INVENTORY.md CHANGED
@@ -2,40 +2,38 @@
2
 
3
  ## Executive Summary
4
 
5
- This audit provides a complete technical inventory of the `dev-jas/polymer-aging-ml` repository, a sophisticated machine learning platform for polymer degradation classification using Raman spectroscopy. The system demonstrates production-ready architecture with comprehensive error handling, batch processing capabilities, and an extensible model framework spanning **34 files across 7 directories**.[^1_1][^1_2]
6
 
7
  ## 🏗️ System Architecture
8
 
9
  ### Core Infrastructure
10
 
11
- The platform employs a **Streamlit-based web application** (`app.py` - 53.7 kB) as its primary interface, supported by a modular backend architecture. The system integrates **PyTorch for deep learning**, **Docker for deployment**, and implements a plugin-based model registry for extensibility.[^1_2][^1_3][^1_4]
12
-
13
- ### Directory Structure Analysis
14
-
15
- The codebase maintains clean separation of concerns across seven primary directories:[^1_1]
16
-
17
- **Root Level Files:**
18
-
19
- - `app.py` (53.7 kB) - Main Streamlit application with two-column UI layout
20
- - `README.md` (4.8 kB) - Comprehensive project documentation
21
- - `Dockerfile` (421 Bytes) - Python 3.13-slim containerization
22
- - `requirements.txt` (132 Bytes) - Dependency management without version pinning
23
-
24
- **Core Directories:**
25
-
26
- - `models/` - Neural network architectures with registry pattern
27
- - `utils/` - Shared utility modules (43.2 kB total)
28
- - `scripts/` - CLI tools and automation workflows
29
- - `outputs/` - Pre-trained model weights storage
30
- - `sample_data/` - Demo spectrum files for testing
31
- - `tests/` - Unit testing infrastructure
32
- - `datasets/` - Data storage directory (content ignored)
33
 
34
  ## 🤖 Machine Learning Framework
35
 
36
- ### Model Registry System
37
 
38
- The platform implements a **sophisticated factory pattern** for model management in `models/registry.py`. This design enables dynamic model selection and provides a unified interface for different architectures:[^1_5]
39
 
40
  ```python
41
  _REGISTRY: Dict[str, Callable[[int], object]] = {
@@ -47,35 +45,31 @@ _REGISTRY: Dict[str, Callable[[int], object]] = {
47
 
48
  ### Neural Network Architectures
49
 
50
- **1. Figure2CNN (Baseline Model)**[^1_6]
51
 
52
- - **Architecture**: 4 convolutional layers with progressive channel expansion (1→16→32→64→128)
53
- - **Classification Head**: 3 fully connected layers (256→128→2 neurons)
54
- - **Performance**: 94.80% accuracy, 94.30% F1-score
55
- - **Designation**: Validated exclusively for Raman spectra input
56
- - **Parameters**: Dynamic flattened size calculation for input flexibility
57
 
58
- **2. ResNet1D (Advanced Model)**[^1_7]
 
 
59
 
60
- - **Architecture**: 3 residual blocks with skip connections
61
- - **Innovation**: 1D residual connections for spectral feature learning
62
- - **Performance**: 96.20% accuracy, 95.90% F1-score
63
- - **Efficiency**: Global average pooling reduces parameter count
64
- - **Parameters**: Approximately 100K (more efficient than baseline)
65
 
66
- **3. ResNet18Vision (Deep Architecture)**[^1_8]
 
 
67
 
68
- - **Design**: 1D adaptation of ResNet-18 with BasicBlock1D modules
69
- - **Structure**: 4 residual layers with 2 blocks each
70
- - **Initialization**: Kaiming normal initialization for optimal training
71
- - **Status**: Under evaluation for spectral analysis applications
 
72
 
73
  ## 🔧 Data Processing Infrastructure
74
 
75
  ### Preprocessing Pipeline
76
 
77
- The system implements a **modular preprocessing pipeline** in `utils/preprocessing.py` with five configurable stages:[^1_9]
78
-
79
  **1. Input Validation Framework:**
80
 
81
  - File format verification (`.txt` files exclusively)
@@ -84,16 +78,16 @@ The system implements a **modular preprocessing pipeline** in `utils/preprocessi
84
  - Monotonic sequence verification for spectral consistency
85
  - NaN value detection and automatic rejection
86
 
87
- **2. Core Processing Steps:**[^1_9]
88
 
89
  - **Linear Resampling**: Uniform grid interpolation to 500 points using `scipy.interpolate.interp1d`
90
  - **Baseline Correction**: Polynomial detrending (configurable degree, default=2)
91
  - **Savitzky-Golay Smoothing**: Noise reduction (window=11, order=2, configurable)
92
- - **Min-Max Normalization**: Scaling to range with constant-signal protection[^1_1]
93
 
94
  ### Batch Processing Framework
95
 
96
- The `utils/multifile.py` module (12.5 kB) provides **enterprise-grade batch processing** capabilities:[^1_10]
97
 
98
  - **Multi-File Upload**: Streamlit widget supporting simultaneous file selection
99
  - **Error-Tolerant Processing**: Individual file failures don't interrupt batch operations
@@ -123,7 +117,7 @@ The main application implements a **sophisticated two-column layout** with compr
123
 
124
  ### State Management System
125
 
126
- The application employs **advanced session state management**:[^1_2]
127
 
128
  - Persistent state across Streamlit reruns using `st.session_state`
129
  - Intelligent caching with content-based hash keys for expensive operations
@@ -134,46 +128,24 @@ The application employs **advanced session state management**:[^1_2]
134
 
135
  ### Centralized Error Handling
136
 
137
- The `utils/errors.py` module (5.51 kB) implements **production-grade error management**:[^1_11]
138
-
139
- ```python
140
- class ErrorHandler:
141
- @staticmethod
142
- def log_error(error: Exception, context: str = "", include_traceback: bool = False)
143
- @staticmethod
144
- def handle_file_error(filename: str, error: Exception) -> str
145
- @staticmethod
146
- def handle_inference_error(model_name: str, error: Exception) -> str
147
- ```
148
 
149
- **Key Features:**
150
 
151
- - Context-aware error messages for different operation types
152
- - Graceful degradation with fallback modes
153
- - Structured logging with configurable verbosity
154
- - User-friendly error translation from technical exceptions
155
 
156
- ### Confidence Analysis System
 
 
157
 
158
- The `utils/confidence.py` module provides **scientific confidence metrics**
159
-
160
- :
161
-
162
- **Softmax-Based Confidence:**
163
-
164
- - Normalized probability distributions from model logits
165
- - Three-tier confidence levels: HIGH (≥80%), MEDIUM (≥60%), LOW (<60%)
166
- - Color-coded visual indicators with emoji representations
167
- - Legacy compatibility with logit margin calculations
168
-
169
- ### Session Results Management
170
 
171
- The `utils/results_manager.py` module (8.16 kB) enables **comprehensive session tracking**:
172
 
173
- - **In-Memory Storage**: Session-wide results persistence
174
- - **Export Capabilities**: CSV and JSON download with timestamp formatting
175
- - **Statistical Analysis**: Automatic accuracy calculation when ground truth available
176
- - **Data Integrity**: Results survive page refreshes within session boundaries
177
 
178
  ## 📜 Command-Line Interface
179
 
@@ -194,17 +166,6 @@ The `scripts/train_model.py` module (6.27 kB) implements **robust model training
194
  - Deterministic CUDA operations when GPU available
195
  - Standardized train/validation splitting methodology
196
 
197
- ### Inference Pipeline
198
-
199
- The `scripts/run_inference.py` module (5.88 kB) provides **automated inference capabilities**:
200
-
201
- **CLI Features:**
202
-
203
- - Preprocessing parity with web interface ensuring consistent results
204
- - Multiple output formats with detailed metadata inclusion
205
- - Safe model loading across PyTorch versions with fallback mechanisms
206
- - Flexible architecture selection via command-line arguments
207
-
208
  ### Data Utilities
209
 
210
  **File Discovery System:**
@@ -213,17 +174,6 @@ The `scripts/run_inference.py` module (5.88 kB) provides **automated inference c
213
  - Filename-based labeling convention (`sta-*` = stable, `wea-*` = weathered)
214
  - Dataset inventory generation with statistical summaries
215
 
216
- ## 🐳 Deployment Infrastructure
217
-
218
- ### Docker Configuration
219
-
220
- The `Dockerfile` (421 Bytes) implements **optimized containerization**:[^1_12]
221
-
222
- - **Base Image**: Python 3.13-slim for minimal attack surface
223
- - **System Dependencies**: Essential build tools and scientific libraries
224
- - **Health Monitoring**: HTTP endpoint checking for container wellness
225
- - **Caching Strategy**: Layered builds with dependency caching for faster rebuilds
226
-
227
  ### Dependency Management
228
 
229
  The `requirements.txt` specifies **core dependencies without version pinning**:[^1_12]
@@ -234,6 +184,36 @@ The `requirements.txt` specifies **core dependencies without version pinning**:[
234
  - **Visualization**: `matplotlib` for spectrum plotting
235
  - **API Framework**: `fastapi`, `uvicorn` for potential REST API expansion
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  ## 🧪 Testing Framework
238
 
239
  ### Test Infrastructure
@@ -244,12 +224,12 @@ The `tests/` directory implements **basic validation framework**:
244
  - **Preprocessing Tests**: Core pipeline functionality validation in `test_preprocessing.py`
245
  - **Limited Coverage**: Currently covers preprocessing functions only
246
 
247
- **Testing Gaps Identified:**
248
 
249
- - No model architecture unit tests
250
- - Missing integration tests for UI components
251
- - No performance benchmarking tests
252
- - Limited error handling validation
253
 
254
  ## 🔍 Security \& Quality Assessment
255
 
@@ -271,27 +251,11 @@ The `tests/` directory implements **basic validation framework**:
271
  - **Error Boundaries**: Multi-level exception handling with graceful degradation
272
  - **Logging**: Structured logging with appropriate severity levels
273
 
274
- ### Security Considerations
275
-
276
- **Current Protections:**
277
-
278
- - Input sanitization through strict parsing rules
279
- - No arbitrary code execution paths
280
- - Containerized deployment limiting attack surface
281
- - Session-based storage preventing data persistence attacks
282
-
283
- **Areas Requiring Enhancement:**
284
-
285
- - No explicit security headers in web responses
286
- - Basic authentication/authorization framework absent
287
- - File upload size limits not explicitly configured
288
- - No rate limiting mechanisms implemented
289
-
290
  ## 🚀 Extensibility Analysis
291
 
292
  ### Model Architecture Extensibility
293
 
294
- The **registry pattern enables seamless model addition**:[^1_5]
295
 
296
  1. **Implementation**: Create new model class with standardized interface
297
  2. **Registration**: Add to `models/registry.py` with factory function
@@ -344,72 +308,15 @@ The **registry pattern enables seamless model addition**:[^1_5]
344
  - Session state pruning for long-running sessions
345
  - Caching with content-based invalidation
346
 
347
- ## 🎯 Production Readiness Evaluation
348
-
349
- ### Strengths
350
-
351
- **Architecture Excellence:**
352
-
353
- - Clean separation of concerns with modular design
354
- - Production-grade error handling and logging
355
- - Intuitive user experience with real-time feedback
356
- - Scalable batch processing with progress tracking
357
- - Well-documented, type-hinted codebase
358
-
359
- **Operational Readiness:**
360
-
361
- - Containerized deployment with health checks
362
- - Comprehensive preprocessing validation
363
- - Multiple export formats for integration
364
- - Session-based results management
365
-
366
- ### Enhancement Opportunities
367
-
368
- **Testing Infrastructure:**
369
-
370
- - Expand unit test coverage beyond preprocessing
371
- - Implement integration tests for UI workflows
372
- - Add performance regression testing
373
- - Include security vulnerability scanning
374
-
375
- **Monitoring \& Observability:**
376
-
377
- - Application performance monitoring integration
378
- - User analytics and usage patterns tracking
379
- - Model performance drift detection
380
- - Resource utilization monitoring
381
-
382
- **Security Hardening:**
383
-
384
- - Implement proper authentication mechanisms
385
- - Add rate limiting for API endpoints
386
- - Configure security headers for web responses
387
- - Establish audit logging for sensitive operations
388
-
389
  ## 🔮 Strategic Development Roadmap
390
 
391
- Based on the documented roadmap in `README.md`, the platform targets three strategic expansion paths:[^1_13]
392
-
393
- **1. Multi-Model Dashboard Evolution**
394
-
395
- - Comparative model evaluation framework
396
- - Side-by-side performance reporting
397
- - Automated model retraining pipelines
398
- - Model versioning and rollback capabilities
399
-
400
- **2. Multi-Modal Input Support**
401
-
402
- - FTIR spectroscopy integration with dedicated preprocessing
403
- - Image-based polymer classification via computer vision
404
- - Cross-modal validation and ensemble methods
405
- - Unified preprocessing pipeline for multiple modalities
406
 
407
- **3. Enterprise Integration Features**
408
-
409
- - RESTful API development for programmatic access
410
- - Database integration for persistent storage
411
- - User authentication and authorization systems
412
- - Audit trails and compliance reporting
413
 
414
  ## 💼 Business Logic \& Scientific Workflow
415
 
@@ -424,7 +331,7 @@ Based on the documented roadmap in `README.md`, the platform targets three strat
424
 
425
  ### Scientific Applications
426
 
427
- **Research Use Cases:**[^1_13]
428
 
429
  - Material science polymer degradation studies
430
  - Recycling viability assessment for circular economy
@@ -434,7 +341,7 @@ Based on the documented roadmap in `README.md`, the platform targets three strat
434
 
435
  ### Data Workflow Architecture
436
 
437
- ```
438
  Input Validation → Spectrum Preprocessing → Model Inference →
439
  Confidence Analysis → Results Visualization → Export Options
440
  ```
@@ -475,10 +382,7 @@ The platform successfully bridges academic research and practical application, p
475
 
476
  **Risk Assessment:** Low - The codebase demonstrates mature engineering practices with appropriate validation and error handling for production deployment.
477
 
478
- **Recommendation:** This platform is ready for production deployment with minimal additional hardening, representing a solid foundation for polymer classification research and industrial applications.
479
- <span style="display:none">[^1_14][^1_15][^1_16][^1_17][^1_18]</span>
480
-
481
- <div style="text-align: center">⁂</div>
482
 
483
  ### EXTRA
484
 
@@ -529,22 +433,3 @@ The platform successfully bridges academic research and practical application, p
529
  Column 1 (Input): Contains the main st.radio for mode selection and the conditional logic to display the single file uploader, batch uploader, or sample selector. It also holds the "Run Analysis" and "Reset All" buttons.
530
  Column 2 (Results): Contains all the logic for displaying either the batch results or the detailed, tabbed results for a single file (Details, Technical, Explanation).
531
  ```
532
-
533
- [^1_1]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/tree/main
534
- [^1_2]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/tree/main/datasets
535
- [^1_3]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml
536
- [^1_4]: https://github.com/KLab-AI3/ml-polymer-recycling
537
- [^1_5]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/.gitignore
538
- [^1_6]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/blob/main/models/resnet_cnn.py
539
- [^1_7]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/multifile.py
540
- [^1_8]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/preprocessing.py
541
- [^1_9]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/audit.py
542
- [^1_10]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/results_manager.py
543
- [^1_11]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/blob/main/scripts/train_model.py
544
- [^1_12]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/requirements.txt
545
- [^1_13]: https://doi.org/10.1016/j.resconrec.2022.106718
546
- [^1_14]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/app.py
547
- [^1_15]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/Dockerfile
548
- [^1_16]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/errors.py
549
- [^1_17]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/confidence.py
550
- [^1_18]: https://ppl-ai-code-interpreter-files.s3.amazonaws.com/web/direct-files/9fd1eb2028a28085942cb82c9241b5ae/a25e2c38-813f-4d8b-89b3-713f7d24f1fe/3e70b172.md
 
2
 
3
  ## Executive Summary
4
 
5
+ This audit provides a technical inventory of the dev-jas/polymer-aging-ml repositorya modular machine learning platform for polymer degradation classification using Raman and FTIR spectroscopy. The system features robust error handling, multi-format batch processing, and persistent performance tracking, making it suitable for research, education, and industrial applications.
6
 
7
  ## 🏗️ System Architecture
8
 
9
  ### Core Infrastructure
10
 
11
+ - **Streamlit-based web app** (`app.py`) as the main interface
12
+ - **PyTorch** for deep learning
13
+ - **Docker** for deployment
14
+ - **SQLite** (`outputs/performance_tracking.db`) for performance metrics
15
+ - **Plugin-based model registry** for extensibility
16
+
17
+ ### Directory Structure
18
+
19
+ - **app.py**: Main Streamlit application
20
+ - **README.md**: Project documentation
21
+ - **Dockerfile**: Containerization (Python 3.13-slim)
22
+ - **requirements.txt**: Dependency management
23
+ - **models/**: Neural network architectures and registry
24
+ - **utils/**: Shared utilities (preprocessing, batch, results, performance, errors, confidence)
25
+ - **scripts/**: CLI tools for training, inference, data management
26
+ - **outputs/**: Model weights, inference results, performance DB
27
+ - **sample_data/**: Demo spectrum files
28
+ - **tests/**: Unit tests (PyTest)
29
+ - **datasets/**: Data storage
30
+ - **pages/**: Streamlit dashboard pages
 
 
31
 
32
  ## 🤖 Machine Learning Framework
33
 
34
+ ### Model Registry
35
 
36
+ Factory pattern in `models/registry.py` enables dynamic model selection:
37
 
38
  ```python
39
  _REGISTRY: Dict[str, Callable[[int], object]] = {
 
45
 
46
  ### Neural Network Architectures
47
 
48
+ The platform supports three architectures, offering diverse options for spectral analysis:
49
 
50
+ **Figure2CNN (Baseline Model):**
 
 
 
 
51
 
52
+ - Architecture: 4 convolutional layers (1→16→32→64→128), 3 fully connected layers (256→128→2).
53
+ - Performance: 94.80% accuracy, 94.30% F1-score (Raman-only).
54
+ - Parameters: ~500K, supports dynamic input handling.
55
 
56
+ **ResNet1D (Advanced Model):**
 
 
 
 
57
 
58
+ - Architecture: 3 residual blocks with 1D skip connections.
59
+ - Performance: 96.20% accuracy, 95.90% F1-score.
60
+ - Parameters: ~100K, efficient via global average pooling.
61
 
62
+ **ResNet18Vision (Experimental):**
63
+
64
+ - Architecture: 1D-adapted ResNet-18 with 4 layers (2 blocks each).
65
+ - Status: Under evaluation, ~11M parameters.
66
+ - Opportunity: Expand validation for broader spectral applications.
67
 
68
  ## 🔧 Data Processing Infrastructure
69
 
70
  ### Preprocessing Pipeline
71
 
72
+ The system implements a **modular preprocessing pipeline** in `utils/preprocessing.py` with five configurable stages:
 
73
  **1. Input Validation Framework:**
74
 
75
  - File format verification (`.txt` files exclusively)
 
78
  - Monotonic sequence verification for spectral consistency
79
  - NaN value detection and automatic rejection
80
 
81
+ **2. Core Processing Steps:**
82
 
83
  - **Linear Resampling**: Uniform grid interpolation to 500 points using `scipy.interpolate.interp1d`
84
  - **Baseline Correction**: Polynomial detrending (configurable degree, default=2)
85
  - **Savitzky-Golay Smoothing**: Noise reduction (window=11, order=2, configurable)
86
+ - **Min-Max Normalization**: Scaling to range with constant-signal protection
87
 
88
  ### Batch Processing Framework
89
 
90
+ The `utils/multifile.py` module (12.5 kB) provides **enterprise-grade batch processing** capabilities:
91
 
92
  - **Multi-File Upload**: Streamlit widget supporting simultaneous file selection
93
  - **Error-Tolerant Processing**: Individual file failures don't interrupt batch operations
 
117
 
118
  ### State Management System
119
 
120
+ The application employs **advanced session state management**:
121
 
122
  - Persistent state across Streamlit reruns using `st.session_state`
123
  - Intelligent caching with content-based hash keys for expensive operations
 
128
 
129
  ### Centralized Error Handling
130
 
131
+ The `utils/errors.py` module provides with **context-aware** logging and user-friendly error messages.
 
 
 
 
 
 
 
 
 
 
132
 
133
+ ### Performance Tracking System
134
 
135
+ The `utils/performance_tracker.py` module provides a robust system for logging and analyzing performance metrics.
 
 
 
136
 
137
+ - **Database Logging**: Persists metrics to a SQLite database.
138
+ - **Automated Tracking**: Uses a context manager to automatically track inference time, preprocessing time, and memory usage.
139
+ - **Dashboarding**: Includes functions to generate performance visualizations and summary statistics for the UI.
140
 
141
+ ### Enhanced Results Management
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ The `utils/results_manager.py` module enables comprehensive session and persistent results tracking.
144
 
145
+ - **In-Memory Storage**: Manages results for the current session.
146
+ - **Multi-Model Handling**: Aggregates results from multiple models for comparison.
147
+ - **Export Capabilities**: Exports results to CSV and JSON.
148
+ - **Statistical Analysis**: Calculates accuracy, confidence, and other metrics.
149
 
150
  ## 📜 Command-Line Interface
151
 
 
166
  - Deterministic CUDA operations when GPU available
167
  - Standardized train/validation splitting methodology
168
 
 
 
 
 
 
 
 
 
 
 
 
169
  ### Data Utilities
170
 
171
  **File Discovery System:**
 
174
  - Filename-based labeling convention (`sta-*` = stable, `wea-*` = weathered)
175
  - Dataset inventory generation with statistical summaries
176
 
 
 
 
 
 
 
 
 
 
 
 
177
  ### Dependency Management
178
 
179
  The `requirements.txt` specifies **core dependencies without version pinning**:[^1_12]
 
184
  - **Visualization**: `matplotlib` for spectrum plotting
185
  - **API Framework**: `fastapi`, `uvicorn` for potential REST API expansion
186
 
187
+ ## 🐳 Deployment Infrastructure
188
+
189
+ ### Docker Configuration
190
+
191
+ The Dockerfile uses Python 3.13-slim for efficient containerization:
192
+
193
+ - Includes essential build tools and scientific libraries.
194
+ - Supports health checks for container wellness.
195
+ - **Roadmap**: Implement multi-stage builds and environment variables for streamlined deployments.
196
+
197
+ ### Confidence Analysis System
198
+
199
+ The `utils/confidence.py` module provides **scientific confidence metrics**
200
+
201
+ **Softmax-Based Confidence:**
202
+
203
+ - Normalized probability distributions from model logits
204
+ - Three-tier confidence levels: HIGH (≥80%), MEDIUM (≥60%), LOW (<60%)
205
+ - Color-coded visual indicators with emoji representations
206
+ - Legacy compatibility with logit margin calculations
207
+
208
+ ### Session Results Management
209
+
210
+ The `utils/results_manager.py` module (8.16 kB) enables **comprehensive session tracking**:
211
+
212
+ - **In-Memory Storage**: Session-wide results persistence
213
+ - **Export Capabilities**: CSV and JSON download with timestamp formatting
214
+ - **Statistical Analysis**: Automatic accuracy calculation when ground truth available
215
+ - **Data Integrity**: Results survive page refreshes within session boundaries
216
+
217
  ## 🧪 Testing Framework
218
 
219
  ### Test Infrastructure
 
224
  - **Preprocessing Tests**: Core pipeline functionality validation in `test_preprocessing.py`
225
  - **Limited Coverage**: Currently covers preprocessing functions only
226
 
227
+ **Testing Coming Soon:**
228
 
229
+ - Add model architecture unit tests
230
+ - Integration tests for UI components
231
+ - Performance benchmarking tests
232
+ - Improved error handling validation
233
 
234
  ## 🔍 Security \& Quality Assessment
235
 
 
251
  - **Error Boundaries**: Multi-level exception handling with graceful degradation
252
  - **Logging**: Structured logging with appropriate severity levels
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  ## 🚀 Extensibility Analysis
255
 
256
  ### Model Architecture Extensibility
257
 
258
+ The **registry pattern enables seamless model addition**:
259
 
260
  1. **Implementation**: Create new model class with standardized interface
261
  2. **Registration**: Add to `models/registry.py` with factory function
 
308
  - Session state pruning for long-running sessions
309
  - Caching with content-based invalidation
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  ## 🔮 Strategic Development Roadmap
312
 
313
+ The project roadmap has been updated to reflect recent progress:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ - [x] **FTIR Support**: Modular integration of FTIR spectroscopy is complete.
316
+ - [x] **Multi-Model Dashboard**: A model comparison tab has been implemented.
317
+ - [ ] **Image-based Inference**: Future work to include image-based polymer classification.
318
+ - [x] **Performance Tracking**: A performance tracking dashboard has been implemented.
319
+ - [ ] **Enterprise Integration**: Future work to include a RESTful API and more advanced database integration.
 
320
 
321
  ## 💼 Business Logic \& Scientific Workflow
322
 
 
331
 
332
  ### Scientific Applications
333
 
334
+ **Research Use Cases:**
335
 
336
  - Material science polymer degradation studies
337
  - Recycling viability assessment for circular economy
 
341
 
342
  ### Data Workflow Architecture
343
 
344
+ ```text
345
  Input Validation → Spectrum Preprocessing → Model Inference →
346
  Confidence Analysis → Results Visualization → Export Options
347
  ```
 
382
 
383
  **Risk Assessment:** Low - The codebase demonstrates mature engineering practices with appropriate validation and error handling for production deployment.
384
 
385
+ **Recommendation:** This platform is ready for production deployment, representing a solid foundation for polymer classification research and industrial applications.
 
 
 
386
 
387
  ### EXTRA
388
 
 
433
  Column 1 (Input): Contains the main st.radio for mode selection and the conditional logic to display the single file uploader, batch uploader, or sample selector. It also holds the "Run Analysis" and "Reset All" buttons.
434
  Column 2 (Results): Contains all the logic for displaying either the batch results or the detailed, tabbed results for a single file (Details, Technical, Explanation).
435
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: AI Polymer Classification
3
  emoji: 🔬
4
  colorFrom: indigo
5
  colorTo: yellow
@@ -8,19 +8,21 @@ app_file: app.py
8
  pinned: false
9
  license: apache-2.0
10
  ---
11
- ## AI-Driven Polymer Aging Prediction and Classification (v0.1)
12
 
13
- This web application classifies the degradation state of polymers using Raman spectroscopy and deep learning.
14
 
15
- It was developed as part of the AIRE 2025 internship project at the Imageomics Institute and demonstrates a prototype pipeline for evaluating multiple convolutional neural networks (CNNs) on spectral data.
 
16
 
17
  ---
18
 
19
  ## 🧪 Current Scope
20
 
21
- - 🔬 **Modality**: Raman spectroscopy (.txt)
22
- - 🧠 **Model**: Figure2CNN (baseline)
 
23
  - 📊 **Task**: Binary classification — Stable vs Weathered polymers
 
24
  - 🛠️ **Architecture**: PyTorch + Streamlit
25
 
26
  ---
@@ -29,84 +31,76 @@ It was developed as part of the AIRE 2025 internship project at the Imageomics I
29
 
30
  - [x] Inference from Raman `.txt` files
31
  - [x] Model selection (Figure2CNN, ResNet1D)
 
 
 
32
  - [ ] Add more trained CNNs for comparison
33
- - [ ] FTIR support (modular integration planned)
34
  - [ ] Image-based inference (future modality)
 
35
 
36
  ---
37
 
38
  ## 🧭 How to Use
39
 
40
- 1. Upload a Raman spectrum `.txt` file (or select a sample)
41
- 2. Choose a model from the sidebar
42
- 3. Run analysis
43
- 4. View prediction, logits, and technical information
44
 
45
- Supported input:
46
 
47
- - Plaintext `.txt` files with 1–2 columns
48
- - Space- or comma-separated
49
- - Comment lines (#) are ignored
50
- - Automatically resampled to 500 points
51
 
52
- ---
53
-
54
- ## Contributors
55
 
56
- 👨‍🏫 Dr. Sanmukh Kuppannagari (Mentor)
57
- 👨‍🏫 Dr. Metin Karailyan (Mentor)
58
- 👨‍💻 Jaser Hasan (Author/Developer)
59
 
60
- ## 🧠 Model Credit
 
 
 
61
 
62
- Baseline model inspired by:
63
 
64
- Neo, E.R.K., Low, J.S.C., Goodship, V., Debattista, K. (2023).
65
- *Deep learning for chemometric analysis of plastic spectral data from infrared and Raman databases.*
66
- _Resources, Conservation & Recycling_, **188**, 106718.
67
- [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718)
68
 
69
  ---
70
 
71
- ## 🔗 Links
72
-
73
- - 💻 **Live App**: [Hugging Face Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
74
- - 📂 **GitHub Repo**: [ml-polymer-recycling](https://github.com/KLab-AI3/ml-polymer-recycling)
75
-
76
-
77
- ## 🎯 Strategic Expansion Objectives (Roadmap)
78
-
79
- **The roadmap defines three major expansion paths designed to broaden the system’s capabilities and impact:**
80
-
81
- 1. **Model Expansion: Multi-Model Dashboard**
82
-
83
- > The dashboard will evolve into a hub for multiple model architectures rather than being tied to a single baseline. Planned work includes:
84
 
85
- - **Retraining & Fine-Tuning**: Incorporating publicly available vision models and retraining them with the polymer dataset.
86
- - **Model Registry**: Automatically detecting available .pth weights and exposing them in the dashboard for easy selection.
87
- - **Side-by-Side Reporting**: Running comparative experiments and reporting each model’s accuracy and diagnostics in a standardized format.
88
- - **Reproducible Integration**: Maintaining modular scripts and pipelines so each model’s results can be replicated without conflict.
89
 
90
- This ensures flexibility for future research and transparency in performance comparisons.
91
 
92
- 2. **Image Input Modality**
93
 
94
- > The system will support classification on images as an additional modality, extending beyond spectra. Key features will include:
 
 
 
95
 
96
- - **Upload Support**: Users can upload single images or batches directly through the dashboard.
97
- - **Multi-Model Execution**: Selected models from the registry can be applied to all uploaded images simultaneously.
98
- - **Batch Results**: Output will be returned in a structured, accessible way, showing both individual predictions and aggregate statistics.
99
- - **Enhanced Feedback**: Outputs will include predicted class, model confidence, and potentially annotated image previews.
100
 
101
- This expands the system toward a multi-modal framework, supporting broader research workflows.
102
 
103
- 3. **FTIR Dataset Integration**
 
104
 
105
- > Although previously deferred, FTIR support will be added back in a modular, distinct fashion. Planned steps are:
106
 
107
- - **Dedicated Preprocessing**: Tailored scripts to handle FTIR-specific signal characteristics (multi-layer handling, baseline correction, normalization).
108
- - **Architecture Compatibility**: Ensuring existing and retrained models can process FTIR data without mixing it with Raman workflows.
109
- - **UI Integration**: Introducing FTIR as a separate option in the modality selector, keeping Raman, Image, and FTIR workflows clearly delineated.
110
- - **Phased Development**: Implementation details to be refined during meetings to ensure scientific rigor.
111
 
112
- This guarantees FTIR becomes a supported modality without undermining the validated Raman foundation.
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AI Polymer Classification (Raman & FTIR)
3
  emoji: 🔬
4
  colorFrom: indigo
5
  colorTo: yellow
 
8
  pinned: false
9
  license: apache-2.0
10
  ---
 
11
 
12
+ ## AI-Driven Polymer Aging Prediction and Classification (v0.1)
13
 
14
+ This web application classifies the degradation state of polymers using **Raman and FTIR spectroscopy** and deep learning.
15
+ It is a prototype pipeline for evaluating multiple convolutional neural networks (CNNs) on spectral data.
16
 
17
  ---
18
 
19
  ## 🧪 Current Scope
20
 
21
+ - 🔬 **Modalities**: Raman & FTIR spectroscopy
22
+ - 💾 **Input Formats**: `.txt`, `.csv`, `.json` (with auto-detection)
23
+ - 🧠 **Models**: Figure2CNN (baseline), ResNet1D, ResNet18Vision
24
  - 📊 **Task**: Binary classification — Stable vs Weathered polymers
25
+ - 🚀 **Features**: Multi-model comparison, performance tracking dashboard
26
  - 🛠️ **Architecture**: PyTorch + Streamlit
27
 
28
  ---
 
31
 
32
  - [x] Inference from Raman `.txt` files
33
  - [x] Model selection (Figure2CNN, ResNet1D)
34
+ - [x] **FTIR support** (modular integration complete)
35
+ - [x] **Multi-model comparison dashboard**
36
+ - [x] **Performance tracking dashboard**
37
  - [ ] Add more trained CNNs for comparison
 
38
  - [ ] Image-based inference (future modality)
39
+ - [ ] RESTful API for programmatic access
40
 
41
  ---
42
 
43
  ## 🧭 How to Use
44
 
45
+ The application provides three main analysis modes in a tabbed interface:
 
 
 
46
 
47
+ 1. **Standard Analysis**:
48
 
49
+ - Upload a single spectrum file (`.txt`, `.csv`, `.json`) or a batch of files.
50
+ - Choose a model from the sidebar.
51
+ - Run analysis and view the prediction, confidence, and technical details.
 
52
 
53
+ 2. **Model Comparison**:
 
 
54
 
55
+ - Upload a single spectrum file.
56
+ - The app runs inference with all available models.
57
+ - View a side-by-side comparison of the models' predictions and performance.
58
 
59
+ 3. **Performance Tracking**:
60
+ - Explore a dashboard with visualizations of historical performance data.
61
+ - Compare model performance across different metrics.
62
+ - Export performance data in CSV or JSON format.
63
 
64
+ ### Supported Input
65
 
66
+ - Plaintext `.txt`, `.csv`, or `.json` files.
67
+ - Data can be space-, comma-, or tab-separated.
68
+ - Comment lines (`#`, `%`) are ignored.
69
+ - The app automatically detects the file format and resamples the data to a standard length.
70
 
71
  ---
72
 
73
+ ## Contributors
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ Dr. Sanmukh Kuppannagari (Mentor)
76
+ Dr. Metin Karailyan (Mentor)
77
+ Jaser Hasan (Author/Developer)
 
78
 
79
+ ## Model Credit
80
 
81
+ Baseline model inspired by:
82
 
83
+ Neo, E.R.K., Low, J.S.C., Goodship, V., Debattista, K. (2023).
84
+ _Deep learning for chemometric analysis of plastic spectral data from infrared and Raman databases._
85
+ _Resources, Conservation & Recycling_, **188**, 106718.
86
+ [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718)
87
 
88
+ ---
 
 
 
89
 
90
+ ## 🔗 Links
91
 
92
+ - **Live App**: [Hugging Face Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
93
+ - **GitHub Repo**: [ml-polymer-recycling](https://github.com/KLab-AI3/ml-polymer-recycling)
94
 
95
+ ## 🚀 Technical Architecture
96
 
97
+ **The system is built on a modular, production-ready architecture designed for scalability and maintainability.**
 
 
 
98
 
99
+ - **Frontend**: A Streamlit-based web application (`app.py`) provides an interactive, multi-tab user interface.
100
+ - **Backend**: PyTorch handles all deep learning operations, including model loading and inference.
101
+ - **Model Management**: A registry pattern (`models/registry.py`) allows for dynamic model loading and easy integration of new architectures.
102
+ - **Data Processing**: A robust, modality-aware preprocessing pipeline (`utils/preprocessing.py`) ensures data integrity and standardization for both Raman and FTIR data.
103
+ - **Multi-Format Parsing**: The `utils/multifile.py` module handles parsing of `.txt`, `.csv`, and `.json` files.
104
+ - **Results Management**: The `utils/results_manager.py` module manages session and persistent results, with support for multi-model comparison and data export.
105
+ - **Performance Tracking**: The `utils/performance_tracker.py` module logs performance metrics to a SQLite database and provides a dashboard for visualization.
106
+ - **Deployment**: The application is containerized using Docker (`Dockerfile`) for reproducible, cross-platform execution.
app.py CHANGED
@@ -8,6 +8,8 @@ from modules.ui_components import (
8
  render_sidebar,
9
  render_results_column,
10
  render_input_column,
 
 
11
  load_css,
12
  )
13
 
@@ -27,14 +29,28 @@ def main():
27
  load_css("static/style.css")
28
  init_session_state()
29
 
30
- # Render UI components
31
  render_sidebar()
32
 
33
- col1, col2 = st.columns([1, 1.35], gap="small")
34
- with col1:
35
- render_input_column()
36
- with col2:
37
- render_results_column()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  if __name__ == "__main__":
 
8
  render_sidebar,
9
  render_results_column,
10
  render_input_column,
11
+ render_comparison_tab,
12
+ render_performance_tab,
13
  load_css,
14
  )
15
 
 
29
  load_css("static/style.css")
30
  init_session_state()
31
 
 
32
  render_sidebar()
33
 
34
+ # Create main tabs for difference analysis modes
35
+ tab1, tab2, tab3 = st.tabs(
36
+ ["Standard Analysis", "Model Comparison", "Peformance Tracking"]
37
+ )
38
+
39
+ with tab1:
40
+ # Standard single-model analysis
41
+ col1, col2 = st.columns([1, 1.35], gap="small")
42
+ with col1:
43
+ render_input_column()
44
+ with col2:
45
+ render_results_column()
46
+
47
+ with tab2:
48
+ # Multi-model comparison interface
49
+ render_comparison_tab()
50
+
51
+ with tab3:
52
+ # Performance tracking interface
53
+ render_performance_tab()
54
 
55
 
56
  if __name__ == "__main__":
core_logic.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import streamlit as st
11
  from pathlib import Path
12
  from config import SAMPLE_DATA_DIR
 
13
 
14
 
15
  def label_file(filename: str) -> int:
@@ -89,16 +90,26 @@ def cleanup_memory():
89
 
90
  @st.cache_data
91
  def run_inference(y_resampled, model_choice, _cache_key=None):
92
- """Run model inference and cache results"""
 
 
 
93
  model, model_loaded = load_model(model_choice)
94
  if not model_loaded:
95
  return None, None, None, None, None
96
 
 
 
 
97
  input_tensor = (
98
  torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
99
  )
 
 
100
  start_time = time.time()
101
- model.eval()
 
 
102
  with torch.no_grad():
103
  if model is None:
104
  raise ValueError(
@@ -108,11 +119,51 @@ def run_inference(y_resampled, model_choice, _cache_key=None):
108
  prediction = torch.argmax(logits, dim=1).item()
109
  logits_list = logits.detach().numpy().tolist()[0]
110
  probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
 
111
  inference_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  cleanup_memory()
113
  return prediction, logits_list, probs, inference_time, logits
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
116
  @st.cache_data
117
  def get_sample_files():
118
  """Get list of sample files if available"""
 
10
  import streamlit as st
11
  from pathlib import Path
12
  from config import SAMPLE_DATA_DIR
13
+ from datetime import datetime
14
 
15
 
16
  def label_file(filename: str) -> int:
 
90
 
91
  @st.cache_data
92
  def run_inference(y_resampled, model_choice, _cache_key=None):
93
+ """Run model inference and cache results with performance tracking"""
94
+ from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
95
+ from datetime import datetime
96
+
97
  model, model_loaded = load_model(model_choice)
98
  if not model_loaded:
99
  return None, None, None, None, None
100
 
101
+ # Performance tracking setup
102
+ tracker = get_performance_tracker()
103
+
104
  input_tensor = (
105
  torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
106
  )
107
+
108
+ # Track inference performance
109
  start_time = time.time()
110
+ start_memory = _get_memory_usage()
111
+
112
+ model.eval() # type: ignore
113
  with torch.no_grad():
114
  if model is None:
115
  raise ValueError(
 
119
  prediction = torch.argmax(logits, dim=1).item()
120
  logits_list = logits.detach().numpy().tolist()[0]
121
  probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
122
+
123
  inference_time = time.time() - start_time
124
+ end_memory = _get_memory_usage()
125
+ memory_usage = max(end_memory - start_memory, 0)
126
+
127
+ # Log performance metrics
128
+ try:
129
+ modality = st.session_state.get("modality_select", "raman")
130
+ confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
131
+
132
+ metrics = PerformanceMetrics(
133
+ model_name=model_choice,
134
+ prediction_time=inference_time,
135
+ preprocessing_time=0.0, # Will be updated by calling function if available
136
+ total_time=inference_time,
137
+ memory_usage_mb=memory_usage,
138
+ accuracy=None, # Will be updated if ground truth is available
139
+ confidence=confidence,
140
+ timestamp=datetime.now().isoformat(),
141
+ input_size=(
142
+ len(y_resampled) if hasattr(y_resampled, "__len__") else TARGET_LEN
143
+ ),
144
+ modality=modality,
145
+ )
146
+
147
+ tracker.log_performance(metrics)
148
+ except (AttributeError, ValueError, KeyError) as e:
149
+ # Don't fail inference if performance tracking fails
150
+ print(f"Performance tracking failed: {e}")
151
+
152
  cleanup_memory()
153
  return prediction, logits_list, probs, inference_time, logits
154
 
155
 
156
+ def _get_memory_usage() -> float:
157
+ """Get current memory usage in MB"""
158
+ try:
159
+ import psutil
160
+
161
+ process = psutil.Process()
162
+ return process.memory_info().rss / 1024 / 1024 # Convert to MB
163
+ except ImportError:
164
+ return 0.0 # psutil not available
165
+
166
+
167
  @st.cache_data
168
  def get_sample_files():
169
  """Get list of sample files if available"""
models/registry.py CHANGED
@@ -1,35 +1,138 @@
1
  # models/registry.py
2
- from typing import Callable, Dict
3
  from models.figure2_cnn import Figure2CNN
4
  from models.resnet_cnn import ResNet1D
5
- from models.resnet18_vision import ResNet18Vision
6
 
7
  # Internal registry of model builders keyed by short name.
8
  _REGISTRY: Dict[str, Callable[[int], object]] = {
9
  "figure2": lambda L: Figure2CNN(input_length=L),
10
  "resnet": lambda L: ResNet1D(input_length=L),
11
- "resnet18vision": lambda L: ResNet18Vision(input_length=L)
12
  }
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def choices():
15
  """Return the list of available model keys."""
16
  return list(_REGISTRY.keys())
17
 
 
 
 
 
 
 
18
  def build(name: str, input_length: int):
19
  """Instantiate a model by short name with the given input length."""
20
  if name not in _REGISTRY:
21
  raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
22
  return _REGISTRY[name](input_length)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def spec(name: str):
25
  """Return expected input length and number of classes for a model key."""
26
- if name == "figure2":
27
- return {"input_length": 500, "num_classes": 2}
28
- if name == "resnet":
29
- return {"input_length": 500, "num_classes": 2}
30
- if name == "resnet18vision":
31
- return {"input_length": 500, "num_classes": 2}
32
- raise KeyError(f"Unknown model '{name}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
- __all__ = ["choices", "build"]
 
 
 
 
 
 
 
 
 
 
 
1
  # models/registry.py
2
+ from typing import Callable, Dict, List, Any
3
  from models.figure2_cnn import Figure2CNN
4
  from models.resnet_cnn import ResNet1D
5
+ from models.resnet18_vision import ResNet18Vision
6
 
7
  # Internal registry of model builders keyed by short name.
8
  _REGISTRY: Dict[str, Callable[[int], object]] = {
9
  "figure2": lambda L: Figure2CNN(input_length=L),
10
  "resnet": lambda L: ResNet1D(input_length=L),
11
+ "resnet18vision": lambda L: ResNet18Vision(input_length=L),
12
  }
13
 
14
+ # Model specifications with metadata for enhanced features
15
+ _MODEL_SPECS: Dict[str, Dict[str, Any]] = {
16
+ "figure2": {
17
+ "input_length": 500,
18
+ "num_classes": 2,
19
+ "description": "Figure 2 baseline custom implemetation",
20
+ "modalities": ["raman", "ftir"],
21
+ "citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
22
+ },
23
+ "resnet": {
24
+ "input_length": 500,
25
+ "num_classes": 2,
26
+ "description": "(Residual Network) uses skip connections to train much deeper networks",
27
+ "modalities": ["raman", "ftir"],
28
+ "citation": "Custom ResNet implementation",
29
+ },
30
+ "resnet18vision": {
31
+ "input_length": 500,
32
+ "num_classes": 2,
33
+ "description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
34
+ "modalities": ["raman", "ftir"],
35
+ "citation": "ResNet18 Vision adaptation",
36
+ },
37
+ }
38
+
39
+ # Placeholder for future model expansions
40
+ _FUTURE_MODELS = {
41
+ "densenet1d": {
42
+ "description": "DenseNet1D for spectroscopy (placeholder)",
43
+ "status": "planned",
44
+ },
45
+ "ensemble_cnn": {
46
+ "description": "Ensemble of CNN variants (placeholder)",
47
+ "status": "planned",
48
+ },
49
+ }
50
+
51
+
52
  def choices():
53
  """Return the list of available model keys."""
54
  return list(_REGISTRY.keys())
55
 
56
+
57
+ def planned_models():
58
+ """Return the list of planned future model keys."""
59
+ return list(_FUTURE_MODELS.keys())
60
+
61
+
62
  def build(name: str, input_length: int):
63
  """Instantiate a model by short name with the given input length."""
64
  if name not in _REGISTRY:
65
  raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
66
  return _REGISTRY[name](input_length)
67
 
68
+
69
+ def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
70
+ """Nuild multiple models for comparison."""
71
+ models = {}
72
+ for name in names:
73
+ if name in _REGISTRY:
74
+ models[name] = build(name, input_length)
75
+ else:
76
+ raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
77
+ return models
78
+
79
+
80
+ def register_model(
81
+ name: str, builder: Callable[[int], object], spec: Dict[str, Any]
82
+ ) -> None:
83
+ """Dynamically register a new model."""
84
+ if name in _REGISTRY:
85
+ raise ValueError(f"Model '{name}' already registered.")
86
+ if not callable(builder):
87
+ raise TypeError("Builder must be a callable that accepts an integer argument.")
88
+ _REGISTRY[name] = builder
89
+ _MODEL_SPECS[name] = spec
90
+
91
+
92
  def spec(name: str):
93
  """Return expected input length and number of classes for a model key."""
94
+ if name in _MODEL_SPECS:
95
+ return _MODEL_SPECS[name].copy()
96
+ raise KeyError(f"Unknown model '{name}'. Available: {choices()}")
97
+
98
+
99
+ def get_model_info(name: str) -> Dict[str, Any]:
100
+ """Get comprehensive model information including metadata."""
101
+ if name in _MODEL_SPECS:
102
+ return _MODEL_SPECS[name].copy()
103
+ elif name in _FUTURE_MODELS:
104
+ return _FUTURE_MODELS[name].copy()
105
+ else:
106
+ raise KeyError(f"Unknown model '{name}'")
107
+
108
+
109
+ def models_for_modality(modality: str) -> List[str]:
110
+ """Get list of models that support a specific modality."""
111
+ compatible = []
112
+ for name, spec_info in _MODEL_SPECS.items():
113
+ if modality in spec_info.get("modalities", []):
114
+ compatible.append(name)
115
+ return compatible
116
+
117
+
118
+ def validate_model_list(names: List[str]) -> List[str]:
119
+ """Validate and return list of available models from input list."""
120
+ available = choices()
121
+ valid_models = []
122
+ for name in names:
123
+ if name is available:
124
+ valid_models.append(name)
125
+ return valid_models
126
 
127
 
128
+ __all__ = [
129
+ "choices",
130
+ "build",
131
+ "spec",
132
+ "build_multiple",
133
+ "register_model",
134
+ "get_model_info",
135
+ "models_for_modality",
136
+ "validate_model_list",
137
+ "planned_models",
138
+ ]
modules/ui_components.py CHANGED
@@ -13,9 +13,9 @@ from modules.callbacks import (
13
  on_model_change,
14
  on_input_mode_change,
15
  on_sample_change,
 
16
  reset_ephemeral_state,
17
  log_message,
18
- clear_batch_results,
19
  )
20
  from core_logic import (
21
  get_sample_files,
@@ -24,7 +24,6 @@ from core_logic import (
24
  parse_spectrum_data,
25
  label_file,
26
  )
27
- from modules.callbacks import reset_results
28
  from utils.results_manager import ResultsManager
29
  from utils.confidence import calculate_softmax_confidence
30
  from utils.multifile import process_multiple_files, display_batch_results
@@ -41,7 +40,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
41
  """Create spectrum visualization plot"""
42
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
43
 
44
- # == Raw spectrum ==
45
  ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
46
  ax[0].set_title("Raw Input Spectrum")
47
  ax[0].set_xlabel("Wavenumber (cm⁻¹)")
@@ -49,7 +48,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
49
  ax[0].grid(True, alpha=0.3)
50
  ax[0].legend()
51
 
52
- # == Resampled spectrum ==
53
  ax[1].plot(
54
  x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
55
  )
@@ -60,7 +59,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
60
  ax[1].legend()
61
 
62
  fig.tight_layout()
63
- # == Convert to image ==
64
  buf = io.BytesIO()
65
  plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
66
  buf.seek(0)
@@ -69,6 +68,9 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
69
  return Image.open(buf)
70
 
71
 
 
 
 
72
  def render_confidence_progress(
73
  probs: np.ndarray,
74
  labels: list[str] = ["Stable", "Weathered"],
@@ -114,7 +116,10 @@ def render_confidence_progress(
114
  st.markdown("")
115
 
116
 
117
- def render_kv_grid(d: dict = {}, ncols: int = 2):
 
 
 
118
  if d is None:
119
  d = {}
120
  if not d:
@@ -126,6 +131,9 @@ def render_kv_grid(d: dict = {}, ncols: int = 2):
126
  st.caption(f"**{k}:** {v}")
127
 
128
 
 
 
 
129
  def render_model_meta(model_choice: str):
130
  info = MODEL_CONFIG.get(model_choice, {})
131
  emoji = info.get("emoji", "")
@@ -143,6 +151,9 @@ def render_model_meta(model_choice: str):
143
  st.caption(desc)
144
 
145
 
 
 
 
146
  def get_confidence_description(logit_margin):
147
  """Get human-readable confidence description"""
148
  if logit_margin > 1000:
@@ -155,13 +166,35 @@ def get_confidence_description(logit_margin):
155
  return "LOW", "🔴"
156
 
157
 
 
 
 
158
  def render_sidebar():
159
  with st.sidebar:
160
  # Header
161
  st.header("AI-Driven Polymer Classification")
162
  st.caption(
163
- "Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.1"
 
 
 
 
 
 
 
 
 
 
164
  )
 
 
 
 
 
 
 
 
 
165
  model_labels = [
166
  f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
167
  ]
@@ -173,10 +206,10 @@ def render_sidebar():
173
  )
174
  model_choice = selected_label.split(" ", 1)[1]
175
 
176
- # ===Compact metadata directly under dropdown===
177
  render_model_meta(model_choice)
178
 
179
- # ===Collapsed info to reduce clutter===
180
  with st.expander("About This App", icon=":material/info:", expanded=False):
181
  st.markdown(
182
  """
@@ -184,8 +217,9 @@ def render_sidebar():
184
 
185
  **Purpose**: Classify polymer degradation using AI<br>
186
  **Input**: Raman spectroscopy .txt files<br>
187
- **Models**: CNN architectures for binary classification<br>
188
- **Next**: More trained CNNs in evaluation pipeline<br>
 
189
 
190
 
191
  **Contributors**<br>
@@ -207,11 +241,7 @@ def render_sidebar():
207
  )
208
 
209
 
210
- # col1 goes here
211
-
212
- # In modules/ui_components.py
213
-
214
-
215
  def render_input_column():
216
  st.markdown("##### Data Input")
217
 
@@ -224,22 +254,20 @@ def render_input_column():
224
  )
225
 
226
  # == Input Mode Logic ==
227
- # ... (The if/elif/else block for Upload, Batch, and Sample modes remains exactly the same) ...
228
- # ==Upload tab==
229
  if mode == "Upload File":
230
  upload_key = st.session_state["current_upload_key"]
231
  up = st.file_uploader(
232
- "Upload Raman spectrum (.txt)",
233
- type="txt",
234
- help="Upload a text file with wavenumber and intensity columns",
235
  key=upload_key, # ← versioned key
236
  )
237
 
238
- # ==Process change immediately (no on_change; simpler & reliable)==
239
  if up is not None:
240
  raw = up.read()
241
  text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
242
- # == only reparse if its a different file|source ==
243
  if (
244
  st.session_state.get("filename") != getattr(up, "name", None)
245
  or st.session_state.get("input_source") != "upload"
@@ -255,23 +283,20 @@ def render_input_column():
255
  st.session_state["status_type"] = "success"
256
  reset_results("New file uploaded")
257
 
258
- # ==Batch Upload tab==
259
  elif mode == "Batch Upload":
260
  st.session_state["batch_mode"] = True
261
- # --- START: BUG 1 & 3 FIX ---
262
  # Use a versioned key to ensure the file uploader resets properly.
263
  batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
264
  uploaded_files = st.file_uploader(
265
- "Upload multiple Raman spectrum files (.txt)",
266
- type="txt",
267
  accept_multiple_files=True,
268
- help="Upload one or more text files with wavenumber and intensity columns.",
269
  key=batch_upload_key,
270
  )
271
- # --- END: BUG 1 & 3 FIX ---
272
 
273
  if uploaded_files:
274
- # --- START: Bug 1 Fix ---
275
  # Use a dictionary to keep only unique files based on name and size
276
  unique_files = {(file.name, file.size): file for file in uploaded_files}
277
  unique_file_list = list(unique_files.values())
@@ -281,9 +306,7 @@ def render_input_column():
281
 
282
  # Optionally, inform the user that duplicates were removed
283
  if num_uploaded > num_unique:
284
- st.info(
285
- f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed."
286
- )
287
 
288
  # Use the unique list
289
  st.session_state["batch_files"] = unique_file_list
@@ -291,7 +314,6 @@ def render_input_column():
291
  f"{num_unique} ready for batch analysis"
292
  )
293
  st.session_state["status_type"] = "success"
294
- # --- END: Bug 1 Fix ---
295
  else:
296
  st.session_state["batch_files"] = []
297
  # This check prevents resetting the status if files are already staged
@@ -301,7 +323,7 @@ def render_input_column():
301
  )
302
  st.session_state["status_type"] = "info"
303
 
304
- # ==Sample tab==
305
  elif mode == "Sample Data":
306
  st.session_state["batch_mode"] = False
307
  sample_files = get_sample_files()
@@ -330,9 +352,6 @@ def render_input_column():
330
  else:
331
  st.info(msg)
332
 
333
- # --- DE-NESTED LOGIC STARTS HERE ---
334
- # This code now runs on EVERY execution, guaranteeing the buttons will appear.
335
-
336
  # Safely get model choice from session state
337
  model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
338
  model = load_model(model_choice)
@@ -388,7 +407,7 @@ def render_input_column():
388
  st.error(f"Error processing spectrum data: {e}")
389
 
390
 
391
- # col2 goes here
392
 
393
 
394
  def render_results_column():
@@ -410,7 +429,7 @@ def render_results_column():
410
  filename = st.session_state.get("filename", "Unknown")
411
 
412
  if all(v is not None for v in [x_raw, y_raw, y_resampled]):
413
- # ===Run inference===
414
  if y_resampled is None:
415
  raise ValueError(
416
  "y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
@@ -437,14 +456,14 @@ def render_results_column():
437
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
438
  )
439
 
440
- # ===Get ground truth===
441
  true_label_idx = label_file(filename)
442
  true_label_str = (
443
  LABEL_MAP.get(true_label_idx, "Unknown")
444
  if true_label_idx is not None
445
  else "Unknown"
446
  )
447
- # ===Get prediction===
448
  predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
449
 
450
  # Enhanced confidence calculation
@@ -455,7 +474,7 @@ def render_results_column():
455
  )
456
  confidence_desc = confidence_level
457
  else:
458
- # Fallback to legace method
459
  logit_margin = abs(
460
  (logits_list[0] - logits_list[1])
461
  if logits_list is not None and len(logits_list) >= 2
@@ -487,7 +506,7 @@ def render_results_column():
487
  },
488
  )
489
 
490
- # ===Precompute Stats===
491
  model_choice = (
492
  st.session_state.get("model_select", "").split(" ", 1)[1]
493
  if "model_select" in st.session_state
@@ -505,7 +524,6 @@ def render_results_column():
505
  if os.path.exists(model_path)
506
  else "N/A"
507
  )
508
- # Removed unused variable 'input_tensor'
509
 
510
  start_render = time.time()
511
 
@@ -590,17 +608,13 @@ def render_results_column():
590
  """,
591
  unsafe_allow_html=True,
592
  )
593
- # --- END: CONSOLIDATED CONFIDENCE ANALYSIS ---
594
 
595
  st.divider()
596
 
597
- # --- START: CLEAN METADATA FOOTER ---
598
- # Secondary info is now a clean, single-line caption
599
  st.caption(
600
  f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
601
  )
602
- # --- END: CLEAN METADATA FOOTER ---
603
-
604
  st.markdown("</div>", unsafe_allow_html=True)
605
 
606
  elif active_tab == "Technical":
@@ -918,7 +932,7 @@ def render_results_column():
918
  """
919
  )
920
  else:
921
- # ===Getting Started===
922
  st.markdown(
923
  """
924
  ##### How to Get Started
@@ -948,3 +962,416 @@ def render_results_column():
948
  - 🏭 Quality control in manufacturing
949
  """
950
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  on_model_change,
14
  on_input_mode_change,
15
  on_sample_change,
16
+ reset_results,
17
  reset_ephemeral_state,
18
  log_message,
 
19
  )
20
  from core_logic import (
21
  get_sample_files,
 
24
  parse_spectrum_data,
25
  label_file,
26
  )
 
27
  from utils.results_manager import ResultsManager
28
  from utils.confidence import calculate_softmax_confidence
29
  from utils.multifile import process_multiple_files, display_batch_results
 
40
  """Create spectrum visualization plot"""
41
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
42
 
43
+ # Raw spectrum
44
  ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
45
  ax[0].set_title("Raw Input Spectrum")
46
  ax[0].set_xlabel("Wavenumber (cm⁻¹)")
 
48
  ax[0].grid(True, alpha=0.3)
49
  ax[0].legend()
50
 
51
+ # Resampled spectrum
52
  ax[1].plot(
53
  x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
54
  )
 
59
  ax[1].legend()
60
 
61
  fig.tight_layout()
62
+ # Convert to image
63
  buf = io.BytesIO()
64
  plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
65
  buf.seek(0)
 
68
  return Image.open(buf)
69
 
70
 
71
+ # //////////////////////////////////////////
72
+
73
+
74
  def render_confidence_progress(
75
  probs: np.ndarray,
76
  labels: list[str] = ["Stable", "Weathered"],
 
116
  st.markdown("")
117
 
118
 
119
+ from typing import Optional
120
+
121
+
122
+ def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
123
  if d is None:
124
  d = {}
125
  if not d:
 
131
  st.caption(f"**{k}:** {v}")
132
 
133
 
134
+ # //////////////////////////////////////////
135
+
136
+
137
  def render_model_meta(model_choice: str):
138
  info = MODEL_CONFIG.get(model_choice, {})
139
  emoji = info.get("emoji", "")
 
151
  st.caption(desc)
152
 
153
 
154
+ # //////////////////////////////////////////
155
+
156
+
157
  def get_confidence_description(logit_margin):
158
  """Get human-readable confidence description"""
159
  if logit_margin > 1000:
 
166
  return "LOW", "🔴"
167
 
168
 
169
+ # //////////////////////////////////////////
170
+
171
+
172
  def render_sidebar():
173
  with st.sidebar:
174
  # Header
175
  st.header("AI-Driven Polymer Classification")
176
  st.caption(
177
+ "Predict polymer degradation (Stable vs Weathered) from Raman/FTIR spectra using validated CNN models. — v0.01"
178
+ )
179
+
180
+ # Modality Selection
181
+ st.markdown("##### Spectroscopy Modality")
182
+ modality = st.selectbox(
183
+ "Choose Modality",
184
+ ["raman", "ftir"],
185
+ index=0,
186
+ key="modality_select",
187
+ format_func=lambda x: f"{'Raman' if x == 'raman' else 'FTIR'}",
188
  )
189
+
190
+ # Display modality info
191
+ if modality == "ftir":
192
+ st.info("FTIR mode: 400-4000 cm-1 range with atmospheric correction")
193
+ else:
194
+ st.info("Raman mode: 200-4000 cm-1 range with standard preprocessing")
195
+
196
+ # Model selection
197
+ st.markdown("##### AI Model Selection")
198
  model_labels = [
199
  f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
200
  ]
 
206
  )
207
  model_choice = selected_label.split(" ", 1)[1]
208
 
209
+ # Compact metadata directly under dropdown
210
  render_model_meta(model_choice)
211
 
212
+ # Collapsed info to reduce clutter
213
  with st.expander("About This App", icon=":material/info:", expanded=False):
214
  st.markdown(
215
  """
 
217
 
218
  **Purpose**: Classify polymer degradation using AI<br>
219
  **Input**: Raman spectroscopy .txt files<br>
220
+ **Models**: CNN architectures for classification<br>
221
+ **Modalities**: Raman and FTIR spectroscopy support<br>
222
+ **Features**: Multi-model comparison and analysis<br>
223
 
224
 
225
  **Contributors**<br>
 
241
  )
242
 
243
 
244
+ # //////////////////////////////////////////
 
 
 
 
245
  def render_input_column():
246
  st.markdown("##### Data Input")
247
 
 
254
  )
255
 
256
  # == Input Mode Logic ==
 
 
257
  if mode == "Upload File":
258
  upload_key = st.session_state["current_upload_key"]
259
  up = st.file_uploader(
260
+ "Upload spectrum file (.txt, .csv, .json)",
261
+ type=["txt", "csv", "json"],
262
+ help="Upload spectroscopy data: TXT (2-column), CSV (with headers), or JSON format",
263
  key=upload_key, # ← versioned key
264
  )
265
 
266
+ # Process change immediately
267
  if up is not None:
268
  raw = up.read()
269
  text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
270
+ # only reparse if its a different file|source
271
  if (
272
  st.session_state.get("filename") != getattr(up, "name", None)
273
  or st.session_state.get("input_source") != "upload"
 
283
  st.session_state["status_type"] = "success"
284
  reset_results("New file uploaded")
285
 
286
+ # Batch Upload tab
287
  elif mode == "Batch Upload":
288
  st.session_state["batch_mode"] = True
 
289
  # Use a versioned key to ensure the file uploader resets properly.
290
  batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
291
  uploaded_files = st.file_uploader(
292
+ "Upload multiple spectrum files (.txt, .csv, .json)",
293
+ type=["txt", "csv", "json"],
294
  accept_multiple_files=True,
295
+ help="Upload spectroscopy files in TXT, CSV, or JSON format.",
296
  key=batch_upload_key,
297
  )
 
298
 
299
  if uploaded_files:
 
300
  # Use a dictionary to keep only unique files based on name and size
301
  unique_files = {(file.name, file.size): file for file in uploaded_files}
302
  unique_file_list = list(unique_files.values())
 
306
 
307
  # Optionally, inform the user that duplicates were removed
308
  if num_uploaded > num_unique:
309
+ st.info(f"{num_uploaded - num_unique} duplicate file(s) were removed.")
 
 
310
 
311
  # Use the unique list
312
  st.session_state["batch_files"] = unique_file_list
 
314
  f"{num_unique} ready for batch analysis"
315
  )
316
  st.session_state["status_type"] = "success"
 
317
  else:
318
  st.session_state["batch_files"] = []
319
  # This check prevents resetting the status if files are already staged
 
323
  )
324
  st.session_state["status_type"] = "info"
325
 
326
+ # Sample tab
327
  elif mode == "Sample Data":
328
  st.session_state["batch_mode"] = False
329
  sample_files = get_sample_files()
 
352
  else:
353
  st.info(msg)
354
 
 
 
 
355
  # Safely get model choice from session state
356
  model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
357
  model = load_model(model_choice)
 
407
  st.error(f"Error processing spectrum data: {e}")
408
 
409
 
410
+ # //////////////////////////////////////////
411
 
412
 
413
  def render_results_column():
 
429
  filename = st.session_state.get("filename", "Unknown")
430
 
431
  if all(v is not None for v in [x_raw, y_raw, y_resampled]):
432
+ # Run inference
433
  if y_resampled is None:
434
  raise ValueError(
435
  "y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
 
456
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
457
  )
458
 
459
+ # Get ground truth
460
  true_label_idx = label_file(filename)
461
  true_label_str = (
462
  LABEL_MAP.get(true_label_idx, "Unknown")
463
  if true_label_idx is not None
464
  else "Unknown"
465
  )
466
+ # Get prediction
467
  predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
468
 
469
  # Enhanced confidence calculation
 
474
  )
475
  confidence_desc = confidence_level
476
  else:
477
+ # Fallback to legacy method
478
  logit_margin = abs(
479
  (logits_list[0] - logits_list[1])
480
  if logits_list is not None and len(logits_list) >= 2
 
506
  },
507
  )
508
 
509
+ # Precompute Stats
510
  model_choice = (
511
  st.session_state.get("model_select", "").split(" ", 1)[1]
512
  if "model_select" in st.session_state
 
524
  if os.path.exists(model_path)
525
  else "N/A"
526
  )
 
527
 
528
  start_render = time.time()
529
 
 
608
  """,
609
  unsafe_allow_html=True,
610
  )
 
611
 
612
  st.divider()
613
 
614
+ # METADATA FOOTER
 
615
  st.caption(
616
  f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
617
  )
 
 
618
  st.markdown("</div>", unsafe_allow_html=True)
619
 
620
  elif active_tab == "Technical":
 
932
  """
933
  )
934
  else:
935
+ # Getting Started
936
  st.markdown(
937
  """
938
  ##### How to Get Started
 
962
  - 🏭 Quality control in manufacturing
963
  """
964
  )
965
+
966
+
967
+ # //////////////////////////////////////////
968
+
969
+
970
+ def render_comparison_tab():
971
+ """Render the multi-model comparison interface"""
972
+ import streamlit as st
973
+ import matplotlib.pyplot as plt
974
+ from models.registry import choices, validate_model_list
975
+ from utils.results_manager import ResultsManager
976
+ from core_logic import get_sample_files, run_inference, parse_spectrum_data
977
+ from utils.preprocessing import preprocess_spectrum
978
+ from utils.multifile import parse_spectrum_data
979
+ import numpy as np
980
+ import time
981
+
982
+ st.markdown("### Multi-Model Comparison Analysis")
983
+ st.markdown(
984
+ "Compare predictions across different AI models for comprehensive analysis."
985
+ )
986
+
987
+ # Model selection for comparison
988
+ st.markdown("##### Select Models for Comparison")
989
+
990
+ available_models = choices()
991
+ selected_models = st.multiselect(
992
+ "Choose models to compare",
993
+ available_models,
994
+ default=(
995
+ available_models[:2] if len(available_models) >= 2 else available_models
996
+ ),
997
+ help="Select 2 or more models to compare their predictions side-by-side",
998
+ )
999
+
1000
+ if len(selected_models) < 2:
1001
+ st.warning("⚠️ Please select at least 2 models for comparison.")
1002
+
1003
+ # Input selection for comparison
1004
+ col1, col2 = st.columns([1, 1.5])
1005
+
1006
+ with col1:
1007
+ st.markdown("###### Input Data")
1008
+
1009
+ # File upload for comparison
1010
+ comparison_file = st.file_uploader(
1011
+ "Upload spectrum for comparison",
1012
+ type=["txt", "csv", "json"],
1013
+ key="comparison_file_upload",
1014
+ help="Upload a spectrum file to test across all selected models",
1015
+ )
1016
+
1017
+ # Or select sample data
1018
+ selected_sample = None # Initialize with a default value
1019
+ sample_files = get_sample_files()
1020
+ if sample_files:
1021
+ sample_options = ["-- Select Sample --"] + [p.name for p in sample_files]
1022
+ selected_sample = st.selectbox(
1023
+ "Or choose sample data", sample_options, key="comparison_sample_select"
1024
+ )
1025
+
1026
+ # Get modality from session state
1027
+ modality = st.session_state.get("modality_select", "raman")
1028
+ st.info(f"Using {modality.upper()} preprocessing parameters")
1029
+
1030
+ # Run comparison button
1031
+ run_comparison = st.button(
1032
+ "Run Multi-Model Comparison",
1033
+ type="primary",
1034
+ disabled=not (
1035
+ comparison_file
1036
+ or (sample_files and selected_sample != "-- Select Sample --")
1037
+ ),
1038
+ )
1039
+
1040
+ with col2:
1041
+ st.markdown("###### Comparison Results")
1042
+
1043
+ if run_comparison:
1044
+ # Determine input source
1045
+ input_text = None
1046
+ filename = "unknown"
1047
+
1048
+ if comparison_file:
1049
+ raw = comparison_file.read()
1050
+ input_text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
1051
+ filename = comparison_file.name
1052
+ elif sample_files and selected_sample != "-- Select Sample --":
1053
+ sample_path = next(p for p in sample_files if p.name == selected_sample)
1054
+ with open(sample_path, "r") as f:
1055
+ input_text = f.read()
1056
+ filename = selected_sample
1057
+
1058
+ if input_text:
1059
+ try:
1060
+ # Parse spectrum data
1061
+ x_raw, y_raw = parse_spectrum_data(
1062
+ str(input_text), filename or "unknown_filename"
1063
+ )
1064
+
1065
+ # Store results
1066
+ comparison_results = {}
1067
+ processing_times = {}
1068
+
1069
+ progress_bar = st.progress(0)
1070
+ status_text = st.empty()
1071
+
1072
+ for i, model_name in enumerate(selected_models):
1073
+ status_text.text(f"Running inference with {model_name}...")
1074
+
1075
+ start_time = time.time()
1076
+
1077
+ # Preprocess spectrum with modality-specific parameters
1078
+ _, y_processed = preprocess_spectrum(
1079
+ x_raw, y_raw, modality=modality, target_len=500
1080
+ )
1081
+
1082
+ # Run inference
1083
+ prediction, logits_list, probs, inference_time, logits = (
1084
+ run_inference(y_processed, model_name)
1085
+ )
1086
+
1087
+ processing_time = time.time() - start_time
1088
+
1089
+ if prediction is not None:
1090
+ # Map prediction to class name
1091
+ class_names = ["Stable", "Weathered"]
1092
+ predicted_class = (
1093
+ class_names[int(prediction)]
1094
+ if prediction < len(class_names)
1095
+ else f"Class_{prediction}"
1096
+ )
1097
+ confidence = (
1098
+ max(probs)
1099
+ if probs is not None and len(probs) > 0
1100
+ else 0.0
1101
+ )
1102
+
1103
+ comparison_results[model_name] = {
1104
+ "prediction": prediction,
1105
+ "predicted_class": predicted_class,
1106
+ "confidence": confidence,
1107
+ "probs": probs if probs is not None else [],
1108
+ "logits": (
1109
+ logits_list if logits_list is not None else []
1110
+ ),
1111
+ "processing_time": processing_time,
1112
+ }
1113
+ processing_times[model_name] = processing_time
1114
+
1115
+ progress_bar.progress((i + 1) / len(selected_models))
1116
+
1117
+ status_text.text("Comparison complete!")
1118
+
1119
+ # Display results
1120
+ if comparison_results:
1121
+ st.markdown("###### Model Predictions")
1122
+
1123
+ # Create comparison table
1124
+ import pandas as pd
1125
+
1126
+ table_data = []
1127
+ for model_name, result in comparison_results.items():
1128
+ row = {
1129
+ "Model": model_name,
1130
+ "Prediction": result["predicted_class"],
1131
+ "Confidence": f"{result['confidence']:.3f}",
1132
+ "Processing Time (s)": f"{result['processing_time']:.3f}",
1133
+ }
1134
+ table_data.append(row)
1135
+
1136
+ df = pd.DataFrame(table_data)
1137
+ st.dataframe(df, use_container_width=True)
1138
+
1139
+ # Show confidence comparison
1140
+ st.markdown("##### Confidence Comparison")
1141
+ conf_col1, conf_col2 = st.columns(2)
1142
+
1143
+ with conf_col1:
1144
+ # Bar chart of confidences
1145
+ models = list(comparison_results.keys())
1146
+ confidences = [
1147
+ comparison_results[m]["confidence"] for m in models
1148
+ ]
1149
+
1150
+ fig, ax = plt.subplots(figsize=(8, 5))
1151
+ bars = ax.bar(
1152
+ models,
1153
+ confidences,
1154
+ alpha=0.7,
1155
+ color=["steelblue", "orange", "green", "red"][
1156
+ : len(models)
1157
+ ],
1158
+ )
1159
+ ax.set_ylabel("Confidence")
1160
+ ax.set_title("Model Confidence Comparison")
1161
+ ax.set_ylim(0, 1)
1162
+ plt.xticks(rotation=45)
1163
+
1164
+ # Add value labels on bars
1165
+ for bar, conf in zip(bars, confidences):
1166
+ height = bar.get_height()
1167
+ ax.text(
1168
+ bar.get_x() + bar.get_width() / 2.0,
1169
+ height + 0.01,
1170
+ f"{conf:.3f}",
1171
+ ha="center",
1172
+ va="bottom",
1173
+ )
1174
+
1175
+ plt.tight_layout()
1176
+ st.pyplot(fig)
1177
+
1178
+ with conf_col2:
1179
+ # Agreement analysis
1180
+ predictions = [
1181
+ comparison_results[m]["prediction"] for m in models
1182
+ ]
1183
+ unique_predictions = set(predictions)
1184
+
1185
+ if len(unique_predictions) == 1:
1186
+ st.success("✅ All models agree on the prediction!")
1187
+ else:
1188
+ st.warning("⚠️ Models disagree on the prediction")
1189
+
1190
+ # Show prediction distribution
1191
+ from collections import Counter
1192
+
1193
+ pred_counts = Counter(predictions)
1194
+
1195
+ st.markdown("**Prediction Distribution:**")
1196
+ for pred, count in pred_counts.items():
1197
+ class_name = (
1198
+ ["Stable", "Weathered"][pred]
1199
+ if pred < 2
1200
+ else f"Class_{pred}"
1201
+ )
1202
+ percentage = (count / len(predictions)) * 100
1203
+ st.write(
1204
+ f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
1205
+ )
1206
+
1207
+ # Performance metrics
1208
+ st.markdown("##### Performance Metrics")
1209
+ perf_col1, perf_col2 = st.columns(2)
1210
+
1211
+ with perf_col1:
1212
+ avg_time = np.mean(list(processing_times.values()))
1213
+ fastest_model = min(
1214
+ processing_times.keys(),
1215
+ key=lambda k: processing_times[k],
1216
+ )
1217
+ slowest_model = max(
1218
+ processing_times.keys(),
1219
+ key=lambda k: processing_times[k],
1220
+ )
1221
+
1222
+ st.metric("Average Processing Time", f"{avg_time:.3f}s")
1223
+ st.metric(
1224
+ "Fastest Model",
1225
+ f"{fastest_model}",
1226
+ f"{processing_times[fastest_model]:.3f}s",
1227
+ )
1228
+ st.metric(
1229
+ "Slowest Model",
1230
+ f"{slowest_model}",
1231
+ f"{processing_times[slowest_model]:.3f}s",
1232
+ )
1233
+
1234
+ with perf_col2:
1235
+ most_confident = max(
1236
+ comparison_results.keys(),
1237
+ key=lambda k: comparison_results[k]["confidence"],
1238
+ )
1239
+ least_confident = min(
1240
+ comparison_results.keys(),
1241
+ key=lambda k: comparison_results[k]["confidence"],
1242
+ )
1243
+
1244
+ st.metric(
1245
+ "Most Confident",
1246
+ f"{most_confident}",
1247
+ f"{comparison_results[most_confident]['confidence']:.3f}",
1248
+ )
1249
+ st.metric(
1250
+ "Least Confident",
1251
+ f"{least_confident}",
1252
+ f"{comparison_results[least_confident]['confidence']:.3f}",
1253
+ )
1254
+
1255
+ # Store results in session state for potential export
1256
+ # Store results in session state for potential export
1257
+ st.session_state["last_comparison_results"] = {
1258
+ "filename": filename,
1259
+ "modality": modality,
1260
+ "models": comparison_results,
1261
+ "summary": {
1262
+ "agreement": len(unique_predictions) == 1,
1263
+ "avg_processing_time": avg_time,
1264
+ "fastest_model": fastest_model,
1265
+ "most_confident": most_confident,
1266
+ },
1267
+ }
1268
+
1269
+ except Exception as e:
1270
+ st.error(f"Error during comparison: {str(e)}")
1271
+
1272
+ # Show recent comparison results if available
1273
+ elif "last_comparison_results" in st.session_state:
1274
+ st.info(
1275
+ "Previous comparison results available. Upload a new file or select a sample to run new comparison."
1276
+ )
1277
+
1278
+ # Show comparison history
1279
+ comparison_stats = ResultsManager.get_comparison_stats()
1280
+ if comparison_stats:
1281
+ st.markdown("#### Comparison History")
1282
+
1283
+ with st.expander("View detailed comparison statistics", expanded=False):
1284
+ # Show model statistics table
1285
+ stats_data = []
1286
+ for model_name, stats in comparison_stats.items():
1287
+ row = {
1288
+ "Model": model_name,
1289
+ "Total Predictions": stats["total_predictions"],
1290
+ "Avg Confidence": f"{stats['avg_confidence']:.3f}",
1291
+ "Avg Processing Time": f"{stats['avg_processing_time']:.3f}s",
1292
+ "Accuracy": (
1293
+ f"{stats['accuracy']:.3f}"
1294
+ if stats["accuracy"] is not None
1295
+ else "N/A"
1296
+ ),
1297
+ }
1298
+ stats_data.append(row)
1299
+
1300
+ if stats_data:
1301
+ import pandas as pd
1302
+
1303
+ stats_df = pd.DataFrame(stats_data)
1304
+ st.dataframe(stats_df, use_container_width=True)
1305
+
1306
+ # Show agreement matrix if multiple models
1307
+ agreement_matrix = ResultsManager.get_agreement_matrix()
1308
+ if not agreement_matrix.empty and len(agreement_matrix) > 1:
1309
+ st.markdown("**Model Agreement Matrix**")
1310
+ st.dataframe(agreement_matrix.round(3), use_container_width=True)
1311
+
1312
+ # Plot agreement heatmap
1313
+ fig, ax = plt.subplots(figsize=(8, 6))
1314
+ im = ax.imshow(
1315
+ agreement_matrix.values, cmap="RdYlGn", vmin=0, vmax=1
1316
+ )
1317
+
1318
+ # Add text annotations
1319
+ for i in range(len(agreement_matrix)):
1320
+ for j in range(len(agreement_matrix.columns)):
1321
+ text = ax.text(
1322
+ j,
1323
+ i,
1324
+ f"{agreement_matrix.iloc[i, j]:.2f}",
1325
+ ha="center",
1326
+ va="center",
1327
+ color="black",
1328
+ )
1329
+
1330
+ ax.set_xticks(range(len(agreement_matrix.columns)))
1331
+ ax.set_yticks(range(len(agreement_matrix)))
1332
+ ax.set_xticklabels(agreement_matrix.columns, rotation=45)
1333
+ ax.set_yticklabels(agreement_matrix.index)
1334
+ ax.set_title("Model Agreement Matrix")
1335
+
1336
+ plt.colorbar(im, ax=ax, label="Agreement Rate")
1337
+ plt.tight_layout()
1338
+ st.pyplot(fig)
1339
+
1340
+ # Export functionality
1341
+ if "last_comparison_results" in st.session_state:
1342
+ st.markdown("##### Export Results")
1343
+
1344
+ export_col1, export_col2 = st.columns(2)
1345
+
1346
+ with export_col1:
1347
+ if st.button("📥 Export Comparison (JSON)"):
1348
+ import json
1349
+
1350
+ results = st.session_state["last_comparison_results"]
1351
+ json_str = json.dumps(results, indent=2, default=str)
1352
+ st.download_button(
1353
+ label="Download JSON",
1354
+ data=json_str,
1355
+ file_name=f"comparison_{results['filename'].split('.')[0]}.json",
1356
+ mime="application/json",
1357
+ )
1358
+
1359
+ with export_col2:
1360
+ if st.button("📊 Export Full Report"):
1361
+ report = ResultsManager.export_comparison_report()
1362
+ st.download_button(
1363
+ label="Download Full Report",
1364
+ data=report,
1365
+ file_name="model_comparison_report.json",
1366
+ mime="application/json",
1367
+ )
1368
+
1369
+
1370
+ # //////////////////////////////////////////
1371
+
1372
+
1373
+ def render_performance_tab():
1374
+ """Render the performance tracking and analysis tab."""
1375
+ from utils.performance_tracker import display_performance_dashboard
1376
+
1377
+ display_performance_dashboard()
sample_data/ftir-stable-1.txt ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sample FTIR spectrum data - Stable polymer
2
+ # Wavenumber (cm^-1) Absorbance
3
+ 400.0 0.045
4
+ 450.0 0.048
5
+ 500.0 0.052
6
+ 550.0 0.056
7
+ 600.0 0.061
8
+ 650.0 0.065
9
+ 700.0 0.070
10
+ 750.0 0.075
11
+ 800.0 0.082
12
+ 850.0 0.089
13
+ 900.0 0.096
14
+ 950.0 0.104
15
+ 1000.0 0.112
16
+ 1050.0 0.121
17
+ 1100.0 0.130
18
+ 1150.0 0.140
19
+ 1200.0 0.151
20
+ 1250.0 0.162
21
+ 1300.0 0.174
22
+ 1350.0 0.187
23
+ 1400.0 0.200
24
+ 1450.0 0.215
25
+ 1500.0 0.230
26
+ 1550.0 0.246
27
+ 1600.0 0.263
28
+ 1650.0 0.281
29
+ 1700.0 0.300
30
+ 1750.0 0.320
31
+ 1800.0 0.341
32
+ 1850.0 0.363
33
+ 1900.0 0.386
34
+ 1950.0 0.410
35
+ 2000.0 0.435
36
+ 2050.0 0.461
37
+ 2100.0 0.488
38
+ 2150.0 0.516
39
+ 2200.0 0.545
40
+ 2250.0 0.575
41
+ 2300.0 0.606
42
+ 2350.0 0.638
43
+ 2400.0 0.671
44
+ 2450.0 0.705
45
+ 2500.0 0.740
46
+ 2550.0 0.776
47
+ 2600.0 0.813
48
+ 2650.0 0.851
49
+ 2700.0 0.890
50
+ 2750.0 0.930
51
+ 2800.0 0.971
52
+ 2850.0 1.013
53
+ 2900.0 1.056
54
+ 2950.0 1.100
55
+ 3000.0 1.145
56
+ 3050.0 1.191
57
+ 3100.0 1.238
58
+ 3150.0 1.286
59
+ 3200.0 1.335
60
+ 3250.0 1.385
61
+ 3300.0 1.436
62
+ 3350.0 1.488
63
+ 3400.0 1.541
64
+ 3450.0 1.595
65
+ 3500.0 1.650
66
+ 3550.0 1.706
67
+ 3600.0 1.763
68
+ 3650.0 1.821
69
+ 3700.0 1.880
70
+ 3750.0 1.940
71
+ 3800.0 2.001
72
+ 3850.0 2.063
73
+ 3900.0 2.126
74
+ 3950.0 2.190
75
+ 4000.0 2.255
sample_data/ftir-weathered-1.txt ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sample FTIR spectrum data - Weathered polymer
2
+ # Wavenumber (cm^-1) Absorbance
3
+ 400.0 0.062
4
+ 450.0 0.069
5
+ 500.0 0.077
6
+ 550.0 0.086
7
+ 600.0 0.095
8
+ 650.0 0.105
9
+ 700.0 0.116
10
+ 750.0 0.128
11
+ 800.0 0.141
12
+ 850.0 0.155
13
+ 900.0 0.170
14
+ 950.0 0.186
15
+ 1000.0 0.203
16
+ 1050.0 0.221
17
+ 1100.0 0.240
18
+ 1150.0 0.260
19
+ 1200.0 0.281
20
+ 1250.0 0.303
21
+ 1300.0 0.326
22
+ 1350.0 0.350
23
+ 1400.0 0.375
24
+ 1450.0 0.401
25
+ 1500.0 0.428
26
+ 1550.0 0.456
27
+ 1600.0 0.485
28
+ 1650.0 0.515
29
+ 1700.0 0.546
30
+ 1750.0 0.578
31
+ 1800.0 0.611
32
+ 1850.0 0.645
33
+ 1900.0 0.680
34
+ 1950.0 0.716
35
+ 2000.0 0.753
36
+ 2050.0 0.791
37
+ 2100.0 0.830
38
+ 2150.0 0.870
39
+ 2200.0 0.911
40
+ 2250.0 0.953
41
+ 2300.0 0.996
42
+ 2350.0 1.040
43
+ 2400.0 1.085
44
+ 2450.0 1.131
45
+ 2500.0 1.178
46
+ 2550.0 1.226
47
+ 2600.0 1.275
48
+ 2650.0 1.325
49
+ 2700.0 1.376
50
+ 2750.0 1.428
51
+ 2800.0 1.481
52
+ 2850.0 1.535
53
+ 2900.0 1.590
54
+ 2950.0 1.646
55
+ 3000.0 1.703
56
+ 3050.0 1.761
57
+ 3100.0 1.820
58
+ 3150.0 1.880
59
+ 3200.0 1.941
60
+ 3250.0 2.003
61
+ 3300.0 2.066
62
+ 3350.0 2.130
63
+ 3400.0 2.195
64
+ 3450.0 2.261
65
+ 3500.0 2.328
66
+ 3550.0 2.396
67
+ 3600.0 2.465
68
+ 3650.0 2.535
69
+ 3700.0 2.606
70
+ 3750.0 2.678
71
+ 3800.0 2.751
72
+ 3850.0 2.825
73
+ 3900.0 2.900
74
+ 3950.0 2.976
75
+ 4000.0 3.053
sample_data/stable.sample.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wavenumber,intensity
2
+ 200.0,1542.3
3
+ 205.0,1543.1
4
+ 210.0,1544.8
5
+ 215.0,1546.2
6
+ 220.0,1547.9
7
+ 225.0,1549.1
8
+ 230.0,1550.4
9
+ 235.0,1551.8
10
+ 240.0,1553.2
11
+ 245.0,1554.6
12
+ 250.0,1556.1
13
+ 255.0,1557.6
14
+ 260.0,1559.1
15
+ 265.0,1560.7
16
+ 270.0,1562.3
17
+ 275.0,1563.9
18
+ 280.0,1565.6
19
+ 285.0,1567.3
20
+ 290.0,1569.0
21
+ 295.0,1570.8
22
+ 300.0,1572.6
scripts/run_inference.py CHANGED
@@ -17,144 +17,447 @@ python scripts/run_inference.py --input ... --arch resnet --weights ... --disabl
17
 
18
  import os
19
  import sys
 
20
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
21
 
22
  import argparse
23
  import json
 
24
  import logging
25
  from pathlib import Path
26
- from typing import cast
27
  from torch import nn
 
28
 
29
  import numpy as np
30
  import torch
31
  import torch.nn.functional as F
32
 
33
- from models.registry import build, choices
34
  from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
 
35
  from scripts.plot_spectrum import load_spectrum
36
  from scripts.discover_raman_files import label_file
37
 
38
 
39
  def parse_args():
40
- p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).")
41
- p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).")
42
- p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.")
43
- p.add_argument("--weights", required=True, help="Path to model weights (.pth).")
44
- p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Default = ON; use disable- flags to turn steps off explicitly.
47
- p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.")
48
- p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.")
49
- p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).")
52
- p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).")
53
  return p.parse_args()
54
 
55
 
 
 
 
56
  def _load_state_dict_safe(path: str):
57
  """Load a state dict safely across torch versions & checkpoint formats."""
58
  try:
59
  obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
60
  except TypeError:
61
  obj = torch.load(path, map_location="cpu") # fallback for older torch
62
-
63
  # Accept either a plain state_dict or a checkpoint dict that contains one
64
  if isinstance(obj, dict):
65
  for k in ("state_dict", "model_state_dict", "model"):
66
  if k in obj and isinstance(obj[k], dict):
67
  obj = obj[k]
68
  break
69
-
70
  if not isinstance(obj, dict):
71
  raise ValueError(
72
  "Loaded object is not a state_dict or checkpoint with a state_dict. "
73
  f"Type={type(obj)} from file={path}"
74
  )
75
-
76
  # Strip DataParallel 'module.' prefixes if present
77
  if any(key.startswith("module.") for key in obj.keys()):
78
  obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
79
-
80
  return obj
81
 
82
 
83
- def main():
84
- logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
85
- args = parse_args()
86
 
87
- in_path = Path(args.input)
88
- if not in_path.exists():
89
- raise FileNotFoundError(f"Input file not found: {in_path}")
90
 
91
- # --- Load raw spectrum
92
- x_raw, y_raw = load_spectrum(str(in_path))
93
- if len(x_raw) < 10:
94
- raise ValueError("Input spectrum has too few points (<10).")
 
 
 
 
 
 
95
 
96
- # --- Preprocess (single source of truth)
97
  _, y_proc = preprocess_spectrum(
98
- np.array(x_raw),
99
- np.array(y_raw),
100
  target_len=args.target_len,
 
101
  do_baseline=not args.disable_baseline,
102
  do_smooth=not args.disable_smooth,
103
  do_normalize=not args.disable_normalize,
104
  out_dtype="float32",
105
  )
106
 
107
- # --- Build model & load weights (safe)
108
- device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
109
- model = cast(nn.Module, build(args.arch, args.target_len)).to(device)
110
- state = _load_state_dict_safe(args.weights)
111
  missing, unexpected = model.load_state_dict(state, strict=False)
112
  if missing or unexpected:
113
- logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected))
 
 
114
 
115
  model.eval()
116
 
117
- # Shape: (B, C, L) = (1, 1, target_len)
118
  x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
119
 
120
  with torch.no_grad():
121
- logits = model(x_tensor).float().cpu() # shape (1, num_classes)
122
  probs = F.softmax(logits, dim=1)
123
 
 
124
  probs_np = probs.numpy().ravel().tolist()
125
  logits_np = logits.numpy().ravel().tolist()
126
  pred_label = int(np.argmax(probs_np))
127
 
128
- # Optional ground-truth from filename (if encoded)
129
- true_label = label_file(str(in_path))
130
-
131
- # --- Prepare output
132
- out_dir = Path("outputs") / "inference"
133
- out_dir.mkdir(parents=True, exist_ok=True)
134
- out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json")
135
-
136
- result = {
137
- "input_file": str(in_path),
138
- "arch": args.arch,
139
- "weights": str(args.weights),
140
- "target_len": args.target_len,
141
- "preprocessing": {
142
- "baseline": not args.disable_baseline,
143
- "smooth": not args.disable_smooth,
144
- "normalize": not args.disable_normalize,
145
- },
146
- "predicted_label": pred_label,
147
- "true_label": true_label,
148
  "probs": probs_np,
149
  "logits": logits_np,
 
150
  }
151
 
152
- with open(out_path, "w", encoding="utf-8") as f:
153
- json.dump(result, f, indent=2)
154
 
155
- logging.info("Predicted Label: %d True Label: %s", pred_label, true_label)
156
- logging.info("Raw Logits: %s", logits_np)
157
- logging.info("Result saved to %s", out_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  if __name__ == "__main__":
 
17
 
18
  import os
19
  import sys
20
+
21
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
22
 
23
  import argparse
24
  import json
25
+ import csv
26
  import logging
27
  from pathlib import Path
28
+ from typing import cast, Dict, List, Any
29
  from torch import nn
30
+ import time
31
 
32
  import numpy as np
33
  import torch
34
  import torch.nn.functional as F
35
 
36
+ from models.registry import build, choices, build_multiple, validate_model_list
37
  from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
38
+ from utils.multifile import parse_spectrum_data, detect_file_format
39
  from scripts.plot_spectrum import load_spectrum
40
  from scripts.discover_raman_files import label_file
41
 
42
 
43
  def parse_args():
44
+ p = argparse.ArgumentParser(
45
+ description="Raman/FTIR spectrum inference with multi-model support."
46
+ )
47
+ p.add_argument(
48
+ "--input",
49
+ required=True,
50
+ help="Path to spectrum file (.txt, .csv, .json) or directory for batch processing.",
51
+ )
52
+
53
+ # Model selection - either single or multiple
54
+ group = p.add_mutually_exclusive_group(required=True)
55
+ group.add_argument(
56
+ "--arch", choices=choices(), help="Single model architecture key."
57
+ )
58
+ group.add_argument(
59
+ "--models",
60
+ help="Comma-separated list of models for comparison (e.g., 'figure2,resnet,resnet18vision').",
61
+ )
62
+
63
+ p.add_argument(
64
+ "--weights",
65
+ help="Path to model weights (.pth). For multi-model, use pattern with {model} placeholder.",
66
+ )
67
+ p.add_argument(
68
+ "--target-len",
69
+ type=int,
70
+ default=TARGET_LENGTH,
71
+ help="Resample length (default: 500).",
72
+ )
73
+
74
+ # Modality support
75
+ p.add_argument(
76
+ "--modality",
77
+ choices=["raman", "ftir"],
78
+ default="raman",
79
+ help="Spectroscopy modality for preprocessing (default: raman).",
80
+ )
81
 
82
  # Default = ON; use disable- flags to turn steps off explicitly.
83
+ p.add_argument(
84
+ "--disable-baseline", action="store_true", help="Disable baseline correction."
85
+ )
86
+ p.add_argument(
87
+ "--disable-smooth",
88
+ action="store_true",
89
+ help="Disable Savitzky–Golay smoothing.",
90
+ )
91
+ p.add_argument(
92
+ "--disable-normalize",
93
+ action="store_true",
94
+ help="Disable min-max normalization.",
95
+ )
96
+
97
+ p.add_argument(
98
+ "--output",
99
+ default=None,
100
+ help="Output path - JSON for single file, CSV for multi-model comparison.",
101
+ )
102
+ p.add_argument(
103
+ "--output-format",
104
+ choices=["json", "csv"],
105
+ default="json",
106
+ help="Output format for results.",
107
+ )
108
+ p.add_argument(
109
+ "--device",
110
+ default="cpu",
111
+ choices=["cpu", "cuda"],
112
+ help="Compute device (default: cpu).",
113
+ )
114
+
115
+ # File format options
116
+ p.add_argument(
117
+ "--file-format",
118
+ choices=["auto", "txt", "csv", "json"],
119
+ default="auto",
120
+ help="Input file format (auto-detect by default).",
121
+ )
122
 
 
 
123
  return p.parse_args()
124
 
125
 
126
+ # /////////////////////////////////////////////////////////
127
+
128
+
129
  def _load_state_dict_safe(path: str):
130
  """Load a state dict safely across torch versions & checkpoint formats."""
131
  try:
132
  obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
133
  except TypeError:
134
  obj = torch.load(path, map_location="cpu") # fallback for older torch
 
135
  # Accept either a plain state_dict or a checkpoint dict that contains one
136
  if isinstance(obj, dict):
137
  for k in ("state_dict", "model_state_dict", "model"):
138
  if k in obj and isinstance(obj[k], dict):
139
  obj = obj[k]
140
  break
 
141
  if not isinstance(obj, dict):
142
  raise ValueError(
143
  "Loaded object is not a state_dict or checkpoint with a state_dict. "
144
  f"Type={type(obj)} from file={path}"
145
  )
 
146
  # Strip DataParallel 'module.' prefixes if present
147
  if any(key.startswith("module.") for key in obj.keys()):
148
  obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
 
149
  return obj
150
 
151
 
152
+ # /////////////////////////////////////////////////////////
 
 
153
 
 
 
 
154
 
155
+ def run_single_model_inference(
156
+ x_raw: np.ndarray,
157
+ y_raw: np.ndarray,
158
+ model_name: str,
159
+ weights_path: str,
160
+ args: argparse.Namespace,
161
+ device: torch.device,
162
+ ) -> Dict[str, Any]:
163
+ """Run inference with a single model."""
164
+ start_time = time.time()
165
 
166
+ # Preprocess spectrum
167
  _, y_proc = preprocess_spectrum(
168
+ x_raw,
169
+ y_raw,
170
  target_len=args.target_len,
171
+ modality=args.modality,
172
  do_baseline=not args.disable_baseline,
173
  do_smooth=not args.disable_smooth,
174
  do_normalize=not args.disable_normalize,
175
  out_dtype="float32",
176
  )
177
 
178
+ # Build model & load weights
179
+ model = cast(nn.Module, build(model_name, args.target_len)).to(device)
180
+ state = _load_state_dict_safe(weights_path)
 
181
  missing, unexpected = model.load_state_dict(state, strict=False)
182
  if missing or unexpected:
183
+ logging.info(
184
+ f"Model {model_name}: Loaded with non-strict keys. missing={len(missing)} unexpected={len(unexpected)}"
185
+ )
186
 
187
  model.eval()
188
 
189
+ # Run inference
190
  x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
191
 
192
  with torch.no_grad():
193
+ logits = model(x_tensor).float().cpu()
194
  probs = F.softmax(logits, dim=1)
195
 
196
+ processing_time = time.time() - start_time
197
  probs_np = probs.numpy().ravel().tolist()
198
  logits_np = logits.numpy().ravel().tolist()
199
  pred_label = int(np.argmax(probs_np))
200
 
201
+ # Map prediction to class name
202
+ class_names = ["Stable", "Weathered"]
203
+ predicted_class = (
204
+ class_names[pred_label]
205
+ if pred_label < len(class_names)
206
+ else f"Class_{pred_label}"
207
+ )
208
+
209
+ return {
210
+ "model": model_name,
211
+ "prediction": pred_label,
212
+ "predicted_class": predicted_class,
213
+ "confidence": max(probs_np),
 
 
 
 
 
 
 
214
  "probs": probs_np,
215
  "logits": logits_np,
216
+ "processing_time": processing_time,
217
  }
218
 
 
 
219
 
220
+ # /////////////////////////////////////////////////////////
221
+
222
+
223
+ def run_multi_model_inference(
224
+ x_raw: np.ndarray,
225
+ y_raw: np.ndarray,
226
+ model_names: List[str],
227
+ args: argparse.Namespace,
228
+ device: torch.device,
229
+ ) -> Dict[str, Dict[str, Any]]:
230
+ """Run inference with multiple models for comparison."""
231
+ results = {}
232
+
233
+ for model_name in model_names:
234
+ try:
235
+ # Generate weights path - either use pattern or assume same weights for all
236
+ if args.weights and "{model}" in args.weights:
237
+ weights_path = args.weights.format(model=model_name)
238
+ elif args.weights:
239
+ weights_path = args.weights
240
+ else:
241
+ # Default weights path pattern
242
+ weights_path = f"outputs/{model_name}_model.pth"
243
+
244
+ if not Path(weights_path).exists():
245
+ logging.warning(f"Weights not found for {model_name}: {weights_path}")
246
+ continue
247
+
248
+ result = run_single_model_inference(
249
+ x_raw, y_raw, model_name, weights_path, args, device
250
+ )
251
+ results[model_name] = result
252
+
253
+ except Exception as e:
254
+ logging.error(f"Failed to run inference with {model_name}: {str(e)}")
255
+ continue
256
+
257
+ return results
258
+
259
+
260
+ # /////////////////////////////////////////////////////////
261
+
262
+
263
+ def save_results(
264
+ results: Dict[str, Any], output_path: Path, format: str = "json"
265
+ ) -> None:
266
+ """Save results to file in specified format"""
267
+ output_path.parent.mkdir(parents=True, exist_ok=True)
268
+
269
+ if format == "json":
270
+ with open(output_path, "w", encoding="utf-8") as f:
271
+ json.dump(results, f, indent=2)
272
+ elif format == "csv":
273
+ # Convert to tabular format for CSV
274
+ if "models" in results: # Multi-model results
275
+ rows = []
276
+ for model_name, model_result in results["models"].items():
277
+ row = {
278
+ "model": model_name,
279
+ "prediction": model_result["prediction"],
280
+ "predicted_class": model_result["predicted_class"],
281
+ "confidence": model_result["confidence"],
282
+ "processing_time": model_result["processing_time"],
283
+ }
284
+ # Add individual class probabilities
285
+ if "probs" in model_result:
286
+ for i, prob in enumerate(model_result["probs"]):
287
+ row[f"prob_class_{i}"] = prob
288
+ rows.append(row)
289
+
290
+ # Write CSV
291
+ with open(output_path, "w", newline="", encoding="utf-8") as f:
292
+ if rows:
293
+ writer = csv.DictWriter(f, fieldnames=rows[0].keys())
294
+ writer.writeheader()
295
+ writer.writerows(rows)
296
+ else: # Single model result
297
+ with open(output_path, "w", newline="", encoding="utf-8") as f:
298
+ writer = csv.DictWriter(f, fieldnames=results.keys())
299
+ writer.writeheader()
300
+ writer.writerow(results)
301
+
302
+
303
+ def main():
304
+ logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
305
+ args = parse_args()
306
+
307
+ # Input validation
308
+ in_path = Path(args.input)
309
+ if not in_path.exists():
310
+ raise FileNotFoundError(f"Input file not found: {in_path}")
311
+
312
+ # Determine if this is single or multi-model inference
313
+ if args.models:
314
+ model_names = [m.strip() for m in args.models.split(",")]
315
+ model_names = validate_model_list(model_names)
316
+ if not model_names:
317
+ raise ValueError(f"No valid models found in: {args.models}")
318
+ multi_model = True
319
+ else:
320
+ model_names = [args.arch]
321
+ multi_model = False
322
+
323
+ # Load and parse spectrum data
324
+ if args.file_format == "auto":
325
+ file_format = None # Auto-detect
326
+ else:
327
+ file_format = args.file_format
328
+
329
+ try:
330
+ # Read file content
331
+ with open(in_path, "r", encoding="utf-8") as f:
332
+ content = f.read()
333
+
334
+ # Parse spectrum data with format detection
335
+ x_raw, y_raw = parse_spectrum_data(content, str(in_path))
336
+ x_raw = np.array(x_raw, dtype=np.float32)
337
+ y_raw = np.array(y_raw, dtype=np.float32)
338
+
339
+ except Exception as e:
340
+ x_raw, y_raw = load_spectrum(str(in_path))
341
+ x_raw = np.array(x_raw, dtype=np.float32)
342
+ y_raw = np.array(y_raw, dtype=np.float32)
343
+ logging.warning(
344
+ f"Failed to parse with new parser, falling back to original: {e}"
345
+ )
346
+ x_raw, y_raw = load_spectrum(str(in_path))
347
+
348
+ if len(x_raw) < 10:
349
+ raise ValueError("Input spectrum has too few points (<10).")
350
+
351
+ # Setup device
352
+ device = torch.device(
353
+ args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
354
+ )
355
+
356
+ # Run inference
357
+ model_results = {} # Initialize to avoid unbound variable error
358
+ if multi_model:
359
+ model_results = run_multi_model_inference(
360
+ np.array(x_raw, dtype=np.float32),
361
+ np.array(y_raw, dtype=np.float32),
362
+ model_names,
363
+ args,
364
+ device,
365
+ )
366
+
367
+ # Get ground truth if available
368
+ true_label = label_file(str(in_path))
369
+
370
+ # Prepare combined results
371
+ results = {
372
+ "input_file": str(in_path),
373
+ "modality": args.modality,
374
+ "models": model_results,
375
+ "true_label": true_label,
376
+ "preprocessing": {
377
+ "baseline": not args.disable_baseline,
378
+ "smooth": not args.disable_smooth,
379
+ "normalize": not args.disable_normalize,
380
+ "target_len": args.target_len,
381
+ },
382
+ "comparison": {
383
+ "total_models": len(model_results),
384
+ "agreements": (
385
+ sum(
386
+ 1
387
+ for i, (_, r1) in enumerate(model_results.items())
388
+ for j, (_, r2) in enumerate(
389
+ list(model_results.items())[i + 1 :]
390
+ )
391
+ if r1["prediction"] == r2["prediction"]
392
+ )
393
+ if len(model_results) > 1
394
+ else 0
395
+ ),
396
+ },
397
+ }
398
+
399
+ # Default output path for multi-model
400
+ default_output = (
401
+ Path("outputs")
402
+ / "inference"
403
+ / f"{in_path.stem}_comparison.{args.output_format}"
404
+ )
405
+
406
+ else:
407
+ # Single model inference
408
+ model_result = run_single_model_inference(
409
+ x_raw, y_raw, model_names[0], args.weights, args, device
410
+ )
411
+ true_label = label_file(str(in_path))
412
+
413
+ results = {
414
+ "input_file": str(in_path),
415
+ "modality": args.modality,
416
+ "arch": model_names[0],
417
+ "weights": str(args.weights),
418
+ "target_len": args.target_len,
419
+ "preprocessing": {
420
+ "baseline": not args.disable_baseline,
421
+ "smooth": not args.disable_smooth,
422
+ "normalize": not args.disable_normalize,
423
+ },
424
+ "predicted_label": model_result["prediction"],
425
+ "predicted_class": model_result["predicted_class"],
426
+ "true_label": true_label,
427
+ "confidence": model_result["confidence"],
428
+ "probs": model_result["probs"],
429
+ "logits": model_result["logits"],
430
+ "processing_time": model_result["processing_time"],
431
+ }
432
+
433
+ # Default output path for single model
434
+ default_output = (
435
+ Path("outputs")
436
+ / "inference"
437
+ / f"{in_path.stem}_{model_names[0]}.{args.output_format}"
438
+ )
439
+
440
+ # Save results
441
+ output_path = Path(args.output) if args.output else default_output
442
+ save_results(results, output_path, args.output_format)
443
+
444
+ # Log summary
445
+ if multi_model:
446
+ logging.info(
447
+ f"Multi-model inference completed with {len(model_results)} models"
448
+ )
449
+ for model_name, result in model_results.items():
450
+ logging.info(
451
+ f"{model_name}: {result['predicted_class']} (confidence: {result['confidence']:.3f})"
452
+ )
453
+ logging.info(f"Results saved to {output_path}")
454
+ else:
455
+ logging.info(
456
+ f"Predicted Label: {results['predicted_label']} ({results['predicted_class']})"
457
+ )
458
+ logging.info(f"Confidence: {results['confidence']:.3f}")
459
+ logging.info(f"True Label: {results['true_label']}")
460
+ logging.info(f"Result saved to {output_path}")
461
 
462
 
463
  if __name__ == "__main__":
tests/test_ftir_preprocessing.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for FTIR preprocessing functionality."""
2
+
3
+ import pytest
4
+ import numpy as np
5
+ from utils.preprocessing import (
6
+ preprocess_spectrum,
7
+ validate_spectrum_range,
8
+ get_modality_info,
9
+ MODALITY_RANGES,
10
+ MODALITY_PARAMS,
11
+ )
12
+
13
+
14
+ def test_modality_ranges():
15
+ """Test that modality ranges are correctly defined."""
16
+ assert "raman" in MODALITY_RANGES
17
+ assert "ftir" in MODALITY_RANGES
18
+
19
+ raman_range = MODALITY_RANGES["raman"]
20
+ ftir_range = MODALITY_RANGES["ftir"]
21
+
22
+ assert raman_range[0] < raman_range[1] # Valid range
23
+ assert ftir_range[0] < ftir_range[1] # Valid range
24
+ assert ftir_range[0] >= 400 # FTIR starts at 400 cm⁻¹
25
+ assert ftir_range[1] <= 4000 # FTIR ends at 4000 cm⁻¹
26
+
27
+
28
+ def test_validate_spectrum_range():
29
+ """Test spectrum range validation for different modalities."""
30
+ # Test Raman range validation
31
+ raman_x = np.linspace(300, 3500, 100) # Typical Raman range
32
+ assert validate_spectrum_range(raman_x, "raman") == True
33
+
34
+ # Test FTIR range validation
35
+ ftir_x = np.linspace(500, 3800, 100) # Typical FTIR range
36
+ assert validate_spectrum_range(ftir_x, "ftir") == True
37
+
38
+ # Test out-of-range data
39
+ out_of_range_x = np.linspace(50, 150, 100) # Too low for either
40
+ assert validate_spectrum_range(out_of_range_x, "raman") == False
41
+ assert validate_spectrum_range(out_of_range_x, "ftir") == False
42
+
43
+
44
+ def test_ftir_preprocessing():
45
+ """Test FTIR-specific preprocessing parameters."""
46
+ # Generate synthetic FTIR spectrum
47
+ x = np.linspace(400, 4000, 200) # FTIR range
48
+ y = np.sin(x / 500) + 0.1 * np.random.randn(len(x)) + 2.0 # Synthetic absorbance
49
+
50
+ # Test FTIR preprocessing
51
+ x_proc, y_proc = preprocess_spectrum(x, y, modality="ftir", target_len=500)
52
+
53
+ assert x_proc.shape == (500,)
54
+ assert y_proc.shape == (500,)
55
+ assert np.all(np.diff(x_proc) > 0) # Monotonic increasing
56
+ assert np.min(y_proc) >= 0.0 # Normalized to [0, 1]
57
+ assert np.max(y_proc) <= 1.0
58
+
59
+
60
+ def test_raman_preprocessing():
61
+ """Test Raman-specific preprocessing parameters."""
62
+ # Generate synthetic Raman spectrum
63
+ x = np.linspace(200, 3500, 200) # Raman range
64
+ y = np.exp(-(((x - 1500) / 200) ** 2)) + 0.05 * np.random.randn(
65
+ len(x)
66
+ ) # Gaussian peak
67
+
68
+ # Test Raman preprocessing
69
+ x_proc, y_proc = preprocess_spectrum(x, y, modality="raman", target_len=500)
70
+
71
+ assert x_proc.shape == (500,)
72
+ assert y_proc.shape == (500,)
73
+ assert np.all(np.diff(x_proc) > 0) # Monotonic increasing
74
+ assert np.min(y_proc) >= 0.0 # Normalized to [0, 1]
75
+ assert np.max(y_proc) <= 1.0
76
+
77
+
78
+ def test_modality_specific_parameters():
79
+ """Test that different modalities use different default parameters."""
80
+ x = np.linspace(400, 4000, 200)
81
+ y = np.sin(x / 500) + 1.0
82
+
83
+ # Test that FTIR uses different window length than Raman
84
+ ftir_params = MODALITY_PARAMS["ftir"]
85
+ raman_params = MODALITY_PARAMS["raman"]
86
+
87
+ assert ftir_params["smooth_window"] != raman_params["smooth_window"]
88
+
89
+ # Preprocess with both modalities (should use different parameters)
90
+ x_raman, y_raman = preprocess_spectrum(x, y, modality="raman")
91
+ x_ftir, y_ftir = preprocess_spectrum(x, y, modality="ftir")
92
+
93
+ # Results should be slightly different due to different parameters
94
+ assert not np.allclose(y_raman, y_ftir, rtol=1e-10)
95
+
96
+
97
+ def test_get_modality_info():
98
+ """Test modality information retrieval."""
99
+ raman_info = get_modality_info("raman")
100
+ ftir_info = get_modality_info("ftir")
101
+
102
+ assert "range" in raman_info
103
+ assert "params" in raman_info
104
+ assert "range" in ftir_info
105
+ assert "params" in ftir_info
106
+
107
+ # Check that ranges match expected values
108
+ assert raman_info["range"] == MODALITY_RANGES["raman"]
109
+ assert ftir_info["range"] == MODALITY_RANGES["ftir"]
110
+
111
+ # Check that parameters are present
112
+ assert "baseline_degree" in raman_info["params"]
113
+ assert "smooth_window" in ftir_info["params"]
114
+
115
+
116
+ def test_invalid_modality():
117
+ """Test handling of invalid modality."""
118
+ x = np.linspace(1000, 2000, 100)
119
+ y = np.sin(x / 100)
120
+
121
+ with pytest.raises(ValueError, match="Unsupported modality"):
122
+ preprocess_spectrum(x, y, modality="invalid")
123
+
124
+ with pytest.raises(ValueError, match="Unknown modality"):
125
+ validate_spectrum_range(x, "invalid")
126
+
127
+ with pytest.raises(ValueError, match="Unknown modality"):
128
+ get_modality_info("invalid")
129
+
130
+
131
+ def test_modality_parameter_override():
132
+ """Test that modality defaults can be overridden."""
133
+ x = np.linspace(400, 4000, 100)
134
+ y = np.sin(x / 500) + 1.0
135
+
136
+ # Override FTIR default window length
137
+ custom_window = 21 # Different from FTIR default (13)
138
+
139
+ x_proc, y_proc = preprocess_spectrum(
140
+ x, y, modality="ftir", window_length=custom_window
141
+ )
142
+
143
+ assert x_proc.shape[0] > 0
144
+ assert y_proc.shape[0] > 0
145
+
146
+
147
+ def test_range_validation_warning():
148
+ """Test that range validation warnings work correctly."""
149
+ # Create spectrum outside typical FTIR range
150
+ x_bad = np.linspace(100, 300, 50) # Too low for FTIR
151
+ y_bad = np.ones_like(x_bad)
152
+
153
+ # Should still process but with validation disabled
154
+ x_proc, y_proc = preprocess_spectrum(
155
+ x_bad, y_bad, modality="ftir", validate_range=False # Disable validation
156
+ )
157
+
158
+ assert len(x_proc) > 0
159
+ assert len(y_proc) > 0
160
+
161
+
162
+ def test_backwards_compatibility():
163
+ """Test that old preprocessing calls still work (defaults to Raman)."""
164
+ x = np.linspace(1000, 2000, 100)
165
+ y = np.sin(x / 100)
166
+
167
+ # Old style call (should default to Raman)
168
+ x_old, y_old = preprocess_spectrum(x, y)
169
+
170
+ # New style call with explicit Raman
171
+ x_new, y_new = preprocess_spectrum(x, y, modality="raman")
172
+
173
+ # Should be identical
174
+ np.testing.assert_array_equal(x_old, x_new)
175
+ np.testing.assert_array_equal(y_old, y_new)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ pytest.main([__file__])
tests/test_multi_format.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for multi-format file parsing functionality."""
2
+
3
+ import pytest
4
+ import numpy as np
5
+ from utils.multifile import (
6
+ parse_spectrum_data,
7
+ detect_file_format,
8
+ parse_json_spectrum,
9
+ parse_csv_spectrum,
10
+ parse_txt_spectrum,
11
+ )
12
+
13
+
14
+ def test_detect_file_format():
15
+ """Test automatic file format detection."""
16
+ # JSON detection
17
+ json_content = '{"wavenumbers": [1, 2, 3], "intensities": [0.1, 0.2, 0.3]}'
18
+ assert detect_file_format("test.json", json_content) == "json"
19
+
20
+ # CSV detection
21
+ csv_content = "wavenumber,intensity\n1000,0.5\n1001,0.6"
22
+ assert detect_file_format("test.csv", csv_content) == "csv"
23
+
24
+ # TXT detection (default)
25
+ txt_content = "1000 0.5\n1001 0.6"
26
+ assert detect_file_format("test.txt", txt_content) == "txt"
27
+
28
+
29
+ def test_parse_json_spectrum():
30
+ """Test JSON spectrum parsing."""
31
+ # Test object format
32
+ json_content = '{"wavenumbers": [1000, 1001, 1002], "intensities": [0.1, 0.2, 0.3]}'
33
+ x, y = parse_json_spectrum(json_content)
34
+
35
+ expected_x = np.array([1000, 1001, 1002])
36
+ expected_y = np.array([0.1, 0.2, 0.3])
37
+
38
+ np.testing.assert_array_equal(x, expected_x)
39
+ np.testing.assert_array_equal(y, expected_y)
40
+
41
+ # Test alternative key names
42
+ json_content_alt = '{"x": [1000, 1001, 1002], "y": [0.1, 0.2, 0.3]}'
43
+ x_alt, y_alt = parse_json_spectrum(json_content_alt)
44
+ np.testing.assert_array_equal(x_alt, expected_x)
45
+ np.testing.assert_array_equal(y_alt, expected_y)
46
+
47
+ # Test array of objects format
48
+ json_array = """[
49
+ {"wavenumber": 1000, "intensity": 0.1},
50
+ {"wavenumber": 1001, "intensity": 0.2},
51
+ {"wavenumber": 1002, "intensity": 0.3}
52
+ ]"""
53
+ x_arr, y_arr = parse_json_spectrum(json_array)
54
+ np.testing.assert_array_equal(x_arr, expected_x)
55
+ np.testing.assert_array_equal(y_arr, expected_y)
56
+
57
+
58
+ def test_parse_csv_spectrum():
59
+ """Test CSV spectrum parsing."""
60
+ # Test with headers
61
+ csv_with_headers = """wavenumber,intensity
62
+ 1000,0.1
63
+ 1001,0.2
64
+ 1002,0.3
65
+ 1003,0.4
66
+ 1004,0.5
67
+ 1005,0.6
68
+ 1006,0.7
69
+ 1007,0.8
70
+ 1008,0.9
71
+ 1009,1.0
72
+ 1010,1.1
73
+ 1011,1.2"""
74
+
75
+ x, y = parse_csv_spectrum(csv_with_headers)
76
+ expected_x = np.array(
77
+ [1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011]
78
+ )
79
+ expected_y = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
80
+
81
+ np.testing.assert_array_equal(x, expected_x)
82
+ np.testing.assert_array_equal(y, expected_y)
83
+
84
+ # Test without headers
85
+ csv_no_headers = """1000,0.1
86
+ 1001,0.2
87
+ 1002,0.3
88
+ 1003,0.4
89
+ 1004,0.5
90
+ 1005,0.6
91
+ 1006,0.7
92
+ 1007,0.8
93
+ 1008,0.9
94
+ 1009,1.0
95
+ 1010,1.1
96
+ 1011,1.2"""
97
+
98
+ x_no_h, y_no_h = parse_csv_spectrum(csv_no_headers)
99
+ np.testing.assert_array_equal(x_no_h, expected_x)
100
+ np.testing.assert_array_equal(y_no_h, expected_y)
101
+
102
+ # Test semicolon delimiter
103
+ csv_semicolon = """1000;0.1
104
+ 1001;0.2
105
+ 1002;0.3
106
+ 1003;0.4
107
+ 1004;0.5
108
+ 1005;0.6
109
+ 1006;0.7
110
+ 1007;0.8
111
+ 1008;0.9
112
+ 1009;1.0
113
+ 1010;1.1
114
+ 1011;1.2"""
115
+
116
+ x_semi, y_semi = parse_csv_spectrum(csv_semicolon)
117
+ np.testing.assert_array_equal(x_semi, expected_x)
118
+ np.testing.assert_array_equal(y_semi, expected_y)
119
+
120
+
121
+ def test_parse_txt_spectrum():
122
+ """Test TXT spectrum parsing."""
123
+ txt_content = """# Comment line
124
+ 1000 0.1
125
+ 1001 0.2
126
+ 1002 0.3
127
+ 1003 0.4
128
+ 1004 0.5
129
+ 1005 0.6
130
+ 1006 0.7
131
+ 1007 0.8
132
+ 1008 0.9
133
+ 1009 1.0
134
+ 1010 1.1
135
+ 1011 1.2"""
136
+
137
+ x, y = parse_txt_spectrum(txt_content)
138
+ expected_x = np.array(
139
+ [1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011]
140
+ )
141
+ expected_y = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
142
+
143
+ np.testing.assert_array_equal(x, expected_x)
144
+ np.testing.assert_array_equal(y, expected_y)
145
+
146
+ # Test comma-separated
147
+ txt_comma = """1000,0.1
148
+ 1001,0.2
149
+ 1002,0.3
150
+ 1003,0.4
151
+ 1004,0.5
152
+ 1005,0.6
153
+ 1006,0.7
154
+ 1007,0.8
155
+ 1008,0.9
156
+ 1009,1.0
157
+ 1010,1.1
158
+ 1011,1.2"""
159
+
160
+ x_comma, y_comma = parse_txt_spectrum(txt_comma)
161
+ np.testing.assert_array_equal(x_comma, expected_x)
162
+ np.testing.assert_array_equal(y_comma, expected_y)
163
+
164
+
165
+ def test_parse_spectrum_data_integration():
166
+ """Test integrated spectrum data parsing with format detection."""
167
+ # Test automatic format detection and parsing
168
+ test_cases = [
169
+ (
170
+ '{"wavenumbers": [1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011], "intensities": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]}',
171
+ "test.json",
172
+ ),
173
+ (
174
+ "wavenumber,intensity\n1000,0.1\n1001,0.2\n1002,0.3\n1003,0.4\n1004,0.5\n1005,0.6\n1006,0.7\n1007,0.8\n1008,0.9\n1009,1.0\n1010,1.1\n1011,1.2",
175
+ "test.csv",
176
+ ),
177
+ (
178
+ "1000 0.1\n1001 0.2\n1002 0.3\n1003 0.4\n1004 0.5\n1005 0.6\n1006 0.7\n1007 0.8\n1008 0.9\n1009 1.0\n1010 1.1\n1011 1.2",
179
+ "test.txt",
180
+ ),
181
+ ]
182
+
183
+ for content, filename in test_cases:
184
+ x, y = parse_spectrum_data(content, filename)
185
+ assert len(x) >= 10
186
+ assert len(y) >= 10
187
+ assert len(x) == len(y)
188
+
189
+
190
+ def test_insufficient_data_points():
191
+ """Test handling of insufficient data points."""
192
+ # Test with too few points
193
+ insufficient_data = "1000 0.1\n1001 0.2" # Only 2 points, need at least 10
194
+
195
+ with pytest.raises(ValueError, match="Insufficient data points"):
196
+ parse_txt_spectrum(insufficient_data, "test.txt")
197
+
198
+
199
+ def test_invalid_json():
200
+ """Test handling of invalid JSON."""
201
+ invalid_json = (
202
+ '{"wavenumbers": [1000, 1001], "intensities": [0.1}' # Missing closing bracket
203
+ )
204
+
205
+ with pytest.raises(ValueError, match="Invalid JSON format"):
206
+ parse_json_spectrum(invalid_json)
207
+
208
+
209
+ def test_empty_file():
210
+ """Test handling of empty files."""
211
+ empty_content = ""
212
+
213
+ with pytest.raises(ValueError, match="No data lines found"):
214
+ parse_txt_spectrum(empty_content, "empty.txt")
215
+
216
+
217
+ if __name__ == "__main__":
218
+ pytest.main([__file__])
utils/multifile.py CHANGED
@@ -1,11 +1,16 @@
1
- """Multi-file processing utiltities for batch inference.
2
- Handles multiple file uploads and iterative processing."""
 
3
 
4
- from typing import List, Dict, Any, Tuple, Optional
5
  import time
6
  import streamlit as st
7
  import numpy as np
8
  import pandas as pd
 
 
 
 
9
 
10
  from .preprocessing import resample_spectrum
11
  from .errors import ErrorHandler, safe_execute
@@ -13,83 +18,230 @@ from .results_manager import ResultsManager
13
  from .confidence import calculate_softmax_confidence
14
 
15
 
16
- def parse_spectrum_data(
17
- text_content: str, filename: str = "unknown"
18
- ) -> Tuple[np.ndarray, np.ndarray]:
19
- """
20
- Parse spectrum data from text content
21
 
22
  Args:
23
- text_content: Raw text content of the spectrum file
24
- filename: Name of the file for error reporting
25
 
26
  Returns:
27
- Tuple of (x_values, y_values) as numpy arrays
28
-
29
- Raises:
30
- ValueError: If the data cannot be parsed
31
  """
32
- try:
33
- lines = text_content.strip().split("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # ==Remove empty lines and comments==
36
- data_lines = []
37
- for line in lines:
38
- line = line.strip()
39
- if line and not line.startswith("#") and not line.startswith("%"):
40
- data_lines.append(line)
41
 
42
- if not data_lines:
43
- raise ValueError("No data lines found in file")
44
 
45
- # ==Try to parse==
46
- x_vals, y_vals = [], []
47
 
48
- for i, line in enumerate(data_lines):
49
- try:
50
- # Handle different separators
51
- parts = line.replace(",", " ").split()
52
- numbers = [
53
- p
54
- for p in parts
55
- if p.replace(".", "", 1)
56
- .replace("-", "", 1)
57
- .replace("+", "", 1)
58
- .isdigit()
59
- ]
60
- if len(numbers) >= 2:
61
- x_val = float(numbers[0])
62
- y_val = float(numbers[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  x_vals.append(x_val)
64
  y_vals.append(y_val)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  except ValueError:
67
  ErrorHandler.log_warning(
68
- f"Could not parse line {i+1}: {line}", f"Parsing {filename}"
69
  )
70
  continue
71
 
72
- if len(x_vals) < 10: # ==Need minimum points for interpolation==
73
  raise ValueError(
74
  f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
75
  )
76
 
77
- x = np.array(x_vals)
78
- y = np.array(y_vals)
79
 
80
- # Check for NaNs
81
- if np.any(np.isnan(x)) or np.any(np.isnan(y)):
82
- raise ValueError("Input data contains NaN values")
83
 
84
- # Check monotonic increasing x
85
- if not np.all(np.diff(x) > 0):
86
- raise ValueError("Wavenumbers must be strictly increasing")
87
 
88
- # Check reasonable range for Raman spectroscopy
89
- if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
90
- raise ValueError(
91
- f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100"
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  return x, y
95
 
@@ -97,6 +249,95 @@ def parse_spectrum_data(
97
  raise ValueError(f"Failed to parse spectrum data: {str(e)}")
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def process_single_file(
101
  filename: str,
102
  text_content: str,
 
1
+ """Multi-file processing utilities for batch inference.
2
+ Handles multiple file uploads and iterative processing.
3
+ Supports TXT, CSV, and JSON file formats with automatic detection."""
4
 
5
+ from typing import List, Dict, Any, Tuple, Optional, Union
6
  import time
7
  import streamlit as st
8
  import numpy as np
9
  import pandas as pd
10
+ import json
11
+ import csv
12
+ import io
13
+ from pathlib import Path
14
 
15
  from .preprocessing import resample_spectrum
16
  from .errors import ErrorHandler, safe_execute
 
18
  from .confidence import calculate_softmax_confidence
19
 
20
 
21
+ def detect_file_format(filename: str, content: str) -> str:
22
+ """Automatically detect file format based on exstention and content
 
 
 
23
 
24
  Args:
25
+ filename: Name of the file
26
+ content: Content of the file
27
 
28
  Returns:
29
+ File format: .'txt', .'csv', .'json'
 
 
 
30
  """
31
+ # First try by extension
32
+ suffix = Path(filename).suffix.lower()
33
+ if suffix == ".json":
34
+ try:
35
+ json.loads(content)
36
+ return "json"
37
+ except:
38
+ pass
39
+ elif suffix == ".csv":
40
+ return "csv"
41
+ elif suffix == ".txt":
42
+ return "txt"
43
+
44
+ # If extension doesn't match or is unclear, try content detection
45
+ content_stripped = content.strip()
46
+
47
+ # Try JSON
48
+ if content_stripped.startswith(("{", "[")):
49
+ try:
50
+ json.loads(content)
51
+ return "json"
52
+ except:
53
+ pass
54
 
55
+ # Try CSV (look for commas in first few lines)
56
+ lines = content_stripped.split("\n")[:5]
57
+ comma_count = sum(line.count(",") for line in lines)
58
+ if comma_count > len(lines): # More commas than lines suggests CSV
59
+ return "csv"
 
60
 
61
+ # Default to TXT
62
+ return "txt"
63
 
 
 
64
 
65
+ # /////////////////////////////////////////////////////
66
+
67
+
68
+ def parse_json_spectrum(
69
+ content: str, filename: str = "unknown"
70
+ ) -> Tuple[np.ndarray, np.ndarray]:
71
+ """
72
+ Parse spectrum data from JSON format.
73
+
74
+ Expected formats:
75
+ - {"wavenumbers": [...], "intensities": [...]}
76
+ - {"x": [...], "y": [...]}
77
+ - [{"wavenumber": val, "intensity": val}, ...]
78
+ """
79
+
80
+ try:
81
+ data = json.load(content)
82
+
83
+ # Format 1: Object with arrays
84
+ if isinstance(data, dict):
85
+ x_key = None
86
+ y_key = None
87
+
88
+ # Try common key names for x-axis
89
+ for key in ["wavenumbers", "wavenumber", "x", "freq", "frequency"]:
90
+ if key in data:
91
+ x_key = key
92
+ break
93
+
94
+ # Try common key names for y-axis
95
+ for key in ["intensities", "intensity", "y", "counts", "absorbance"]:
96
+ if key in data:
97
+ y_key = key
98
+ break
99
+
100
+ if x_key and y_key:
101
+ x_vals = np.array(data[x_key], dtype=float)
102
+ y_vals = np.array(data[y_key], dtype=float)
103
+ return x_vals, y_vals
104
+
105
+ # Format 2: Array of objects
106
+ elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
107
+ x_vals = []
108
+ y_vals = []
109
+
110
+ for item in data:
111
+ # Try to find x and y values
112
+ x_val = None
113
+ y_val = None
114
+
115
+ for x_key in ["wavenumber", "wavenumbers", "x", "freq"]:
116
+ if x_key in item:
117
+ x_val = float(item[x_key])
118
+ break
119
+
120
+ for y_key in ["intensity", "intensities", "y", "counts"]:
121
+ if y_key in item:
122
+ y_val = float(item[y_key])
123
+ break
124
+
125
+ if x_val is not None and y_val is not None:
126
  x_vals.append(x_val)
127
  y_vals.append(y_val)
128
 
129
+ if x_vals and y_vals:
130
+ return np.array(x_vals), np.array(y_vals)
131
+
132
+ raise ValueError(
133
+ "JSON format not recognized. Expected wavenumber/intensity pairs."
134
+ )
135
+
136
+ except json.JSONDecodeError as e:
137
+ raise ValueError(f"Invalid JSON format: {str(e)}")
138
+ except Exception as e:
139
+ raise ValueError(f"Failed to parse JSON spectrum: {str(e)}")
140
+
141
+
142
+ # /////////////////////////////////////////////////////
143
+
144
+
145
+ def parse_csv_spectrum(
146
+ content: str, filename: str = "unknown"
147
+ ) -> Tuple[np.ndarray, np.ndarray]:
148
+ """
149
+ Parse spectrum data from CSV format.
150
+
151
+ Handles various CSV formats with headers or without.
152
+ """
153
+ try:
154
+ # Use StringIO to treat string as file-like object
155
+ csv_file = io.StringIO(content)
156
+
157
+ # Try to detect delimiter
158
+ sample = content[:1024]
159
+ delimiter = ","
160
+ if sample.count(";") > sample.count(","):
161
+ delimiter = ";"
162
+ elif sample.count("\t") > sample.count(","):
163
+ delimiter = "\t"
164
+
165
+ # Read CSV
166
+ csv_reader = csv.reader(csv_file, delimiter=delimiter)
167
+ rows = list(csv_reader)
168
+
169
+ if not rows:
170
+ raise ValueError("Empty CSV file")
171
+
172
+ # Check if first row is header
173
+ has_header = False
174
+ try:
175
+ # If first row contains non-numeric data, it's likely a header
176
+ float(rows[0][0])
177
+ float(rows[0][1])
178
+ except (ValueError, IndexError):
179
+ has_header = True
180
+
181
+ data_rows = rows[1:] if has_header else rows
182
+
183
+ # Extract x and y values
184
+ x_vals = []
185
+ y_vals = []
186
+
187
+ for i, row in enumerate(data_rows):
188
+ if len(row) < 2:
189
+ continue
190
+
191
+ try:
192
+ x_val = float(row[0])
193
+ y_val = float(row[1])
194
+ x_vals.append(x_val)
195
+ y_vals.append(y_val)
196
  except ValueError:
197
  ErrorHandler.log_warning(
198
+ f"Could not parse CSV row {i+1}: {row}", f"Parsing {filename}"
199
  )
200
  continue
201
 
202
+ if len(x_vals) < 10:
203
  raise ValueError(
204
  f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
205
  )
206
 
207
+ return np.array(x_vals), np.array(y_vals)
 
208
 
209
+ except Exception as e:
210
+ raise ValueError(f"Failed to parse CSV spectrum: {str(e)}")
 
211
 
 
 
 
212
 
213
+ # /////////////////////////////////////////////////////
214
+
215
+
216
+ def parse_spectrum_data(
217
+ text_content: str, filename: str = "unknown", file_format: Optional[str] = None
218
+ ) -> Tuple[np.ndarray, np.ndarray]:
219
+ """
220
+ Parse spectrum data from text content with automatic format detection.
221
+ Args:
222
+ text_content: Raw text content of the spectrum file
223
+ filename: Name of the file for error reporting
224
+ file_format: Force specific format ('txt', 'csv', 'json') or None for auto-detection
225
+ Returns:
226
+ Tuple of (x_values, y_values) as numpy arrays
227
+ Raises:
228
+ ValueError: If the data cannot be parsed
229
+ """
230
+ try:
231
+ # Detect format if not specified
232
+ if file_format is None:
233
+ file_format = detect_file_format(filename, text_content)
234
+
235
+ # Parse based on detected/specified format
236
+ if file_format == "json":
237
+ x, y = parse_json_spectrum(text_content, filename)
238
+ elif file_format == "csv":
239
+ x, y = parse_csv_spectrum(text_content, filename)
240
+ else: # Default to TXT format
241
+ x, y = parse_txt_spectrum(text_content, filename)
242
+
243
+ # Common validation for all formats
244
+ validate_spectrum_data(x, y, filename)
245
 
246
  return x, y
247
 
 
249
  raise ValueError(f"Failed to parse spectrum data: {str(e)}")
250
 
251
 
252
+ # /////////////////////////////////////////////////////
253
+
254
+
255
+ def parse_txt_spectrum(
256
+ content: str, filename: str = "unknown"
257
+ ) -> Tuple[np.ndarray, np.ndarray]:
258
+ """
259
+ Parse spectrum data from TXT format (original implementation).
260
+ """
261
+ lines = content.strip().split("\n")
262
+
263
+ # ==Remove empty lines and comments==
264
+ data_lines = []
265
+ for line in lines:
266
+ line = line.strip()
267
+ if line and not line.startswith("#") and not line.startswith("%"):
268
+ data_lines.append(line)
269
+
270
+ if not data_lines:
271
+ raise ValueError("No data lines found in file")
272
+
273
+ # ==Try to parse==
274
+ x_vals, y_vals = [], []
275
+
276
+ for i, line in enumerate(data_lines):
277
+ try:
278
+ # Handle different separators
279
+ parts = line.replace(",", " ").split()
280
+ numbers = [
281
+ p
282
+ for p in parts
283
+ if p.replace(".", "", 1)
284
+ .replace("-", "", 1)
285
+ .replace("+", "", 1)
286
+ .isdigit()
287
+ ]
288
+ if len(numbers) >= 2:
289
+ x_val = float(numbers[0])
290
+ y_val = float(numbers[1])
291
+ x_vals.append(x_val)
292
+ y_vals.append(y_val)
293
+
294
+ except ValueError:
295
+ ErrorHandler.log_warning(
296
+ f"Could not parse line {i+1}: {line}", f"Parsing {filename}"
297
+ )
298
+ continue
299
+
300
+ if len(x_vals) < 10: # ==Need minimum points for interpolation==
301
+ raise ValueError(
302
+ f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
303
+ )
304
+
305
+ return np.array(x_vals), np.array(y_vals)
306
+
307
+
308
+ # /////////////////////////////////////////////////////
309
+
310
+
311
+ def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
312
+ """
313
+ Validate parsed spectrum data for common issues.
314
+ """
315
+ # Check for NaNs
316
+ if np.any(np.isnan(x)) or np.any(np.isnan(y)):
317
+ raise ValueError("Input data contains NaN values")
318
+
319
+ # Check monotonic increasing x (sort if needed)
320
+ if not np.all(np.diff(x) >= 0):
321
+ # Sort by x values if not monotonic
322
+ sort_idx = np.argsort(x)
323
+ x = x[sort_idx]
324
+ y = y[sort_idx]
325
+ ErrorHandler.log_warning(
326
+ "Wavenumbers were not monotonic - data has been sorted",
327
+ f"Parsing {filename}",
328
+ )
329
+
330
+ # Check reasonable range for spectroscopy
331
+ if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
332
+ ErrorHandler.log_warning(
333
+ f"Unusual wavenumber range: {min(x):.1f} - {max(x):.1f} cm⁻¹",
334
+ f"Parsing {filename}",
335
+ )
336
+
337
+
338
+ # /////////////////////////////////////////////////////
339
+
340
+
341
  def process_single_file(
342
  filename: str,
343
  text_content: str,
utils/performance_tracker.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Performance tracking and logging utilities for POLYMEROS platform."""
2
+
3
+ import time
4
+ import json
5
+ import sqlite3
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Dict, List, Any, Optional
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import streamlit as st
12
+ from dataclasses import dataclass, asdict
13
+ from contextlib import contextmanager
14
+
15
+
16
+ @dataclass
17
+ class PerformanceMetrics:
18
+ """Data class for performance metrics."""
19
+
20
+ model_name: str
21
+ prediction_time: float
22
+ preprocessing_time: float
23
+ total_time: float
24
+ memory_usage_mb: float
25
+ accuracy: Optional[float]
26
+ confidence: float
27
+ timestamp: str
28
+ input_size: int
29
+ modality: str
30
+
31
+ def to_dict(self) -> Dict[str, Any]:
32
+ return asdict(self)
33
+
34
+
35
+ class PerformanceTracker:
36
+ """Automatic performance tracking and logging system."""
37
+
38
+ def __init__(self, db_path: str = "outputs/performance_tracking.db"):
39
+ self.db_path = Path(db_path)
40
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
41
+ self._init_database()
42
+
43
+ def _init_database(self):
44
+ """Initialize SQLite database for performance tracking."""
45
+ with sqlite3.connect(self.db_path) as conn:
46
+ conn.execute(
47
+ """
48
+ CREATE TABLE IF NOT EXISTS performance_metrics (
49
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
50
+ model_name TEXT NOT NULL,
51
+ prediction_time REAL NOT NULL,
52
+ preprocessing_time REAL NOT NULL,
53
+ total_time REAL NOT NULL,
54
+ memory_usage_mb REAL,
55
+ accuracy REAL,
56
+ confidence REAL NOT NULL,
57
+ timestamp TEXT NOT NULL,
58
+ input_size INTEGER NOT NULL,
59
+ modality TEXT NOT NULL
60
+ )
61
+ """
62
+ )
63
+ conn.commit()
64
+
65
+ def log_performance(self, metrics: PerformanceMetrics):
66
+ """Log performance metrics to database."""
67
+ with sqlite3.connect(self.db_path) as conn:
68
+ conn.execute(
69
+ """
70
+ INSERT INTO performance_metrics
71
+ (model_name, prediction_time, preprocessing_time, total_time,
72
+ memory_usage_mb, accuracy, confidence, timestamp, input_size, modality)
73
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
74
+ """,
75
+ (
76
+ metrics.model_name,
77
+ metrics.prediction_time,
78
+ metrics.preprocessing_time,
79
+ metrics.total_time,
80
+ metrics.memory_usage_mb,
81
+ metrics.accuracy,
82
+ metrics.confidence,
83
+ metrics.timestamp,
84
+ metrics.input_size,
85
+ metrics.modality,
86
+ ),
87
+ )
88
+ conn.commit()
89
+
90
+ @contextmanager
91
+ def track_inference(self, model_name: str, modality: str = "raman"):
92
+ """Context manager for automatic performance tracking."""
93
+ start_time = time.time()
94
+ start_memory = self._get_memory_usage()
95
+
96
+ tracking_data = {
97
+ "model_name": model_name,
98
+ "modality": modality,
99
+ "start_time": start_time,
100
+ "start_memory": start_memory,
101
+ "preprocessing_time": 0.0,
102
+ }
103
+
104
+ try:
105
+ yield tracking_data
106
+ finally:
107
+ end_time = time.time()
108
+ end_memory = self._get_memory_usage()
109
+
110
+ total_time = end_time - start_time
111
+ memory_usage = max(end_memory - start_memory, 0)
112
+
113
+ # Create metrics object if not provided
114
+ if "metrics" not in tracking_data:
115
+ metrics = PerformanceMetrics(
116
+ model_name=model_name,
117
+ prediction_time=tracking_data.get("prediction_time", total_time),
118
+ preprocessing_time=tracking_data.get("preprocessing_time", 0.0),
119
+ total_time=total_time,
120
+ memory_usage_mb=memory_usage,
121
+ accuracy=tracking_data.get("accuracy"),
122
+ confidence=tracking_data.get("confidence", 0.0),
123
+ timestamp=datetime.now().isoformat(),
124
+ input_size=tracking_data.get("input_size", 0),
125
+ modality=modality,
126
+ )
127
+ self.log_performance(metrics)
128
+
129
+ def _get_memory_usage(self) -> float:
130
+ """Get current memory usage in MB."""
131
+ try:
132
+ import psutil
133
+
134
+ process = psutil.Process()
135
+ return process.memory_info().rss / 1024 / 1024 # Convert to MB
136
+ except ImportError:
137
+ return 0.0 # psutil not available
138
+
139
+ def get_recent_metrics(self, limit: int = 100) -> List[Dict[str, Any]]:
140
+ """Get recent performance metrics."""
141
+ with sqlite3.connect(self.db_path) as conn:
142
+ conn.row_factory = sqlite3.Row # Enable column access by name
143
+ cursor = conn.execute(
144
+ """
145
+ SELECT * FROM performance_metrics
146
+ ORDER BY timestamp DESC
147
+ LIMIT ?
148
+ """,
149
+ (limit,),
150
+ )
151
+ return [dict(row) for row in cursor.fetchall()]
152
+
153
+ def get_model_statistics(self, model_name: Optional[str] = None) -> Dict[str, Any]:
154
+ """Get statistical summary of model performance."""
155
+ where_clause = "WHERE model_name = ?" if model_name else ""
156
+ params = (model_name,) if model_name else ()
157
+
158
+ with sqlite3.connect(self.db_path) as conn:
159
+ cursor = conn.execute(
160
+ f"""
161
+ SELECT
162
+ model_name,
163
+ COUNT(*) as total_inferences,
164
+ AVG(prediction_time) as avg_prediction_time,
165
+ AVG(preprocessing_time) as avg_preprocessing_time,
166
+ AVG(total_time) as avg_total_time,
167
+ AVG(memory_usage_mb) as avg_memory_usage,
168
+ AVG(confidence) as avg_confidence,
169
+ MIN(total_time) as fastest_inference,
170
+ MAX(total_time) as slowest_inference
171
+ FROM performance_metrics
172
+ {where_clause}
173
+ GROUP BY model_name
174
+ """,
175
+ params,
176
+ )
177
+
178
+ results = cursor.fetchall()
179
+ if model_name and results:
180
+ # Return single model stats as dict
181
+ row = results[0]
182
+ return {
183
+ "model_name": row[0],
184
+ "total_inferences": row[1],
185
+ "avg_prediction_time": row[2],
186
+ "avg_preprocessing_time": row[3],
187
+ "avg_total_time": row[4],
188
+ "avg_memory_usage": row[5],
189
+ "avg_confidence": row[6],
190
+ "fastest_inference": row[7],
191
+ "slowest_inference": row[8],
192
+ }
193
+ elif not model_name:
194
+ # Return all models stats as dict of dicts
195
+ return {
196
+ row[0]: {
197
+ "model_name": row[0],
198
+ "total_inferences": row[1],
199
+ "avg_prediction_time": row[2],
200
+ "avg_preprocessing_time": row[3],
201
+ "avg_total_time": row[4],
202
+ "avg_memory_usage": row[5],
203
+ "avg_confidence": row[6],
204
+ "fastest_inference": row[7],
205
+ "slowest_inference": row[8],
206
+ }
207
+ for row in results
208
+ }
209
+ else:
210
+ return {}
211
+
212
+ def create_performance_visualization(self) -> plt.Figure:
213
+ """Create performance visualization charts."""
214
+ metrics = self.get_recent_metrics(50)
215
+
216
+ if not metrics:
217
+ return None
218
+
219
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
220
+
221
+ # Convert to convenient format
222
+ models = [m["model_name"] for m in metrics]
223
+ times = [m["total_time"] for m in metrics]
224
+ confidences = [m["confidence"] for m in metrics]
225
+ timestamps = [datetime.fromisoformat(m["timestamp"]) for m in metrics]
226
+
227
+ # 1. Inference Time Over Time
228
+ ax1.plot(timestamps, times, "o-", alpha=0.7)
229
+ ax1.set_title("Inference Time Over Time")
230
+ ax1.set_ylabel("Time (seconds)")
231
+ ax1.tick_params(axis="x", rotation=45)
232
+
233
+ # 2. Performance by Model
234
+ model_stats = self.get_model_statistics()
235
+ if model_stats:
236
+ model_names = list(model_stats.keys())
237
+ avg_times = [model_stats[m]["avg_total_time"] for m in model_names]
238
+
239
+ ax2.bar(model_names, avg_times, alpha=0.7)
240
+ ax2.set_title("Average Inference Time by Model")
241
+ ax2.set_ylabel("Time (seconds)")
242
+ ax2.tick_params(axis="x", rotation=45)
243
+
244
+ # 3. Confidence Distribution
245
+ ax3.hist(confidences, bins=20, alpha=0.7)
246
+ ax3.set_title("Confidence Score Distribution")
247
+ ax3.set_xlabel("Confidence")
248
+ ax3.set_ylabel("Frequency")
249
+
250
+ # 4. Memory Usage if available
251
+ memory_usage = [
252
+ m["memory_usage_mb"] for m in metrics if m["memory_usage_mb"] is not None
253
+ ]
254
+ if memory_usage:
255
+ ax4.plot(range(len(memory_usage)), memory_usage, "o-", alpha=0.7)
256
+ ax4.set_title("Memory Usage")
257
+ ax4.set_xlabel("Inference Number")
258
+ ax4.set_ylabel("Memory (MB)")
259
+ else:
260
+ ax4.text(
261
+ 0.5,
262
+ 0.5,
263
+ "Memory tracking\nnot available",
264
+ ha="center",
265
+ va="center",
266
+ transform=ax4.transAxes,
267
+ )
268
+ ax4.set_title("Memory Usage")
269
+
270
+ plt.tight_layout()
271
+ return fig
272
+
273
+ def export_metrics(self, format: str = "json") -> str:
274
+ """Export performance metrics in specified format."""
275
+ metrics = self.get_recent_metrics(1000) # Get more for export
276
+
277
+ if format == "json":
278
+ return json.dumps(metrics, indent=2, default=str)
279
+ elif format == "csv":
280
+ import pandas as pd
281
+
282
+ df = pd.DataFrame(metrics)
283
+ return df.to_csv(index=False)
284
+ else:
285
+ raise ValueError(f"Unsupported format: {format}")
286
+
287
+
288
+ # Global tracker instance
289
+ _tracker = None
290
+
291
+
292
+ def get_performance_tracker() -> PerformanceTracker:
293
+ """Get global performance tracker instance."""
294
+ global _tracker
295
+ if _tracker is None:
296
+ _tracker = PerformanceTracker()
297
+ return _tracker
298
+
299
+
300
+ def display_performance_dashboard():
301
+ """Display performance tracking dashboard in Streamlit."""
302
+ tracker = get_performance_tracker()
303
+
304
+ st.markdown("### 📈 Performance Dashboard")
305
+
306
+ # Recent metrics summary
307
+ recent_metrics = tracker.get_recent_metrics(20)
308
+
309
+ if not recent_metrics:
310
+ st.info(
311
+ "No performance data available yet. Run some inferences to see metrics."
312
+ )
313
+ return
314
+
315
+ # Summary statistics
316
+ col1, col2, col3, col4 = st.columns(4)
317
+
318
+ total_inferences = len(recent_metrics)
319
+ avg_time = np.mean([m["total_time"] for m in recent_metrics])
320
+ avg_confidence = np.mean([m["confidence"] for m in recent_metrics])
321
+ unique_models = len(set(m["model_name"] for m in recent_metrics))
322
+
323
+ with col1:
324
+ st.metric("Total Inferences", total_inferences)
325
+ with col2:
326
+ st.metric("Avg Time", f"{avg_time:.3f}s")
327
+ with col3:
328
+ st.metric("Avg Confidence", f"{avg_confidence:.3f}")
329
+ with col4:
330
+ st.metric("Models Used", unique_models)
331
+
332
+ # Performance visualization
333
+ fig = tracker.create_performance_visualization()
334
+ if fig:
335
+ st.pyplot(fig)
336
+
337
+ # Model comparison table
338
+ st.markdown("#### Model Performance Comparison")
339
+ model_stats = tracker.get_model_statistics()
340
+
341
+ if model_stats:
342
+ import pandas as pd
343
+
344
+ stats_data = []
345
+ for model_name, stats in model_stats.items():
346
+ stats_data.append(
347
+ {
348
+ "Model": model_name,
349
+ "Total Inferences": stats["total_inferences"],
350
+ "Avg Time (s)": f"{stats['avg_total_time']:.3f}",
351
+ "Avg Confidence": f"{stats['avg_confidence']:.3f}",
352
+ "Fastest (s)": f"{stats['fastest_inference']:.3f}",
353
+ "Slowest (s)": f"{stats['slowest_inference']:.3f}",
354
+ }
355
+ )
356
+
357
+ df = pd.DataFrame(stats_data)
358
+ st.dataframe(df, use_container_width=True)
359
+
360
+ # Export options
361
+ with st.expander("📥 Export Performance Data"):
362
+ col1, col2 = st.columns(2)
363
+
364
+ with col1:
365
+ if st.button("Export JSON"):
366
+ json_data = tracker.export_metrics("json")
367
+ st.download_button(
368
+ "Download JSON",
369
+ json_data,
370
+ "performance_metrics.json",
371
+ "application/json",
372
+ )
373
+
374
+ with col2:
375
+ if st.button("Export CSV"):
376
+ csv_data = tracker.export_metrics("csv")
377
+ st.download_button(
378
+ "Download CSV", csv_data, "performance_metrics.csv", "text/csv"
379
+ )
380
+
381
+
382
+ if __name__ == "__main__":
383
+ # Test the performance tracker
384
+ tracker = PerformanceTracker()
385
+
386
+ # Simulate some metrics
387
+ for i in range(5):
388
+ metrics = PerformanceMetrics(
389
+ model_name=f"test_model_{i%2}",
390
+ prediction_time=0.1 + i * 0.01,
391
+ preprocessing_time=0.05,
392
+ total_time=0.15 + i * 0.01,
393
+ memory_usage_mb=100 + i * 10,
394
+ accuracy=0.8 + i * 0.02,
395
+ confidence=0.7 + i * 0.05,
396
+ timestamp=datetime.now().isoformat(),
397
+ input_size=500,
398
+ modality="raman",
399
+ )
400
+ tracker.log_performance(metrics)
401
+
402
+ print("Performance tracking test completed!")
403
+ print(f"Recent metrics: {len(tracker.get_recent_metrics())}")
404
+ print(f"Model stats: {tracker.get_model_statistics()}")
utils/preprocessing.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Preprocessing utilities for polymer classification app.
3
  Adapted from the original scripts/preprocess_dataset.py for Hugging Face Spaces deployment.
 
4
  """
5
 
6
  from __future__ import annotations
@@ -9,8 +10,33 @@ from numpy.typing import DTypeLike
9
  from scipy.interpolate import interp1d
10
  from scipy.signal import savgol_filter
11
  from scipy.interpolate import interp1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- TARGET_LENGTH = 500 # Frozen default per PREPROCESSING_BASELINE
14
 
15
  def _ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
16
  x = np.asarray(x, dtype=float)
@@ -19,7 +45,10 @@ def _ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarr
19
  raise ValueError("x and y must be 1D arrays of equal length >= 2")
20
  return x, y
21
 
22
- def resample_spectrum(x: np.ndarray, y: np.ndarray, target_len: int = TARGET_LENGTH) -> tuple[np.ndarray, np.ndarray]:
 
 
 
23
  """Linear re-sampling onto a uniform grid of length target_len."""
24
  x, y = _ensure_1d_equal(x, y)
25
  order = np.argsort(x)
@@ -29,6 +58,7 @@ def resample_spectrum(x: np.ndarray, y: np.ndarray, target_len: int = TARGET_LEN
29
  y_new = f(x_new)
30
  return x_new, y_new
31
 
 
32
  def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
33
  """Polynomial baseline subtraction (degree=2 default)"""
34
  y = np.asarray(y, dtype=float)
@@ -37,19 +67,25 @@ def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
37
  baseline = np.polyval(coeffs, x_idx)
38
  return y - baseline
39
 
40
- def smooth_spectrum(y: np.ndarray, window_length: int = 11, polyorder: int = 2) -> np.ndarray:
 
 
 
41
  """Savitzky-Golay smoothing with safe/odd window enforcement"""
42
  y = np.asarray(y, dtype=float)
43
  window_length = int(window_length)
44
  polyorder = int(polyorder)
45
  # === window must be odd and >= polyorder+1 ===
46
  if window_length % 2 == 0:
47
- window_length += 1
48
  min_win = polyorder + 1
49
  if min_win % 2 == 0:
50
  min_win += 1
51
  window_length = max(window_length, min_win)
52
- return savgol_filter(y, window_length=window_length, polyorder=polyorder, mode="interp")
 
 
 
53
 
54
  def normalize_spectrum(y: np.ndarray) -> np.ndarray:
55
  """Min-max normalization to [0, 1] with constant-signal guard."""
@@ -60,27 +96,114 @@ def normalize_spectrum(y: np.ndarray) -> np.ndarray:
60
  return np.zeros_like(y)
61
  return (y - y_min) / (y_max - y_min)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def preprocess_spectrum(
64
  x: np.ndarray,
65
  y: np.ndarray,
66
  *,
67
  target_len: int = TARGET_LENGTH,
 
68
  do_baseline: bool = True,
69
- degree: int = 2,
70
  do_smooth: bool = True,
71
- window_length: int = 11,
72
- polyorder: int = 2,
73
  do_normalize: bool = True,
74
  out_dtype: DTypeLike = np.float32,
 
75
  ) -> tuple[np.ndarray, np.ndarray]:
76
- """Exact CLI baseline: resample -> baseline -> smooth -> normalize"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  x_rs, y_rs = resample_spectrum(x, y, target_len=target_len)
 
78
  if do_baseline:
79
  y_rs = remove_baseline(y_rs, degree=degree)
 
80
  if do_smooth:
81
  y_rs = smooth_spectrum(y_rs, window_length=window_length, polyorder=polyorder)
 
 
 
 
 
 
 
 
 
 
82
  if do_normalize:
83
  y_rs = normalize_spectrum(y_rs)
 
84
  # === Coerce to a real dtype to satisfy static checkers & runtime ===
85
  out_dt = np.dtype(out_dtype)
86
- return x_rs.astype(out_dt, copy=False), y_rs.astype(out_dt, copy=False)
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Preprocessing utilities for polymer classification app.
3
  Adapted from the original scripts/preprocess_dataset.py for Hugging Face Spaces deployment.
4
+ Supports both Raman and FTIR spectroscopy modalities.
5
  """
6
 
7
  from __future__ import annotations
 
10
  from scipy.interpolate import interp1d
11
  from scipy.signal import savgol_filter
12
  from scipy.interpolate import interp1d
13
+ from typing import Tuple, Literal
14
+
15
+ TARGET_LENGTH = 500 # Frozen default per PREPROCESSING_BASELINE
16
+
17
+ # Modality-specific validation ranges (cm⁻¹)
18
+ MODALITY_RANGES = {
19
+ "raman": (200, 4000), # Typical Raman range
20
+ "ftir": (400, 4000), # FTIR wavenumber range
21
+ }
22
+
23
+ # Modality-specific preprocessing parameters
24
+ MODALITY_PARAMS = {
25
+ "raman": {
26
+ "baseline_degree": 2,
27
+ "smooth_window": 11,
28
+ "smooth_polyorder": 2,
29
+ "cosmic_ray_removal": False,
30
+ },
31
+ "ftir": {
32
+ "baseline_degree": 2,
33
+ "smooth_window": 13, # Slightly larger window for FTIR
34
+ "smooth_polyorder": 2,
35
+ "cosmic_ray_removal": False, # Could add atmospheric correction
36
+ "atmospheric_correction": False, # Placeholder for future implementation
37
+ },
38
+ }
39
 
 
40
 
41
  def _ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
42
  x = np.asarray(x, dtype=float)
 
45
  raise ValueError("x and y must be 1D arrays of equal length >= 2")
46
  return x, y
47
 
48
+
49
+ def resample_spectrum(
50
+ x: np.ndarray, y: np.ndarray, target_len: int = TARGET_LENGTH
51
+ ) -> tuple[np.ndarray, np.ndarray]:
52
  """Linear re-sampling onto a uniform grid of length target_len."""
53
  x, y = _ensure_1d_equal(x, y)
54
  order = np.argsort(x)
 
58
  y_new = f(x_new)
59
  return x_new, y_new
60
 
61
+
62
  def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
63
  """Polynomial baseline subtraction (degree=2 default)"""
64
  y = np.asarray(y, dtype=float)
 
67
  baseline = np.polyval(coeffs, x_idx)
68
  return y - baseline
69
 
70
+
71
+ def smooth_spectrum(
72
+ y: np.ndarray, window_length: int = 11, polyorder: int = 2
73
+ ) -> np.ndarray:
74
  """Savitzky-Golay smoothing with safe/odd window enforcement"""
75
  y = np.asarray(y, dtype=float)
76
  window_length = int(window_length)
77
  polyorder = int(polyorder)
78
  # === window must be odd and >= polyorder+1 ===
79
  if window_length % 2 == 0:
80
+ window_length += 1
81
  min_win = polyorder + 1
82
  if min_win % 2 == 0:
83
  min_win += 1
84
  window_length = max(window_length, min_win)
85
+ return savgol_filter(
86
+ y, window_length=window_length, polyorder=polyorder, mode="interp"
87
+ )
88
+
89
 
90
  def normalize_spectrum(y: np.ndarray) -> np.ndarray:
91
  """Min-max normalization to [0, 1] with constant-signal guard."""
 
96
  return np.zeros_like(y)
97
  return (y - y_min) / (y_max - y_min)
98
 
99
+
100
+ def validate_spectrum_range(x: np.ndarray, modality: str = "raman") -> bool:
101
+ """Validate that spectrum wavenumbers are within expected range for modality."""
102
+ if modality not in MODALITY_RANGES:
103
+ raise ValueError(
104
+ f"Unknown modality '{modality}'. Supported: {list(MODALITY_RANGES.keys())}"
105
+ )
106
+
107
+ min_range, max_range = MODALITY_RANGES[modality]
108
+ x_min, x_max = np.min(x), np.max(x)
109
+
110
+ # Check if majority of data points are within range
111
+ in_range = np.sum((x >= min_range) & (x <= max_range))
112
+ total_points = len(x)
113
+
114
+ return (in_range / total_points) >= 0.7 # At least 70% should be in range
115
+
116
+
117
  def preprocess_spectrum(
118
  x: np.ndarray,
119
  y: np.ndarray,
120
  *,
121
  target_len: int = TARGET_LENGTH,
122
+ modality: str = "raman", # New parameter for modality-specific processing
123
  do_baseline: bool = True,
124
+ degree: int | None = None, # Will use modality default if None
125
  do_smooth: bool = True,
126
+ window_length: int | None = None, # Will use modality default if None
127
+ polyorder: int | None = None, # Will use modality default if None
128
  do_normalize: bool = True,
129
  out_dtype: DTypeLike = np.float32,
130
+ validate_range: bool = True,
131
  ) -> tuple[np.ndarray, np.ndarray]:
132
+ """
133
+ Modality-aware preprocessing: resample -> baseline -> smooth -> normalize
134
+
135
+ Args:
136
+ x, y: Input spectrum data
137
+ target_len: Target length for resampling
138
+ modality: 'raman' or 'ftir' for modality-specific processing
139
+ do_baseline: Enable baseline correction
140
+ degree: Polynomial degree for baseline (uses modality default if None)
141
+ do_smooth: Enable smoothing
142
+ window_length: Smoothing window length (uses modality default if None)
143
+ polyorder: Polynomial order for smoothing (uses modality default if None)
144
+ do_normalize: Enable normalization
145
+ out_dtype: Output data type
146
+ validate_range: Check if wavenumbers are in expected range for modality
147
+
148
+ Returns:
149
+ Tuple of (resampled_x, processed_y)
150
+ """
151
+ # Validate modality
152
+ if modality not in MODALITY_PARAMS:
153
+ raise ValueError(
154
+ f"Unsupported modality '{modality}'. Supported: {list(MODALITY_PARAMS.keys())}"
155
+ )
156
+
157
+ # Get modality-specific parameters
158
+ modality_config = MODALITY_PARAMS[modality]
159
+
160
+ # Use modality defaults if parameters not specified
161
+ if degree is None:
162
+ degree = modality_config["baseline_degree"]
163
+ if window_length is None:
164
+ window_length = modality_config["smooth_window"]
165
+ if polyorder is None:
166
+ polyorder = modality_config["smooth_polyorder"]
167
+
168
+ # Validate spectrum range if requested
169
+ if validate_range:
170
+ if not validate_spectrum_range(x, modality):
171
+ print(
172
+ f"Warning: Spectrum wavenumbers may not be optimal for {modality.upper()} analysis"
173
+ )
174
+
175
+ # Standard preprocessing pipeline
176
  x_rs, y_rs = resample_spectrum(x, y, target_len=target_len)
177
+
178
  if do_baseline:
179
  y_rs = remove_baseline(y_rs, degree=degree)
180
+
181
  if do_smooth:
182
  y_rs = smooth_spectrum(y_rs, window_length=window_length, polyorder=polyorder)
183
+
184
+ # FTIR-specific processing (placeholder for future enhancements)
185
+ if modality == "ftir":
186
+ if modality_config.get("atmospheric_correction", False):
187
+ # Placeholder for atmospheric correction
188
+ pass
189
+ if modality_config.get("cosmic_ray_removal", False):
190
+ # Placeholder for cosmic ray removal
191
+ pass
192
+
193
  if do_normalize:
194
  y_rs = normalize_spectrum(y_rs)
195
+
196
  # === Coerce to a real dtype to satisfy static checkers & runtime ===
197
  out_dt = np.dtype(out_dtype)
198
+ return x_rs.astype(out_dt, copy=False), y_rs.astype(out_dt, copy=False)
199
+
200
+
201
+ def get_modality_info(modality: str) -> dict:
202
+ """Get processing parameters and validation ranges for a modality."""
203
+ if modality not in MODALITY_PARAMS:
204
+ raise ValueError(f"Unknown modality '{modality}'")
205
+
206
+ return {
207
+ "range": MODALITY_RANGES[modality],
208
+ "params": MODALITY_PARAMS[modality].copy(),
209
+ }
utils/results_manager.py CHANGED
@@ -1,14 +1,17 @@
1
  """Session results management for multi-file inference.
2
- Handles in-memory results table and export functionality"""
 
3
 
4
  import streamlit as st
5
  import pandas as pd
6
  import json
7
  from datetime import datetime
8
- from typing import Dict, List, Any, Optional
9
  import numpy as np
10
  from pathlib import Path
11
  import io
 
 
12
 
13
 
14
  def local_css(file_name):
@@ -199,6 +202,219 @@ class ResultsManager:
199
 
200
  return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  @staticmethod
203
  # ==UTILITY FUNCTIONS==
204
  def init_session_state():
 
1
  """Session results management for multi-file inference.
2
+ Handles in-memory results table and export functionality.
3
+ Supports multi-model comparison and statistical analysis."""
4
 
5
  import streamlit as st
6
  import pandas as pd
7
  import json
8
  from datetime import datetime
9
+ from typing import Dict, List, Any, Optional, Tuple
10
  import numpy as np
11
  from pathlib import Path
12
  import io
13
+ from collections import defaultdict
14
+ import matplotlib.pyplot as plt
15
 
16
 
17
  def local_css(file_name):
 
202
 
203
  return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
204
 
205
+ @staticmethod
206
+ def add_multi_model_results(
207
+ filename: str,
208
+ model_results: Dict[str, Dict[str, Any]],
209
+ ground_truth: Optional[int] = None,
210
+ metadata: Optional[Dict[str, Any]] = None,
211
+ ) -> None:
212
+ """
213
+ Add results from multiple models for the same file.
214
+
215
+ Args:
216
+ filename: Name of the processed file
217
+ model_results: Dict with model_name -> result dict
218
+ ground_truth: True label if available
219
+ metadata: Additional file metadata
220
+ """
221
+ for model_name, result in model_results.items():
222
+ ResultsManager.add_results(
223
+ filename=filename,
224
+ model_name=model_name,
225
+ prediction=result["prediction"],
226
+ predicted_class=result["predicted_class"],
227
+ confidence=result["confidence"],
228
+ logits=result["logits"],
229
+ ground_truth=ground_truth,
230
+ processing_time=result.get("processing_time", 0.0),
231
+ metadata=metadata,
232
+ )
233
+
234
+ @staticmethod
235
+ def get_comparison_stats() -> Dict[str, Any]:
236
+ """Get comparative statistics across all models."""
237
+ results = ResultsManager.get_results()
238
+ if not results:
239
+ return {}
240
+
241
+ # Group results by model
242
+ model_stats = defaultdict(list)
243
+ for result in results:
244
+ model_stats[result["model"]].append(result)
245
+
246
+ comparison = {}
247
+ for model_name, model_results in model_stats.items():
248
+ stats = {
249
+ "total_predictions": len(model_results),
250
+ "avg_confidence": np.mean([r["confidence"] for r in model_results]),
251
+ "std_confidence": np.std([r["confidence"] for r in model_results]),
252
+ "avg_processing_time": np.mean(
253
+ [r["processing_time"] for r in model_results]
254
+ ),
255
+ "stable_predictions": sum(
256
+ 1 for r in model_results if r["prediction"] == 0
257
+ ),
258
+ "weathered_predictions": sum(
259
+ 1 for r in model_results if r["prediction"] == 1
260
+ ),
261
+ }
262
+
263
+ # Calculate accuracy if ground truth available
264
+ with_gt = [r for r in model_results if r["ground_truth"] is not None]
265
+ if with_gt:
266
+ correct = sum(
267
+ 1 for r in with_gt if r["prediction"] == r["ground_truth"]
268
+ )
269
+ stats["accuracy"] = correct / len(with_gt)
270
+ stats["num_with_ground_truth"] = len(with_gt)
271
+ else:
272
+ stats["accuracy"] = None
273
+ stats["num_with_ground_truth"] = 0
274
+
275
+ comparison[model_name] = stats
276
+
277
+ return comparison
278
+
279
+ @staticmethod
280
+ def get_agreement_matrix() -> pd.DataFrame:
281
+ """
282
+ Calculate agreement matrix between models for the same files.
283
+
284
+ Returns:
285
+ DataFrame showing model agreement rates
286
+ """
287
+ results = ResultsManager.get_results()
288
+ if not results:
289
+ return pd.DataFrame()
290
+
291
+ # Group by filename
292
+ file_results = defaultdict(dict)
293
+ for result in results:
294
+ file_results[result["filename"]][result["model"]] = result["prediction"]
295
+
296
+ # Get unique models
297
+ all_models = list(set(r["model"] for r in results))
298
+
299
+ if len(all_models) < 2:
300
+ return pd.DataFrame()
301
+
302
+ # Calculate agreement matrix
303
+ agreement_matrix = np.zeros((len(all_models), len(all_models)))
304
+
305
+ for i, model1 in enumerate(all_models):
306
+ for j, model2 in enumerate(all_models):
307
+ if i == j:
308
+ agreement_matrix[i, j] = 1.0 # Perfect self-agreement
309
+ else:
310
+ agreements = 0
311
+ comparisons = 0
312
+
313
+ for filename, predictions in file_results.items():
314
+ if model1 in predictions and model2 in predictions:
315
+ comparisons += 1
316
+ if predictions[model1] == predictions[model2]:
317
+ agreements += 1
318
+
319
+ if comparisons > 0:
320
+ agreement_matrix[i, j] = agreements / comparisons
321
+
322
+ return pd.DataFrame(agreement_matrix, index=all_models, columns=all_models)
323
+
324
+ @staticmethod
325
+ def create_comparison_visualization() -> plt.Figure:
326
+ """Create visualization comparing model performance."""
327
+ comparison_stats = ResultsManager.get_comparison_stats()
328
+
329
+ if not comparison_stats:
330
+ return None
331
+
332
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
333
+
334
+ models = list(comparison_stats.keys())
335
+
336
+ # 1. Average Confidence
337
+ confidences = [comparison_stats[m]["avg_confidence"] for m in models]
338
+ conf_stds = [comparison_stats[m]["std_confidence"] for m in models]
339
+ ax1.bar(models, confidences, yerr=conf_stds, capsize=5)
340
+ ax1.set_title("Average Confidence by Model")
341
+ ax1.set_ylabel("Confidence")
342
+ ax1.tick_params(axis="x", rotation=45)
343
+
344
+ # 2. Processing Time
345
+ proc_times = [comparison_stats[m]["avg_processing_time"] for m in models]
346
+ ax2.bar(models, proc_times)
347
+ ax2.set_title("Average Processing Time")
348
+ ax2.set_ylabel("Time (seconds)")
349
+ ax2.tick_params(axis="x", rotation=45)
350
+
351
+ # 3. Prediction Distribution
352
+ stable_counts = [comparison_stats[m]["stable_predictions"] for m in models]
353
+ weathered_counts = [
354
+ comparison_stats[m]["weathered_predictions"] for m in models
355
+ ]
356
+
357
+ x = np.arange(len(models))
358
+ width = 0.35
359
+ ax3.bar(x - width / 2, stable_counts, width, label="Stable", alpha=0.8)
360
+ ax3.bar(x + width / 2, weathered_counts, width, label="Weathered", alpha=0.8)
361
+ ax3.set_title("Prediction Distribution")
362
+ ax3.set_ylabel("Count")
363
+ ax3.set_xticks(x)
364
+ ax3.set_xticklabels(models, rotation=45)
365
+ ax3.legend()
366
+
367
+ # 4. Accuracy (if available)
368
+ accuracies = []
369
+ models_with_acc = []
370
+ for model in models:
371
+ if comparison_stats[model]["accuracy"] is not None:
372
+ accuracies.append(comparison_stats[model]["accuracy"])
373
+ models_with_acc.append(model)
374
+
375
+ if accuracies:
376
+ ax4.bar(models_with_acc, accuracies)
377
+ ax4.set_title("Model Accuracy (where ground truth available)")
378
+ ax4.set_ylabel("Accuracy")
379
+ ax4.set_ylim(0, 1)
380
+ ax4.tick_params(axis="x", rotation=45)
381
+ else:
382
+ ax4.text(
383
+ 0.5,
384
+ 0.5,
385
+ "No ground truth\navailable",
386
+ ha="center",
387
+ va="center",
388
+ transform=ax4.transAxes,
389
+ )
390
+ ax4.set_title("Model Accuracy")
391
+
392
+ plt.tight_layout()
393
+ return fig
394
+
395
+ @staticmethod
396
+ def export_comparison_report() -> str:
397
+ """Export comprehensive comparison report as JSON."""
398
+ comparison_stats = ResultsManager.get_comparison_stats()
399
+ agreement_matrix = ResultsManager.get_agreement_matrix()
400
+
401
+ report = {
402
+ "timestamp": datetime.now().isoformat(),
403
+ "model_comparison": comparison_stats,
404
+ "agreement_matrix": (
405
+ agreement_matrix.to_dict() if not agreement_matrix.empty else {}
406
+ ),
407
+ "summary": {
408
+ "total_models_compared": len(comparison_stats),
409
+ "total_files_processed": len(
410
+ set(r["filename"] for r in ResultsManager.get_results())
411
+ ),
412
+ "overall_statistics": ResultsManager.get_summary_stats(),
413
+ },
414
+ }
415
+
416
+ return json.dumps(report, indent=2, default=str)
417
+
418
  @staticmethod
419
  # ==UTILITY FUNCTIONS==
420
  def init_session_state():