Spaces:
Sleeping
Sleeping
Merge pull request #3 from devjas1/new-space-deploy
Browse files- .gitignore +1 -0
- CODEBASE_INVENTORY.md +99 -214
- README.md +54 -60
- app.py +22 -6
- core_logic.py +53 -2
- models/registry.py +114 -11
- modules/ui_components.py +478 -51
- sample_data/ftir-stable-1.txt +75 -0
- sample_data/ftir-weathered-1.txt +75 -0
- sample_data/stable.sample.csv +22 -0
- scripts/run_inference.py +364 -61
- tests/test_ftir_preprocessing.py +179 -0
- tests/test_multi_format.py +218 -0
- utils/multifile.py +297 -56
- utils/performance_tracker.py +404 -0
- utils/preprocessing.py +133 -10
- utils/results_manager.py +218 -2
.gitignore
CHANGED
|
@@ -26,3 +26,4 @@ datasets/**
|
|
| 26 |
# ---------------------------------------
|
| 27 |
|
| 28 |
__pycache__.py
|
|
|
|
|
|
| 26 |
# ---------------------------------------
|
| 27 |
|
| 28 |
__pycache__.py
|
| 29 |
+
outputs/performance_tracking.db
|
CODEBASE_INVENTORY.md
CHANGED
|
@@ -2,40 +2,38 @@
|
|
| 2 |
|
| 3 |
## Executive Summary
|
| 4 |
|
| 5 |
-
This audit provides a
|
| 6 |
|
| 7 |
## 🏗️ System Architecture
|
| 8 |
|
| 9 |
### Core Infrastructure
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
-
|
| 23 |
-
|
| 24 |
-
**
|
| 25 |
-
|
| 26 |
-
-
|
| 27 |
-
-
|
| 28 |
-
-
|
| 29 |
-
-
|
| 30 |
-
-
|
| 31 |
-
- `tests/` - Unit testing infrastructure
|
| 32 |
-
- `datasets/` - Data storage directory (content ignored)
|
| 33 |
|
| 34 |
## 🤖 Machine Learning Framework
|
| 35 |
|
| 36 |
-
### Model Registry
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
```python
|
| 41 |
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
|
@@ -47,35 +45,31 @@ _REGISTRY: Dict[str, Callable[[int], object]] = {
|
|
| 47 |
|
| 48 |
### Neural Network Architectures
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
- **Classification Head**: 3 fully connected layers (256→128→2 neurons)
|
| 54 |
-
- **Performance**: 94.80% accuracy, 94.30% F1-score
|
| 55 |
-
- **Designation**: Validated exclusively for Raman spectra input
|
| 56 |
-
- **Parameters**: Dynamic flattened size calculation for input flexibility
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
- **Innovation**: 1D residual connections for spectral feature learning
|
| 62 |
-
- **Performance**: 96.20% accuracy, 95.90% F1-score
|
| 63 |
-
- **Efficiency**: Global average pooling reduces parameter count
|
| 64 |
-
- **Parameters**: Approximately 100K (more efficient than baseline)
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
-
|
| 71 |
-
-
|
|
|
|
| 72 |
|
| 73 |
## 🔧 Data Processing Infrastructure
|
| 74 |
|
| 75 |
### Preprocessing Pipeline
|
| 76 |
|
| 77 |
-
The system implements a **modular preprocessing pipeline** in `utils/preprocessing.py` with five configurable stages:
|
| 78 |
-
|
| 79 |
**1. Input Validation Framework:**
|
| 80 |
|
| 81 |
- File format verification (`.txt` files exclusively)
|
|
@@ -84,16 +78,16 @@ The system implements a **modular preprocessing pipeline** in `utils/preprocessi
|
|
| 84 |
- Monotonic sequence verification for spectral consistency
|
| 85 |
- NaN value detection and automatic rejection
|
| 86 |
|
| 87 |
-
**2. Core Processing Steps:**
|
| 88 |
|
| 89 |
- **Linear Resampling**: Uniform grid interpolation to 500 points using `scipy.interpolate.interp1d`
|
| 90 |
- **Baseline Correction**: Polynomial detrending (configurable degree, default=2)
|
| 91 |
- **Savitzky-Golay Smoothing**: Noise reduction (window=11, order=2, configurable)
|
| 92 |
-
- **Min-Max Normalization**: Scaling to range with constant-signal protection
|
| 93 |
|
| 94 |
### Batch Processing Framework
|
| 95 |
|
| 96 |
-
The `utils/multifile.py` module (12.5 kB) provides **enterprise-grade batch processing** capabilities:
|
| 97 |
|
| 98 |
- **Multi-File Upload**: Streamlit widget supporting simultaneous file selection
|
| 99 |
- **Error-Tolerant Processing**: Individual file failures don't interrupt batch operations
|
|
@@ -123,7 +117,7 @@ The main application implements a **sophisticated two-column layout** with compr
|
|
| 123 |
|
| 124 |
### State Management System
|
| 125 |
|
| 126 |
-
The application employs **advanced session state management**:
|
| 127 |
|
| 128 |
- Persistent state across Streamlit reruns using `st.session_state`
|
| 129 |
- Intelligent caching with content-based hash keys for expensive operations
|
|
@@ -134,46 +128,24 @@ The application employs **advanced session state management**:[^1_2]
|
|
| 134 |
|
| 135 |
### Centralized Error Handling
|
| 136 |
|
| 137 |
-
The `utils/errors.py` module
|
| 138 |
-
|
| 139 |
-
```python
|
| 140 |
-
class ErrorHandler:
|
| 141 |
-
@staticmethod
|
| 142 |
-
def log_error(error: Exception, context: str = "", include_traceback: bool = False)
|
| 143 |
-
@staticmethod
|
| 144 |
-
def handle_file_error(filename: str, error: Exception) -> str
|
| 145 |
-
@staticmethod
|
| 146 |
-
def handle_inference_error(model_name: str, error: Exception) -> str
|
| 147 |
-
```
|
| 148 |
|
| 149 |
-
|
| 150 |
|
| 151 |
-
|
| 152 |
-
- Graceful degradation with fallback modes
|
| 153 |
-
- Structured logging with configurable verbosity
|
| 154 |
-
- User-friendly error translation from technical exceptions
|
| 155 |
|
| 156 |
-
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
:
|
| 161 |
-
|
| 162 |
-
**Softmax-Based Confidence:**
|
| 163 |
-
|
| 164 |
-
- Normalized probability distributions from model logits
|
| 165 |
-
- Three-tier confidence levels: HIGH (≥80%), MEDIUM (≥60%), LOW (<60%)
|
| 166 |
-
- Color-coded visual indicators with emoji representations
|
| 167 |
-
- Legacy compatibility with logit margin calculations
|
| 168 |
-
|
| 169 |
-
### Session Results Management
|
| 170 |
|
| 171 |
-
The `utils/results_manager.py` module
|
| 172 |
|
| 173 |
-
- **In-Memory Storage**:
|
| 174 |
-
- **
|
| 175 |
-
- **
|
| 176 |
-
- **
|
| 177 |
|
| 178 |
## 📜 Command-Line Interface
|
| 179 |
|
|
@@ -194,17 +166,6 @@ The `scripts/train_model.py` module (6.27 kB) implements **robust model training
|
|
| 194 |
- Deterministic CUDA operations when GPU available
|
| 195 |
- Standardized train/validation splitting methodology
|
| 196 |
|
| 197 |
-
### Inference Pipeline
|
| 198 |
-
|
| 199 |
-
The `scripts/run_inference.py` module (5.88 kB) provides **automated inference capabilities**:
|
| 200 |
-
|
| 201 |
-
**CLI Features:**
|
| 202 |
-
|
| 203 |
-
- Preprocessing parity with web interface ensuring consistent results
|
| 204 |
-
- Multiple output formats with detailed metadata inclusion
|
| 205 |
-
- Safe model loading across PyTorch versions with fallback mechanisms
|
| 206 |
-
- Flexible architecture selection via command-line arguments
|
| 207 |
-
|
| 208 |
### Data Utilities
|
| 209 |
|
| 210 |
**File Discovery System:**
|
|
@@ -213,17 +174,6 @@ The `scripts/run_inference.py` module (5.88 kB) provides **automated inference c
|
|
| 213 |
- Filename-based labeling convention (`sta-*` = stable, `wea-*` = weathered)
|
| 214 |
- Dataset inventory generation with statistical summaries
|
| 215 |
|
| 216 |
-
## 🐳 Deployment Infrastructure
|
| 217 |
-
|
| 218 |
-
### Docker Configuration
|
| 219 |
-
|
| 220 |
-
The `Dockerfile` (421 Bytes) implements **optimized containerization**:[^1_12]
|
| 221 |
-
|
| 222 |
-
- **Base Image**: Python 3.13-slim for minimal attack surface
|
| 223 |
-
- **System Dependencies**: Essential build tools and scientific libraries
|
| 224 |
-
- **Health Monitoring**: HTTP endpoint checking for container wellness
|
| 225 |
-
- **Caching Strategy**: Layered builds with dependency caching for faster rebuilds
|
| 226 |
-
|
| 227 |
### Dependency Management
|
| 228 |
|
| 229 |
The `requirements.txt` specifies **core dependencies without version pinning**:[^1_12]
|
|
@@ -234,6 +184,36 @@ The `requirements.txt` specifies **core dependencies without version pinning**:[
|
|
| 234 |
- **Visualization**: `matplotlib` for spectrum plotting
|
| 235 |
- **API Framework**: `fastapi`, `uvicorn` for potential REST API expansion
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
## 🧪 Testing Framework
|
| 238 |
|
| 239 |
### Test Infrastructure
|
|
@@ -244,12 +224,12 @@ The `tests/` directory implements **basic validation framework**:
|
|
| 244 |
- **Preprocessing Tests**: Core pipeline functionality validation in `test_preprocessing.py`
|
| 245 |
- **Limited Coverage**: Currently covers preprocessing functions only
|
| 246 |
|
| 247 |
-
**Testing
|
| 248 |
|
| 249 |
-
-
|
| 250 |
-
-
|
| 251 |
-
-
|
| 252 |
-
-
|
| 253 |
|
| 254 |
## 🔍 Security \& Quality Assessment
|
| 255 |
|
|
@@ -271,27 +251,11 @@ The `tests/` directory implements **basic validation framework**:
|
|
| 271 |
- **Error Boundaries**: Multi-level exception handling with graceful degradation
|
| 272 |
- **Logging**: Structured logging with appropriate severity levels
|
| 273 |
|
| 274 |
-
### Security Considerations
|
| 275 |
-
|
| 276 |
-
**Current Protections:**
|
| 277 |
-
|
| 278 |
-
- Input sanitization through strict parsing rules
|
| 279 |
-
- No arbitrary code execution paths
|
| 280 |
-
- Containerized deployment limiting attack surface
|
| 281 |
-
- Session-based storage preventing data persistence attacks
|
| 282 |
-
|
| 283 |
-
**Areas Requiring Enhancement:**
|
| 284 |
-
|
| 285 |
-
- No explicit security headers in web responses
|
| 286 |
-
- Basic authentication/authorization framework absent
|
| 287 |
-
- File upload size limits not explicitly configured
|
| 288 |
-
- No rate limiting mechanisms implemented
|
| 289 |
-
|
| 290 |
## 🚀 Extensibility Analysis
|
| 291 |
|
| 292 |
### Model Architecture Extensibility
|
| 293 |
|
| 294 |
-
The **registry pattern enables seamless model addition**:
|
| 295 |
|
| 296 |
1. **Implementation**: Create new model class with standardized interface
|
| 297 |
2. **Registration**: Add to `models/registry.py` with factory function
|
|
@@ -344,72 +308,15 @@ The **registry pattern enables seamless model addition**:[^1_5]
|
|
| 344 |
- Session state pruning for long-running sessions
|
| 345 |
- Caching with content-based invalidation
|
| 346 |
|
| 347 |
-
## 🎯 Production Readiness Evaluation
|
| 348 |
-
|
| 349 |
-
### Strengths
|
| 350 |
-
|
| 351 |
-
**Architecture Excellence:**
|
| 352 |
-
|
| 353 |
-
- Clean separation of concerns with modular design
|
| 354 |
-
- Production-grade error handling and logging
|
| 355 |
-
- Intuitive user experience with real-time feedback
|
| 356 |
-
- Scalable batch processing with progress tracking
|
| 357 |
-
- Well-documented, type-hinted codebase
|
| 358 |
-
|
| 359 |
-
**Operational Readiness:**
|
| 360 |
-
|
| 361 |
-
- Containerized deployment with health checks
|
| 362 |
-
- Comprehensive preprocessing validation
|
| 363 |
-
- Multiple export formats for integration
|
| 364 |
-
- Session-based results management
|
| 365 |
-
|
| 366 |
-
### Enhancement Opportunities
|
| 367 |
-
|
| 368 |
-
**Testing Infrastructure:**
|
| 369 |
-
|
| 370 |
-
- Expand unit test coverage beyond preprocessing
|
| 371 |
-
- Implement integration tests for UI workflows
|
| 372 |
-
- Add performance regression testing
|
| 373 |
-
- Include security vulnerability scanning
|
| 374 |
-
|
| 375 |
-
**Monitoring \& Observability:**
|
| 376 |
-
|
| 377 |
-
- Application performance monitoring integration
|
| 378 |
-
- User analytics and usage patterns tracking
|
| 379 |
-
- Model performance drift detection
|
| 380 |
-
- Resource utilization monitoring
|
| 381 |
-
|
| 382 |
-
**Security Hardening:**
|
| 383 |
-
|
| 384 |
-
- Implement proper authentication mechanisms
|
| 385 |
-
- Add rate limiting for API endpoints
|
| 386 |
-
- Configure security headers for web responses
|
| 387 |
-
- Establish audit logging for sensitive operations
|
| 388 |
-
|
| 389 |
## 🔮 Strategic Development Roadmap
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
**1. Multi-Model Dashboard Evolution**
|
| 394 |
-
|
| 395 |
-
- Comparative model evaluation framework
|
| 396 |
-
- Side-by-side performance reporting
|
| 397 |
-
- Automated model retraining pipelines
|
| 398 |
-
- Model versioning and rollback capabilities
|
| 399 |
-
|
| 400 |
-
**2. Multi-Modal Input Support**
|
| 401 |
-
|
| 402 |
-
- FTIR spectroscopy integration with dedicated preprocessing
|
| 403 |
-
- Image-based polymer classification via computer vision
|
| 404 |
-
- Cross-modal validation and ensemble methods
|
| 405 |
-
- Unified preprocessing pipeline for multiple modalities
|
| 406 |
|
| 407 |
-
**
|
| 408 |
-
|
| 409 |
-
-
|
| 410 |
-
-
|
| 411 |
-
-
|
| 412 |
-
- Audit trails and compliance reporting
|
| 413 |
|
| 414 |
## 💼 Business Logic \& Scientific Workflow
|
| 415 |
|
|
@@ -424,7 +331,7 @@ Based on the documented roadmap in `README.md`, the platform targets three strat
|
|
| 424 |
|
| 425 |
### Scientific Applications
|
| 426 |
|
| 427 |
-
**Research Use Cases:**
|
| 428 |
|
| 429 |
- Material science polymer degradation studies
|
| 430 |
- Recycling viability assessment for circular economy
|
|
@@ -434,7 +341,7 @@ Based on the documented roadmap in `README.md`, the platform targets three strat
|
|
| 434 |
|
| 435 |
### Data Workflow Architecture
|
| 436 |
|
| 437 |
-
```
|
| 438 |
Input Validation → Spectrum Preprocessing → Model Inference →
|
| 439 |
Confidence Analysis → Results Visualization → Export Options
|
| 440 |
```
|
|
@@ -475,10 +382,7 @@ The platform successfully bridges academic research and practical application, p
|
|
| 475 |
|
| 476 |
**Risk Assessment:** Low - The codebase demonstrates mature engineering practices with appropriate validation and error handling for production deployment.
|
| 477 |
|
| 478 |
-
**Recommendation:** This platform is ready for production deployment
|
| 479 |
-
<span style="display:none">[^1_14][^1_15][^1_16][^1_17][^1_18]</span>
|
| 480 |
-
|
| 481 |
-
<div style="text-align: center">⁂</div>
|
| 482 |
|
| 483 |
### EXTRA
|
| 484 |
|
|
@@ -529,22 +433,3 @@ The platform successfully bridges academic research and practical application, p
|
|
| 529 |
Column 1 (Input): Contains the main st.radio for mode selection and the conditional logic to display the single file uploader, batch uploader, or sample selector. It also holds the "Run Analysis" and "Reset All" buttons.
|
| 530 |
Column 2 (Results): Contains all the logic for displaying either the batch results or the detailed, tabbed results for a single file (Details, Technical, Explanation).
|
| 531 |
```
|
| 532 |
-
|
| 533 |
-
[^1_1]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/tree/main
|
| 534 |
-
[^1_2]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/tree/main/datasets
|
| 535 |
-
[^1_3]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml
|
| 536 |
-
[^1_4]: https://github.com/KLab-AI3/ml-polymer-recycling
|
| 537 |
-
[^1_5]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/.gitignore
|
| 538 |
-
[^1_6]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/blob/main/models/resnet_cnn.py
|
| 539 |
-
[^1_7]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/multifile.py
|
| 540 |
-
[^1_8]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/preprocessing.py
|
| 541 |
-
[^1_9]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/audit.py
|
| 542 |
-
[^1_10]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/results_manager.py
|
| 543 |
-
[^1_11]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/blob/main/scripts/train_model.py
|
| 544 |
-
[^1_12]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/requirements.txt
|
| 545 |
-
[^1_13]: https://doi.org/10.1016/j.resconrec.2022.106718
|
| 546 |
-
[^1_14]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/app.py
|
| 547 |
-
[^1_15]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/Dockerfile
|
| 548 |
-
[^1_16]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/errors.py
|
| 549 |
-
[^1_17]: https://huggingface.co/spaces/dev-jas/polymer-aging-ml/raw/main/utils/confidence.py
|
| 550 |
-
[^1_18]: https://ppl-ai-code-interpreter-files.s3.amazonaws.com/web/direct-files/9fd1eb2028a28085942cb82c9241b5ae/a25e2c38-813f-4d8b-89b3-713f7d24f1fe/3e70b172.md
|
|
|
|
| 2 |
|
| 3 |
## Executive Summary
|
| 4 |
|
| 5 |
+
This audit provides a technical inventory of the dev-jas/polymer-aging-ml repository—a modular machine learning platform for polymer degradation classification using Raman and FTIR spectroscopy. The system features robust error handling, multi-format batch processing, and persistent performance tracking, making it suitable for research, education, and industrial applications.
|
| 6 |
|
| 7 |
## 🏗️ System Architecture
|
| 8 |
|
| 9 |
### Core Infrastructure
|
| 10 |
|
| 11 |
+
- **Streamlit-based web app** (`app.py`) as the main interface
|
| 12 |
+
- **PyTorch** for deep learning
|
| 13 |
+
- **Docker** for deployment
|
| 14 |
+
- **SQLite** (`outputs/performance_tracking.db`) for performance metrics
|
| 15 |
+
- **Plugin-based model registry** for extensibility
|
| 16 |
+
|
| 17 |
+
### Directory Structure
|
| 18 |
+
|
| 19 |
+
- **app.py**: Main Streamlit application
|
| 20 |
+
- **README.md**: Project documentation
|
| 21 |
+
- **Dockerfile**: Containerization (Python 3.13-slim)
|
| 22 |
+
- **requirements.txt**: Dependency management
|
| 23 |
+
- **models/**: Neural network architectures and registry
|
| 24 |
+
- **utils/**: Shared utilities (preprocessing, batch, results, performance, errors, confidence)
|
| 25 |
+
- **scripts/**: CLI tools for training, inference, data management
|
| 26 |
+
- **outputs/**: Model weights, inference results, performance DB
|
| 27 |
+
- **sample_data/**: Demo spectrum files
|
| 28 |
+
- **tests/**: Unit tests (PyTest)
|
| 29 |
+
- **datasets/**: Data storage
|
| 30 |
+
- **pages/**: Streamlit dashboard pages
|
|
|
|
|
|
|
| 31 |
|
| 32 |
## 🤖 Machine Learning Framework
|
| 33 |
|
| 34 |
+
### Model Registry
|
| 35 |
|
| 36 |
+
Factory pattern in `models/registry.py` enables dynamic model selection:
|
| 37 |
|
| 38 |
```python
|
| 39 |
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
|
|
|
| 45 |
|
| 46 |
### Neural Network Architectures
|
| 47 |
|
| 48 |
+
The platform supports three architectures, offering diverse options for spectral analysis:
|
| 49 |
|
| 50 |
+
**Figure2CNN (Baseline Model):**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
- Architecture: 4 convolutional layers (1→16→32→64→128), 3 fully connected layers (256→128→2).
|
| 53 |
+
- Performance: 94.80% accuracy, 94.30% F1-score (Raman-only).
|
| 54 |
+
- Parameters: ~500K, supports dynamic input handling.
|
| 55 |
|
| 56 |
+
**ResNet1D (Advanced Model):**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
- Architecture: 3 residual blocks with 1D skip connections.
|
| 59 |
+
- Performance: 96.20% accuracy, 95.90% F1-score.
|
| 60 |
+
- Parameters: ~100K, efficient via global average pooling.
|
| 61 |
|
| 62 |
+
**ResNet18Vision (Experimental):**
|
| 63 |
+
|
| 64 |
+
- Architecture: 1D-adapted ResNet-18 with 4 layers (2 blocks each).
|
| 65 |
+
- Status: Under evaluation, ~11M parameters.
|
| 66 |
+
- Opportunity: Expand validation for broader spectral applications.
|
| 67 |
|
| 68 |
## 🔧 Data Processing Infrastructure
|
| 69 |
|
| 70 |
### Preprocessing Pipeline
|
| 71 |
|
| 72 |
+
The system implements a **modular preprocessing pipeline** in `utils/preprocessing.py` with five configurable stages:
|
|
|
|
| 73 |
**1. Input Validation Framework:**
|
| 74 |
|
| 75 |
- File format verification (`.txt` files exclusively)
|
|
|
|
| 78 |
- Monotonic sequence verification for spectral consistency
|
| 79 |
- NaN value detection and automatic rejection
|
| 80 |
|
| 81 |
+
**2. Core Processing Steps:**
|
| 82 |
|
| 83 |
- **Linear Resampling**: Uniform grid interpolation to 500 points using `scipy.interpolate.interp1d`
|
| 84 |
- **Baseline Correction**: Polynomial detrending (configurable degree, default=2)
|
| 85 |
- **Savitzky-Golay Smoothing**: Noise reduction (window=11, order=2, configurable)
|
| 86 |
+
- **Min-Max Normalization**: Scaling to range with constant-signal protection
|
| 87 |
|
| 88 |
### Batch Processing Framework
|
| 89 |
|
| 90 |
+
The `utils/multifile.py` module (12.5 kB) provides **enterprise-grade batch processing** capabilities:
|
| 91 |
|
| 92 |
- **Multi-File Upload**: Streamlit widget supporting simultaneous file selection
|
| 93 |
- **Error-Tolerant Processing**: Individual file failures don't interrupt batch operations
|
|
|
|
| 117 |
|
| 118 |
### State Management System
|
| 119 |
|
| 120 |
+
The application employs **advanced session state management**:
|
| 121 |
|
| 122 |
- Persistent state across Streamlit reruns using `st.session_state`
|
| 123 |
- Intelligent caching with content-based hash keys for expensive operations
|
|
|
|
| 128 |
|
| 129 |
### Centralized Error Handling
|
| 130 |
|
| 131 |
+
The `utils/errors.py` module provides with **context-aware** logging and user-friendly error messages.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
### Performance Tracking System
|
| 134 |
|
| 135 |
+
The `utils/performance_tracker.py` module provides a robust system for logging and analyzing performance metrics.
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
- **Database Logging**: Persists metrics to a SQLite database.
|
| 138 |
+
- **Automated Tracking**: Uses a context manager to automatically track inference time, preprocessing time, and memory usage.
|
| 139 |
+
- **Dashboarding**: Includes functions to generate performance visualizations and summary statistics for the UI.
|
| 140 |
|
| 141 |
+
### Enhanced Results Management
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
The `utils/results_manager.py` module enables comprehensive session and persistent results tracking.
|
| 144 |
|
| 145 |
+
- **In-Memory Storage**: Manages results for the current session.
|
| 146 |
+
- **Multi-Model Handling**: Aggregates results from multiple models for comparison.
|
| 147 |
+
- **Export Capabilities**: Exports results to CSV and JSON.
|
| 148 |
+
- **Statistical Analysis**: Calculates accuracy, confidence, and other metrics.
|
| 149 |
|
| 150 |
## 📜 Command-Line Interface
|
| 151 |
|
|
|
|
| 166 |
- Deterministic CUDA operations when GPU available
|
| 167 |
- Standardized train/validation splitting methodology
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
### Data Utilities
|
| 170 |
|
| 171 |
**File Discovery System:**
|
|
|
|
| 174 |
- Filename-based labeling convention (`sta-*` = stable, `wea-*` = weathered)
|
| 175 |
- Dataset inventory generation with statistical summaries
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
### Dependency Management
|
| 178 |
|
| 179 |
The `requirements.txt` specifies **core dependencies without version pinning**:[^1_12]
|
|
|
|
| 184 |
- **Visualization**: `matplotlib` for spectrum plotting
|
| 185 |
- **API Framework**: `fastapi`, `uvicorn` for potential REST API expansion
|
| 186 |
|
| 187 |
+
## 🐳 Deployment Infrastructure
|
| 188 |
+
|
| 189 |
+
### Docker Configuration
|
| 190 |
+
|
| 191 |
+
The Dockerfile uses Python 3.13-slim for efficient containerization:
|
| 192 |
+
|
| 193 |
+
- Includes essential build tools and scientific libraries.
|
| 194 |
+
- Supports health checks for container wellness.
|
| 195 |
+
- **Roadmap**: Implement multi-stage builds and environment variables for streamlined deployments.
|
| 196 |
+
|
| 197 |
+
### Confidence Analysis System
|
| 198 |
+
|
| 199 |
+
The `utils/confidence.py` module provides **scientific confidence metrics**
|
| 200 |
+
|
| 201 |
+
**Softmax-Based Confidence:**
|
| 202 |
+
|
| 203 |
+
- Normalized probability distributions from model logits
|
| 204 |
+
- Three-tier confidence levels: HIGH (≥80%), MEDIUM (≥60%), LOW (<60%)
|
| 205 |
+
- Color-coded visual indicators with emoji representations
|
| 206 |
+
- Legacy compatibility with logit margin calculations
|
| 207 |
+
|
| 208 |
+
### Session Results Management
|
| 209 |
+
|
| 210 |
+
The `utils/results_manager.py` module (8.16 kB) enables **comprehensive session tracking**:
|
| 211 |
+
|
| 212 |
+
- **In-Memory Storage**: Session-wide results persistence
|
| 213 |
+
- **Export Capabilities**: CSV and JSON download with timestamp formatting
|
| 214 |
+
- **Statistical Analysis**: Automatic accuracy calculation when ground truth available
|
| 215 |
+
- **Data Integrity**: Results survive page refreshes within session boundaries
|
| 216 |
+
|
| 217 |
## 🧪 Testing Framework
|
| 218 |
|
| 219 |
### Test Infrastructure
|
|
|
|
| 224 |
- **Preprocessing Tests**: Core pipeline functionality validation in `test_preprocessing.py`
|
| 225 |
- **Limited Coverage**: Currently covers preprocessing functions only
|
| 226 |
|
| 227 |
+
**Testing Coming Soon:**
|
| 228 |
|
| 229 |
+
- Add model architecture unit tests
|
| 230 |
+
- Integration tests for UI components
|
| 231 |
+
- Performance benchmarking tests
|
| 232 |
+
- Improved error handling validation
|
| 233 |
|
| 234 |
## 🔍 Security \& Quality Assessment
|
| 235 |
|
|
|
|
| 251 |
- **Error Boundaries**: Multi-level exception handling with graceful degradation
|
| 252 |
- **Logging**: Structured logging with appropriate severity levels
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
## 🚀 Extensibility Analysis
|
| 255 |
|
| 256 |
### Model Architecture Extensibility
|
| 257 |
|
| 258 |
+
The **registry pattern enables seamless model addition**:
|
| 259 |
|
| 260 |
1. **Implementation**: Create new model class with standardized interface
|
| 261 |
2. **Registration**: Add to `models/registry.py` with factory function
|
|
|
|
| 308 |
- Session state pruning for long-running sessions
|
| 309 |
- Caching with content-based invalidation
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
## 🔮 Strategic Development Roadmap
|
| 312 |
|
| 313 |
+
The project roadmap has been updated to reflect recent progress:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
+
- [x] **FTIR Support**: Modular integration of FTIR spectroscopy is complete.
|
| 316 |
+
- [x] **Multi-Model Dashboard**: A model comparison tab has been implemented.
|
| 317 |
+
- [ ] **Image-based Inference**: Future work to include image-based polymer classification.
|
| 318 |
+
- [x] **Performance Tracking**: A performance tracking dashboard has been implemented.
|
| 319 |
+
- [ ] **Enterprise Integration**: Future work to include a RESTful API and more advanced database integration.
|
|
|
|
| 320 |
|
| 321 |
## 💼 Business Logic \& Scientific Workflow
|
| 322 |
|
|
|
|
| 331 |
|
| 332 |
### Scientific Applications
|
| 333 |
|
| 334 |
+
**Research Use Cases:**
|
| 335 |
|
| 336 |
- Material science polymer degradation studies
|
| 337 |
- Recycling viability assessment for circular economy
|
|
|
|
| 341 |
|
| 342 |
### Data Workflow Architecture
|
| 343 |
|
| 344 |
+
```text
|
| 345 |
Input Validation → Spectrum Preprocessing → Model Inference →
|
| 346 |
Confidence Analysis → Results Visualization → Export Options
|
| 347 |
```
|
|
|
|
| 382 |
|
| 383 |
**Risk Assessment:** Low - The codebase demonstrates mature engineering practices with appropriate validation and error handling for production deployment.
|
| 384 |
|
| 385 |
+
**Recommendation:** This platform is ready for production deployment, representing a solid foundation for polymer classification research and industrial applications.
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
### EXTRA
|
| 388 |
|
|
|
|
| 433 |
Column 1 (Input): Contains the main st.radio for mode selection and the conditional logic to display the single file uploader, batch uploader, or sample selector. It also holds the "Run Analysis" and "Reset All" buttons.
|
| 434 |
Column 2 (Results): Contains all the logic for displaying either the batch results or the detailed, tabbed results for a single file (Details, Technical, Explanation).
|
| 435 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title: AI Polymer Classification
|
| 3 |
emoji: 🔬
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: yellow
|
|
@@ -8,19 +8,21 @@ app_file: app.py
|
|
| 8 |
pinned: false
|
| 9 |
license: apache-2.0
|
| 10 |
---
|
| 11 |
-
## AI-Driven Polymer Aging Prediction and Classification (v0.1)
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
---
|
| 18 |
|
| 19 |
## 🧪 Current Scope
|
| 20 |
|
| 21 |
-
- 🔬 **
|
| 22 |
-
-
|
|
|
|
| 23 |
- 📊 **Task**: Binary classification — Stable vs Weathered polymers
|
|
|
|
| 24 |
- 🛠️ **Architecture**: PyTorch + Streamlit
|
| 25 |
|
| 26 |
---
|
|
@@ -29,84 +31,76 @@ It was developed as part of the AIRE 2025 internship project at the Imageomics I
|
|
| 29 |
|
| 30 |
- [x] Inference from Raman `.txt` files
|
| 31 |
- [x] Model selection (Figure2CNN, ResNet1D)
|
|
|
|
|
|
|
|
|
|
| 32 |
- [ ] Add more trained CNNs for comparison
|
| 33 |
-
- [ ] FTIR support (modular integration planned)
|
| 34 |
- [ ] Image-based inference (future modality)
|
|
|
|
| 35 |
|
| 36 |
---
|
| 37 |
|
| 38 |
## 🧭 How to Use
|
| 39 |
|
| 40 |
-
|
| 41 |
-
2. Choose a model from the sidebar
|
| 42 |
-
3. Run analysis
|
| 43 |
-
4. View prediction, logits, and technical information
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
-
-
|
| 48 |
-
-
|
| 49 |
-
-
|
| 50 |
-
- Automatically resampled to 500 points
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
## Contributors
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
---
|
| 70 |
|
| 71 |
-
##
|
| 72 |
-
|
| 73 |
-
- 💻 **Live App**: [Hugging Face Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
|
| 74 |
-
- 📂 **GitHub Repo**: [ml-polymer-recycling](https://github.com/KLab-AI3/ml-polymer-recycling)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
## 🎯 Strategic Expansion Objectives (Roadmap)
|
| 78 |
-
|
| 79 |
-
**The roadmap defines three major expansion paths designed to broaden the system’s capabilities and impact:**
|
| 80 |
-
|
| 81 |
-
1. **Model Expansion: Multi-Model Dashboard**
|
| 82 |
-
|
| 83 |
-
> The dashboard will evolve into a hub for multiple model architectures rather than being tied to a single baseline. Planned work includes:
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
- **Reproducible Integration**: Maintaining modular scripts and pipelines so each model’s results can be replicated without conflict.
|
| 89 |
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
- **Multi-Model Execution**: Selected models from the registry can be applied to all uploaded images simultaneously.
|
| 98 |
-
- **Batch Results**: Output will be returned in a structured, accessible way, showing both individual predictions and aggregate statistics.
|
| 99 |
-
- **Enhanced Feedback**: Outputs will include predicted class, model confidence, and potentially annotated image previews.
|
| 100 |
|
| 101 |
-
|
| 102 |
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
- **Architecture Compatibility**: Ensuring existing and retrained models can process FTIR data without mixing it with Raman workflows.
|
| 109 |
-
- **UI Integration**: Introducing FTIR as a separate option in the modality selector, keeping Raman, Image, and FTIR workflows clearly delineated.
|
| 110 |
-
- **Phased Development**: Implementation details to be refined during meetings to ensure scientific rigor.
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: AI Polymer Classification (Raman & FTIR)
|
| 3 |
emoji: 🔬
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: yellow
|
|
|
|
| 8 |
pinned: false
|
| 9 |
license: apache-2.0
|
| 10 |
---
|
|
|
|
| 11 |
|
| 12 |
+
## AI-Driven Polymer Aging Prediction and Classification (v0.1)
|
| 13 |
|
| 14 |
+
This web application classifies the degradation state of polymers using **Raman and FTIR spectroscopy** and deep learning.
|
| 15 |
+
It is a prototype pipeline for evaluating multiple convolutional neural networks (CNNs) on spectral data.
|
| 16 |
|
| 17 |
---
|
| 18 |
|
| 19 |
## 🧪 Current Scope
|
| 20 |
|
| 21 |
+
- 🔬 **Modalities**: Raman & FTIR spectroscopy
|
| 22 |
+
- 💾 **Input Formats**: `.txt`, `.csv`, `.json` (with auto-detection)
|
| 23 |
+
- 🧠 **Models**: Figure2CNN (baseline), ResNet1D, ResNet18Vision
|
| 24 |
- 📊 **Task**: Binary classification — Stable vs Weathered polymers
|
| 25 |
+
- 🚀 **Features**: Multi-model comparison, performance tracking dashboard
|
| 26 |
- 🛠️ **Architecture**: PyTorch + Streamlit
|
| 27 |
|
| 28 |
---
|
|
|
|
| 31 |
|
| 32 |
- [x] Inference from Raman `.txt` files
|
| 33 |
- [x] Model selection (Figure2CNN, ResNet1D)
|
| 34 |
+
- [x] **FTIR support** (modular integration complete)
|
| 35 |
+
- [x] **Multi-model comparison dashboard**
|
| 36 |
+
- [x] **Performance tracking dashboard**
|
| 37 |
- [ ] Add more trained CNNs for comparison
|
|
|
|
| 38 |
- [ ] Image-based inference (future modality)
|
| 39 |
+
- [ ] RESTful API for programmatic access
|
| 40 |
|
| 41 |
---
|
| 42 |
|
| 43 |
## 🧭 How to Use
|
| 44 |
|
| 45 |
+
The application provides three main analysis modes in a tabbed interface:
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
1. **Standard Analysis**:
|
| 48 |
|
| 49 |
+
- Upload a single spectrum file (`.txt`, `.csv`, `.json`) or a batch of files.
|
| 50 |
+
- Choose a model from the sidebar.
|
| 51 |
+
- Run analysis and view the prediction, confidence, and technical details.
|
|
|
|
| 52 |
|
| 53 |
+
2. **Model Comparison**:
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
- Upload a single spectrum file.
|
| 56 |
+
- The app runs inference with all available models.
|
| 57 |
+
- View a side-by-side comparison of the models' predictions and performance.
|
| 58 |
|
| 59 |
+
3. **Performance Tracking**:
|
| 60 |
+
- Explore a dashboard with visualizations of historical performance data.
|
| 61 |
+
- Compare model performance across different metrics.
|
| 62 |
+
- Export performance data in CSV or JSON format.
|
| 63 |
|
| 64 |
+
### Supported Input
|
| 65 |
|
| 66 |
+
- Plaintext `.txt`, `.csv`, or `.json` files.
|
| 67 |
+
- Data can be space-, comma-, or tab-separated.
|
| 68 |
+
- Comment lines (`#`, `%`) are ignored.
|
| 69 |
+
- The app automatically detects the file format and resamples the data to a standard length.
|
| 70 |
|
| 71 |
---
|
| 72 |
|
| 73 |
+
## Contributors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
Dr. Sanmukh Kuppannagari (Mentor)
|
| 76 |
+
Dr. Metin Karailyan (Mentor)
|
| 77 |
+
Jaser Hasan (Author/Developer)
|
|
|
|
| 78 |
|
| 79 |
+
## Model Credit
|
| 80 |
|
| 81 |
+
Baseline model inspired by:
|
| 82 |
|
| 83 |
+
Neo, E.R.K., Low, J.S.C., Goodship, V., Debattista, K. (2023).
|
| 84 |
+
_Deep learning for chemometric analysis of plastic spectral data from infrared and Raman databases._
|
| 85 |
+
_Resources, Conservation & Recycling_, **188**, 106718.
|
| 86 |
+
[https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718)
|
| 87 |
|
| 88 |
+
---
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
## 🔗 Links
|
| 91 |
|
| 92 |
+
- **Live App**: [Hugging Face Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
|
| 93 |
+
- **GitHub Repo**: [ml-polymer-recycling](https://github.com/KLab-AI3/ml-polymer-recycling)
|
| 94 |
|
| 95 |
+
## 🚀 Technical Architecture
|
| 96 |
|
| 97 |
+
**The system is built on a modular, production-ready architecture designed for scalability and maintainability.**
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
- **Frontend**: A Streamlit-based web application (`app.py`) provides an interactive, multi-tab user interface.
|
| 100 |
+
- **Backend**: PyTorch handles all deep learning operations, including model loading and inference.
|
| 101 |
+
- **Model Management**: A registry pattern (`models/registry.py`) allows for dynamic model loading and easy integration of new architectures.
|
| 102 |
+
- **Data Processing**: A robust, modality-aware preprocessing pipeline (`utils/preprocessing.py`) ensures data integrity and standardization for both Raman and FTIR data.
|
| 103 |
+
- **Multi-Format Parsing**: The `utils/multifile.py` module handles parsing of `.txt`, `.csv`, and `.json` files.
|
| 104 |
+
- **Results Management**: The `utils/results_manager.py` module manages session and persistent results, with support for multi-model comparison and data export.
|
| 105 |
+
- **Performance Tracking**: The `utils/performance_tracker.py` module logs performance metrics to a SQLite database and provides a dashboard for visualization.
|
| 106 |
+
- **Deployment**: The application is containerized using Docker (`Dockerfile`) for reproducible, cross-platform execution.
|
app.py
CHANGED
|
@@ -8,6 +8,8 @@ from modules.ui_components import (
|
|
| 8 |
render_sidebar,
|
| 9 |
render_results_column,
|
| 10 |
render_input_column,
|
|
|
|
|
|
|
| 11 |
load_css,
|
| 12 |
)
|
| 13 |
|
|
@@ -27,14 +29,28 @@ def main():
|
|
| 27 |
load_css("static/style.css")
|
| 28 |
init_session_state()
|
| 29 |
|
| 30 |
-
# Render UI components
|
| 31 |
render_sidebar()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
if __name__ == "__main__":
|
|
|
|
| 8 |
render_sidebar,
|
| 9 |
render_results_column,
|
| 10 |
render_input_column,
|
| 11 |
+
render_comparison_tab,
|
| 12 |
+
render_performance_tab,
|
| 13 |
load_css,
|
| 14 |
)
|
| 15 |
|
|
|
|
| 29 |
load_css("static/style.css")
|
| 30 |
init_session_state()
|
| 31 |
|
|
|
|
| 32 |
render_sidebar()
|
| 33 |
|
| 34 |
+
# Create main tabs for difference analysis modes
|
| 35 |
+
tab1, tab2, tab3 = st.tabs(
|
| 36 |
+
["Standard Analysis", "Model Comparison", "Peformance Tracking"]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
with tab1:
|
| 40 |
+
# Standard single-model analysis
|
| 41 |
+
col1, col2 = st.columns([1, 1.35], gap="small")
|
| 42 |
+
with col1:
|
| 43 |
+
render_input_column()
|
| 44 |
+
with col2:
|
| 45 |
+
render_results_column()
|
| 46 |
+
|
| 47 |
+
with tab2:
|
| 48 |
+
# Multi-model comparison interface
|
| 49 |
+
render_comparison_tab()
|
| 50 |
+
|
| 51 |
+
with tab3:
|
| 52 |
+
# Performance tracking interface
|
| 53 |
+
render_performance_tab()
|
| 54 |
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
core_logic.py
CHANGED
|
@@ -10,6 +10,7 @@ import numpy as np
|
|
| 10 |
import streamlit as st
|
| 11 |
from pathlib import Path
|
| 12 |
from config import SAMPLE_DATA_DIR
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def label_file(filename: str) -> int:
|
|
@@ -89,16 +90,26 @@ def cleanup_memory():
|
|
| 89 |
|
| 90 |
@st.cache_data
|
| 91 |
def run_inference(y_resampled, model_choice, _cache_key=None):
|
| 92 |
-
"""Run model inference and cache results"""
|
|
|
|
|
|
|
|
|
|
| 93 |
model, model_loaded = load_model(model_choice)
|
| 94 |
if not model_loaded:
|
| 95 |
return None, None, None, None, None
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
input_tensor = (
|
| 98 |
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 99 |
)
|
|
|
|
|
|
|
| 100 |
start_time = time.time()
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
with torch.no_grad():
|
| 103 |
if model is None:
|
| 104 |
raise ValueError(
|
|
@@ -108,11 +119,51 @@ def run_inference(y_resampled, model_choice, _cache_key=None):
|
|
| 108 |
prediction = torch.argmax(logits, dim=1).item()
|
| 109 |
logits_list = logits.detach().numpy().tolist()[0]
|
| 110 |
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
|
|
|
| 111 |
inference_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
cleanup_memory()
|
| 113 |
return prediction, logits_list, probs, inference_time, logits
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
@st.cache_data
|
| 117 |
def get_sample_files():
|
| 118 |
"""Get list of sample files if available"""
|
|
|
|
| 10 |
import streamlit as st
|
| 11 |
from pathlib import Path
|
| 12 |
from config import SAMPLE_DATA_DIR
|
| 13 |
+
from datetime import datetime
|
| 14 |
|
| 15 |
|
| 16 |
def label_file(filename: str) -> int:
|
|
|
|
| 90 |
|
| 91 |
@st.cache_data
|
| 92 |
def run_inference(y_resampled, model_choice, _cache_key=None):
|
| 93 |
+
"""Run model inference and cache results with performance tracking"""
|
| 94 |
+
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
|
| 95 |
+
from datetime import datetime
|
| 96 |
+
|
| 97 |
model, model_loaded = load_model(model_choice)
|
| 98 |
if not model_loaded:
|
| 99 |
return None, None, None, None, None
|
| 100 |
|
| 101 |
+
# Performance tracking setup
|
| 102 |
+
tracker = get_performance_tracker()
|
| 103 |
+
|
| 104 |
input_tensor = (
|
| 105 |
torch.tensor(y_resampled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 106 |
)
|
| 107 |
+
|
| 108 |
+
# Track inference performance
|
| 109 |
start_time = time.time()
|
| 110 |
+
start_memory = _get_memory_usage()
|
| 111 |
+
|
| 112 |
+
model.eval() # type: ignore
|
| 113 |
with torch.no_grad():
|
| 114 |
if model is None:
|
| 115 |
raise ValueError(
|
|
|
|
| 119 |
prediction = torch.argmax(logits, dim=1).item()
|
| 120 |
logits_list = logits.detach().numpy().tolist()[0]
|
| 121 |
probs = F.softmax(logits.detach(), dim=1).cpu().numpy().flatten()
|
| 122 |
+
|
| 123 |
inference_time = time.time() - start_time
|
| 124 |
+
end_memory = _get_memory_usage()
|
| 125 |
+
memory_usage = max(end_memory - start_memory, 0)
|
| 126 |
+
|
| 127 |
+
# Log performance metrics
|
| 128 |
+
try:
|
| 129 |
+
modality = st.session_state.get("modality_select", "raman")
|
| 130 |
+
confidence = float(max(probs)) if probs is not None and len(probs) > 0 else 0.0
|
| 131 |
+
|
| 132 |
+
metrics = PerformanceMetrics(
|
| 133 |
+
model_name=model_choice,
|
| 134 |
+
prediction_time=inference_time,
|
| 135 |
+
preprocessing_time=0.0, # Will be updated by calling function if available
|
| 136 |
+
total_time=inference_time,
|
| 137 |
+
memory_usage_mb=memory_usage,
|
| 138 |
+
accuracy=None, # Will be updated if ground truth is available
|
| 139 |
+
confidence=confidence,
|
| 140 |
+
timestamp=datetime.now().isoformat(),
|
| 141 |
+
input_size=(
|
| 142 |
+
len(y_resampled) if hasattr(y_resampled, "__len__") else TARGET_LEN
|
| 143 |
+
),
|
| 144 |
+
modality=modality,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
tracker.log_performance(metrics)
|
| 148 |
+
except (AttributeError, ValueError, KeyError) as e:
|
| 149 |
+
# Don't fail inference if performance tracking fails
|
| 150 |
+
print(f"Performance tracking failed: {e}")
|
| 151 |
+
|
| 152 |
cleanup_memory()
|
| 153 |
return prediction, logits_list, probs, inference_time, logits
|
| 154 |
|
| 155 |
|
| 156 |
+
def _get_memory_usage() -> float:
|
| 157 |
+
"""Get current memory usage in MB"""
|
| 158 |
+
try:
|
| 159 |
+
import psutil
|
| 160 |
+
|
| 161 |
+
process = psutil.Process()
|
| 162 |
+
return process.memory_info().rss / 1024 / 1024 # Convert to MB
|
| 163 |
+
except ImportError:
|
| 164 |
+
return 0.0 # psutil not available
|
| 165 |
+
|
| 166 |
+
|
| 167 |
@st.cache_data
|
| 168 |
def get_sample_files():
|
| 169 |
"""Get list of sample files if available"""
|
models/registry.py
CHANGED
|
@@ -1,35 +1,138 @@
|
|
| 1 |
# models/registry.py
|
| 2 |
-
from typing import Callable, Dict
|
| 3 |
from models.figure2_cnn import Figure2CNN
|
| 4 |
from models.resnet_cnn import ResNet1D
|
| 5 |
-
from models.resnet18_vision import ResNet18Vision
|
| 6 |
|
| 7 |
# Internal registry of model builders keyed by short name.
|
| 8 |
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
| 9 |
"figure2": lambda L: Figure2CNN(input_length=L),
|
| 10 |
"resnet": lambda L: ResNet1D(input_length=L),
|
| 11 |
-
"resnet18vision": lambda L: ResNet18Vision(input_length=L)
|
| 12 |
}
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def choices():
|
| 15 |
"""Return the list of available model keys."""
|
| 16 |
return list(_REGISTRY.keys())
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def build(name: str, input_length: int):
|
| 19 |
"""Instantiate a model by short name with the given input length."""
|
| 20 |
if name not in _REGISTRY:
|
| 21 |
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
| 22 |
return _REGISTRY[name](input_length)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def spec(name: str):
|
| 25 |
"""Return expected input length and number of classes for a model key."""
|
| 26 |
-
if name
|
| 27 |
-
return
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# models/registry.py
|
| 2 |
+
from typing import Callable, Dict, List, Any
|
| 3 |
from models.figure2_cnn import Figure2CNN
|
| 4 |
from models.resnet_cnn import ResNet1D
|
| 5 |
+
from models.resnet18_vision import ResNet18Vision
|
| 6 |
|
| 7 |
# Internal registry of model builders keyed by short name.
|
| 8 |
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
| 9 |
"figure2": lambda L: Figure2CNN(input_length=L),
|
| 10 |
"resnet": lambda L: ResNet1D(input_length=L),
|
| 11 |
+
"resnet18vision": lambda L: ResNet18Vision(input_length=L),
|
| 12 |
}
|
| 13 |
|
| 14 |
+
# Model specifications with metadata for enhanced features
|
| 15 |
+
_MODEL_SPECS: Dict[str, Dict[str, Any]] = {
|
| 16 |
+
"figure2": {
|
| 17 |
+
"input_length": 500,
|
| 18 |
+
"num_classes": 2,
|
| 19 |
+
"description": "Figure 2 baseline custom implemetation",
|
| 20 |
+
"modalities": ["raman", "ftir"],
|
| 21 |
+
"citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
|
| 22 |
+
},
|
| 23 |
+
"resnet": {
|
| 24 |
+
"input_length": 500,
|
| 25 |
+
"num_classes": 2,
|
| 26 |
+
"description": "(Residual Network) uses skip connections to train much deeper networks",
|
| 27 |
+
"modalities": ["raman", "ftir"],
|
| 28 |
+
"citation": "Custom ResNet implementation",
|
| 29 |
+
},
|
| 30 |
+
"resnet18vision": {
|
| 31 |
+
"input_length": 500,
|
| 32 |
+
"num_classes": 2,
|
| 33 |
+
"description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
|
| 34 |
+
"modalities": ["raman", "ftir"],
|
| 35 |
+
"citation": "ResNet18 Vision adaptation",
|
| 36 |
+
},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Placeholder for future model expansions
|
| 40 |
+
_FUTURE_MODELS = {
|
| 41 |
+
"densenet1d": {
|
| 42 |
+
"description": "DenseNet1D for spectroscopy (placeholder)",
|
| 43 |
+
"status": "planned",
|
| 44 |
+
},
|
| 45 |
+
"ensemble_cnn": {
|
| 46 |
+
"description": "Ensemble of CNN variants (placeholder)",
|
| 47 |
+
"status": "planned",
|
| 48 |
+
},
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
def choices():
|
| 53 |
"""Return the list of available model keys."""
|
| 54 |
return list(_REGISTRY.keys())
|
| 55 |
|
| 56 |
+
|
| 57 |
+
def planned_models():
|
| 58 |
+
"""Return the list of planned future model keys."""
|
| 59 |
+
return list(_FUTURE_MODELS.keys())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def build(name: str, input_length: int):
|
| 63 |
"""Instantiate a model by short name with the given input length."""
|
| 64 |
if name not in _REGISTRY:
|
| 65 |
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
| 66 |
return _REGISTRY[name](input_length)
|
| 67 |
|
| 68 |
+
|
| 69 |
+
def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
|
| 70 |
+
"""Nuild multiple models for comparison."""
|
| 71 |
+
models = {}
|
| 72 |
+
for name in names:
|
| 73 |
+
if name in _REGISTRY:
|
| 74 |
+
models[name] = build(name, input_length)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
|
| 77 |
+
return models
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def register_model(
|
| 81 |
+
name: str, builder: Callable[[int], object], spec: Dict[str, Any]
|
| 82 |
+
) -> None:
|
| 83 |
+
"""Dynamically register a new model."""
|
| 84 |
+
if name in _REGISTRY:
|
| 85 |
+
raise ValueError(f"Model '{name}' already registered.")
|
| 86 |
+
if not callable(builder):
|
| 87 |
+
raise TypeError("Builder must be a callable that accepts an integer argument.")
|
| 88 |
+
_REGISTRY[name] = builder
|
| 89 |
+
_MODEL_SPECS[name] = spec
|
| 90 |
+
|
| 91 |
+
|
| 92 |
def spec(name: str):
|
| 93 |
"""Return expected input length and number of classes for a model key."""
|
| 94 |
+
if name in _MODEL_SPECS:
|
| 95 |
+
return _MODEL_SPECS[name].copy()
|
| 96 |
+
raise KeyError(f"Unknown model '{name}'. Available: {choices()}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_model_info(name: str) -> Dict[str, Any]:
|
| 100 |
+
"""Get comprehensive model information including metadata."""
|
| 101 |
+
if name in _MODEL_SPECS:
|
| 102 |
+
return _MODEL_SPECS[name].copy()
|
| 103 |
+
elif name in _FUTURE_MODELS:
|
| 104 |
+
return _FUTURE_MODELS[name].copy()
|
| 105 |
+
else:
|
| 106 |
+
raise KeyError(f"Unknown model '{name}'")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def models_for_modality(modality: str) -> List[str]:
|
| 110 |
+
"""Get list of models that support a specific modality."""
|
| 111 |
+
compatible = []
|
| 112 |
+
for name, spec_info in _MODEL_SPECS.items():
|
| 113 |
+
if modality in spec_info.get("modalities", []):
|
| 114 |
+
compatible.append(name)
|
| 115 |
+
return compatible
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def validate_model_list(names: List[str]) -> List[str]:
|
| 119 |
+
"""Validate and return list of available models from input list."""
|
| 120 |
+
available = choices()
|
| 121 |
+
valid_models = []
|
| 122 |
+
for name in names:
|
| 123 |
+
if name is available:
|
| 124 |
+
valid_models.append(name)
|
| 125 |
+
return valid_models
|
| 126 |
|
| 127 |
|
| 128 |
+
__all__ = [
|
| 129 |
+
"choices",
|
| 130 |
+
"build",
|
| 131 |
+
"spec",
|
| 132 |
+
"build_multiple",
|
| 133 |
+
"register_model",
|
| 134 |
+
"get_model_info",
|
| 135 |
+
"models_for_modality",
|
| 136 |
+
"validate_model_list",
|
| 137 |
+
"planned_models",
|
| 138 |
+
]
|
modules/ui_components.py
CHANGED
|
@@ -13,9 +13,9 @@ from modules.callbacks import (
|
|
| 13 |
on_model_change,
|
| 14 |
on_input_mode_change,
|
| 15 |
on_sample_change,
|
|
|
|
| 16 |
reset_ephemeral_state,
|
| 17 |
log_message,
|
| 18 |
-
clear_batch_results,
|
| 19 |
)
|
| 20 |
from core_logic import (
|
| 21 |
get_sample_files,
|
|
@@ -24,7 +24,6 @@ from core_logic import (
|
|
| 24 |
parse_spectrum_data,
|
| 25 |
label_file,
|
| 26 |
)
|
| 27 |
-
from modules.callbacks import reset_results
|
| 28 |
from utils.results_manager import ResultsManager
|
| 29 |
from utils.confidence import calculate_softmax_confidence
|
| 30 |
from utils.multifile import process_multiple_files, display_batch_results
|
|
@@ -41,7 +40,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 41 |
"""Create spectrum visualization plot"""
|
| 42 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
| 43 |
|
| 44 |
-
#
|
| 45 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
| 46 |
ax[0].set_title("Raw Input Spectrum")
|
| 47 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
@@ -49,7 +48,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 49 |
ax[0].grid(True, alpha=0.3)
|
| 50 |
ax[0].legend()
|
| 51 |
|
| 52 |
-
#
|
| 53 |
ax[1].plot(
|
| 54 |
x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
|
| 55 |
)
|
|
@@ -60,7 +59,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 60 |
ax[1].legend()
|
| 61 |
|
| 62 |
fig.tight_layout()
|
| 63 |
-
#
|
| 64 |
buf = io.BytesIO()
|
| 65 |
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
| 66 |
buf.seek(0)
|
|
@@ -69,6 +68,9 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 69 |
return Image.open(buf)
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
def render_confidence_progress(
|
| 73 |
probs: np.ndarray,
|
| 74 |
labels: list[str] = ["Stable", "Weathered"],
|
|
@@ -114,7 +116,10 @@ def render_confidence_progress(
|
|
| 114 |
st.markdown("")
|
| 115 |
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
if d is None:
|
| 119 |
d = {}
|
| 120 |
if not d:
|
|
@@ -126,6 +131,9 @@ def render_kv_grid(d: dict = {}, ncols: int = 2):
|
|
| 126 |
st.caption(f"**{k}:** {v}")
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
| 129 |
def render_model_meta(model_choice: str):
|
| 130 |
info = MODEL_CONFIG.get(model_choice, {})
|
| 131 |
emoji = info.get("emoji", "")
|
|
@@ -143,6 +151,9 @@ def render_model_meta(model_choice: str):
|
|
| 143 |
st.caption(desc)
|
| 144 |
|
| 145 |
|
|
|
|
|
|
|
|
|
|
| 146 |
def get_confidence_description(logit_margin):
|
| 147 |
"""Get human-readable confidence description"""
|
| 148 |
if logit_margin > 1000:
|
|
@@ -155,13 +166,35 @@ def get_confidence_description(logit_margin):
|
|
| 155 |
return "LOW", "🔴"
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
| 158 |
def render_sidebar():
|
| 159 |
with st.sidebar:
|
| 160 |
# Header
|
| 161 |
st.header("AI-Driven Polymer Classification")
|
| 162 |
st.caption(
|
| 163 |
-
"Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
model_labels = [
|
| 166 |
f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
|
| 167 |
]
|
|
@@ -173,10 +206,10 @@ def render_sidebar():
|
|
| 173 |
)
|
| 174 |
model_choice = selected_label.split(" ", 1)[1]
|
| 175 |
|
| 176 |
-
#
|
| 177 |
render_model_meta(model_choice)
|
| 178 |
|
| 179 |
-
#
|
| 180 |
with st.expander("About This App", icon=":material/info:", expanded=False):
|
| 181 |
st.markdown(
|
| 182 |
"""
|
|
@@ -184,8 +217,9 @@ def render_sidebar():
|
|
| 184 |
|
| 185 |
**Purpose**: Classify polymer degradation using AI<br>
|
| 186 |
**Input**: Raman spectroscopy .txt files<br>
|
| 187 |
-
**Models**: CNN architectures for
|
| 188 |
-
**
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
**Contributors**<br>
|
|
@@ -207,11 +241,7 @@ def render_sidebar():
|
|
| 207 |
)
|
| 208 |
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
# In modules/ui_components.py
|
| 213 |
-
|
| 214 |
-
|
| 215 |
def render_input_column():
|
| 216 |
st.markdown("##### Data Input")
|
| 217 |
|
|
@@ -224,22 +254,20 @@ def render_input_column():
|
|
| 224 |
)
|
| 225 |
|
| 226 |
# == Input Mode Logic ==
|
| 227 |
-
# ... (The if/elif/else block for Upload, Batch, and Sample modes remains exactly the same) ...
|
| 228 |
-
# ==Upload tab==
|
| 229 |
if mode == "Upload File":
|
| 230 |
upload_key = st.session_state["current_upload_key"]
|
| 231 |
up = st.file_uploader(
|
| 232 |
-
"Upload
|
| 233 |
-
type="txt",
|
| 234 |
-
help="Upload
|
| 235 |
key=upload_key, # ← versioned key
|
| 236 |
)
|
| 237 |
|
| 238 |
-
#
|
| 239 |
if up is not None:
|
| 240 |
raw = up.read()
|
| 241 |
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
| 242 |
-
#
|
| 243 |
if (
|
| 244 |
st.session_state.get("filename") != getattr(up, "name", None)
|
| 245 |
or st.session_state.get("input_source") != "upload"
|
|
@@ -255,23 +283,20 @@ def render_input_column():
|
|
| 255 |
st.session_state["status_type"] = "success"
|
| 256 |
reset_results("New file uploaded")
|
| 257 |
|
| 258 |
-
#
|
| 259 |
elif mode == "Batch Upload":
|
| 260 |
st.session_state["batch_mode"] = True
|
| 261 |
-
# --- START: BUG 1 & 3 FIX ---
|
| 262 |
# Use a versioned key to ensure the file uploader resets properly.
|
| 263 |
batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
|
| 264 |
uploaded_files = st.file_uploader(
|
| 265 |
-
"Upload multiple
|
| 266 |
-
type="txt",
|
| 267 |
accept_multiple_files=True,
|
| 268 |
-
help="Upload
|
| 269 |
key=batch_upload_key,
|
| 270 |
)
|
| 271 |
-
# --- END: BUG 1 & 3 FIX ---
|
| 272 |
|
| 273 |
if uploaded_files:
|
| 274 |
-
# --- START: Bug 1 Fix ---
|
| 275 |
# Use a dictionary to keep only unique files based on name and size
|
| 276 |
unique_files = {(file.name, file.size): file for file in uploaded_files}
|
| 277 |
unique_file_list = list(unique_files.values())
|
|
@@ -281,9 +306,7 @@ def render_input_column():
|
|
| 281 |
|
| 282 |
# Optionally, inform the user that duplicates were removed
|
| 283 |
if num_uploaded > num_unique:
|
| 284 |
-
st.info(
|
| 285 |
-
f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed."
|
| 286 |
-
)
|
| 287 |
|
| 288 |
# Use the unique list
|
| 289 |
st.session_state["batch_files"] = unique_file_list
|
|
@@ -291,7 +314,6 @@ def render_input_column():
|
|
| 291 |
f"{num_unique} ready for batch analysis"
|
| 292 |
)
|
| 293 |
st.session_state["status_type"] = "success"
|
| 294 |
-
# --- END: Bug 1 Fix ---
|
| 295 |
else:
|
| 296 |
st.session_state["batch_files"] = []
|
| 297 |
# This check prevents resetting the status if files are already staged
|
|
@@ -301,7 +323,7 @@ def render_input_column():
|
|
| 301 |
)
|
| 302 |
st.session_state["status_type"] = "info"
|
| 303 |
|
| 304 |
-
#
|
| 305 |
elif mode == "Sample Data":
|
| 306 |
st.session_state["batch_mode"] = False
|
| 307 |
sample_files = get_sample_files()
|
|
@@ -330,9 +352,6 @@ def render_input_column():
|
|
| 330 |
else:
|
| 331 |
st.info(msg)
|
| 332 |
|
| 333 |
-
# --- DE-NESTED LOGIC STARTS HERE ---
|
| 334 |
-
# This code now runs on EVERY execution, guaranteeing the buttons will appear.
|
| 335 |
-
|
| 336 |
# Safely get model choice from session state
|
| 337 |
model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
|
| 338 |
model = load_model(model_choice)
|
|
@@ -388,7 +407,7 @@ def render_input_column():
|
|
| 388 |
st.error(f"Error processing spectrum data: {e}")
|
| 389 |
|
| 390 |
|
| 391 |
-
#
|
| 392 |
|
| 393 |
|
| 394 |
def render_results_column():
|
|
@@ -410,7 +429,7 @@ def render_results_column():
|
|
| 410 |
filename = st.session_state.get("filename", "Unknown")
|
| 411 |
|
| 412 |
if all(v is not None for v in [x_raw, y_raw, y_resampled]):
|
| 413 |
-
#
|
| 414 |
if y_resampled is None:
|
| 415 |
raise ValueError(
|
| 416 |
"y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
|
|
@@ -437,14 +456,14 @@ def render_results_column():
|
|
| 437 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
| 438 |
)
|
| 439 |
|
| 440 |
-
#
|
| 441 |
true_label_idx = label_file(filename)
|
| 442 |
true_label_str = (
|
| 443 |
LABEL_MAP.get(true_label_idx, "Unknown")
|
| 444 |
if true_label_idx is not None
|
| 445 |
else "Unknown"
|
| 446 |
)
|
| 447 |
-
#
|
| 448 |
predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
|
| 449 |
|
| 450 |
# Enhanced confidence calculation
|
|
@@ -455,7 +474,7 @@ def render_results_column():
|
|
| 455 |
)
|
| 456 |
confidence_desc = confidence_level
|
| 457 |
else:
|
| 458 |
-
# Fallback to
|
| 459 |
logit_margin = abs(
|
| 460 |
(logits_list[0] - logits_list[1])
|
| 461 |
if logits_list is not None and len(logits_list) >= 2
|
|
@@ -487,7 +506,7 @@ def render_results_column():
|
|
| 487 |
},
|
| 488 |
)
|
| 489 |
|
| 490 |
-
#
|
| 491 |
model_choice = (
|
| 492 |
st.session_state.get("model_select", "").split(" ", 1)[1]
|
| 493 |
if "model_select" in st.session_state
|
|
@@ -505,7 +524,6 @@ def render_results_column():
|
|
| 505 |
if os.path.exists(model_path)
|
| 506 |
else "N/A"
|
| 507 |
)
|
| 508 |
-
# Removed unused variable 'input_tensor'
|
| 509 |
|
| 510 |
start_render = time.time()
|
| 511 |
|
|
@@ -590,17 +608,13 @@ def render_results_column():
|
|
| 590 |
""",
|
| 591 |
unsafe_allow_html=True,
|
| 592 |
)
|
| 593 |
-
# --- END: CONSOLIDATED CONFIDENCE ANALYSIS ---
|
| 594 |
|
| 595 |
st.divider()
|
| 596 |
|
| 597 |
-
#
|
| 598 |
-
# Secondary info is now a clean, single-line caption
|
| 599 |
st.caption(
|
| 600 |
f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
|
| 601 |
)
|
| 602 |
-
# --- END: CLEAN METADATA FOOTER ---
|
| 603 |
-
|
| 604 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 605 |
|
| 606 |
elif active_tab == "Technical":
|
|
@@ -918,7 +932,7 @@ def render_results_column():
|
|
| 918 |
"""
|
| 919 |
)
|
| 920 |
else:
|
| 921 |
-
#
|
| 922 |
st.markdown(
|
| 923 |
"""
|
| 924 |
##### How to Get Started
|
|
@@ -948,3 +962,416 @@ def render_results_column():
|
|
| 948 |
- 🏭 Quality control in manufacturing
|
| 949 |
"""
|
| 950 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
on_model_change,
|
| 14 |
on_input_mode_change,
|
| 15 |
on_sample_change,
|
| 16 |
+
reset_results,
|
| 17 |
reset_ephemeral_state,
|
| 18 |
log_message,
|
|
|
|
| 19 |
)
|
| 20 |
from core_logic import (
|
| 21 |
get_sample_files,
|
|
|
|
| 24 |
parse_spectrum_data,
|
| 25 |
label_file,
|
| 26 |
)
|
|
|
|
| 27 |
from utils.results_manager import ResultsManager
|
| 28 |
from utils.confidence import calculate_softmax_confidence
|
| 29 |
from utils.multifile import process_multiple_files, display_batch_results
|
|
|
|
| 40 |
"""Create spectrum visualization plot"""
|
| 41 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
| 42 |
|
| 43 |
+
# Raw spectrum
|
| 44 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
| 45 |
ax[0].set_title("Raw Input Spectrum")
|
| 46 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
|
|
| 48 |
ax[0].grid(True, alpha=0.3)
|
| 49 |
ax[0].legend()
|
| 50 |
|
| 51 |
+
# Resampled spectrum
|
| 52 |
ax[1].plot(
|
| 53 |
x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
|
| 54 |
)
|
|
|
|
| 59 |
ax[1].legend()
|
| 60 |
|
| 61 |
fig.tight_layout()
|
| 62 |
+
# Convert to image
|
| 63 |
buf = io.BytesIO()
|
| 64 |
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
| 65 |
buf.seek(0)
|
|
|
|
| 68 |
return Image.open(buf)
|
| 69 |
|
| 70 |
|
| 71 |
+
# //////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
|
| 74 |
def render_confidence_progress(
|
| 75 |
probs: np.ndarray,
|
| 76 |
labels: list[str] = ["Stable", "Weathered"],
|
|
|
|
| 116 |
st.markdown("")
|
| 117 |
|
| 118 |
|
| 119 |
+
from typing import Optional
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
|
| 123 |
if d is None:
|
| 124 |
d = {}
|
| 125 |
if not d:
|
|
|
|
| 131 |
st.caption(f"**{k}:** {v}")
|
| 132 |
|
| 133 |
|
| 134 |
+
# //////////////////////////////////////////
|
| 135 |
+
|
| 136 |
+
|
| 137 |
def render_model_meta(model_choice: str):
|
| 138 |
info = MODEL_CONFIG.get(model_choice, {})
|
| 139 |
emoji = info.get("emoji", "")
|
|
|
|
| 151 |
st.caption(desc)
|
| 152 |
|
| 153 |
|
| 154 |
+
# //////////////////////////////////////////
|
| 155 |
+
|
| 156 |
+
|
| 157 |
def get_confidence_description(logit_margin):
|
| 158 |
"""Get human-readable confidence description"""
|
| 159 |
if logit_margin > 1000:
|
|
|
|
| 166 |
return "LOW", "🔴"
|
| 167 |
|
| 168 |
|
| 169 |
+
# //////////////////////////////////////////
|
| 170 |
+
|
| 171 |
+
|
| 172 |
def render_sidebar():
|
| 173 |
with st.sidebar:
|
| 174 |
# Header
|
| 175 |
st.header("AI-Driven Polymer Classification")
|
| 176 |
st.caption(
|
| 177 |
+
"Predict polymer degradation (Stable vs Weathered) from Raman/FTIR spectra using validated CNN models. — v0.01"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Modality Selection
|
| 181 |
+
st.markdown("##### Spectroscopy Modality")
|
| 182 |
+
modality = st.selectbox(
|
| 183 |
+
"Choose Modality",
|
| 184 |
+
["raman", "ftir"],
|
| 185 |
+
index=0,
|
| 186 |
+
key="modality_select",
|
| 187 |
+
format_func=lambda x: f"{'Raman' if x == 'raman' else 'FTIR'}",
|
| 188 |
)
|
| 189 |
+
|
| 190 |
+
# Display modality info
|
| 191 |
+
if modality == "ftir":
|
| 192 |
+
st.info("FTIR mode: 400-4000 cm-1 range with atmospheric correction")
|
| 193 |
+
else:
|
| 194 |
+
st.info("Raman mode: 200-4000 cm-1 range with standard preprocessing")
|
| 195 |
+
|
| 196 |
+
# Model selection
|
| 197 |
+
st.markdown("##### AI Model Selection")
|
| 198 |
model_labels = [
|
| 199 |
f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
|
| 200 |
]
|
|
|
|
| 206 |
)
|
| 207 |
model_choice = selected_label.split(" ", 1)[1]
|
| 208 |
|
| 209 |
+
# Compact metadata directly under dropdown
|
| 210 |
render_model_meta(model_choice)
|
| 211 |
|
| 212 |
+
# Collapsed info to reduce clutter
|
| 213 |
with st.expander("About This App", icon=":material/info:", expanded=False):
|
| 214 |
st.markdown(
|
| 215 |
"""
|
|
|
|
| 217 |
|
| 218 |
**Purpose**: Classify polymer degradation using AI<br>
|
| 219 |
**Input**: Raman spectroscopy .txt files<br>
|
| 220 |
+
**Models**: CNN architectures for classification<br>
|
| 221 |
+
**Modalities**: Raman and FTIR spectroscopy support<br>
|
| 222 |
+
**Features**: Multi-model comparison and analysis<br>
|
| 223 |
|
| 224 |
|
| 225 |
**Contributors**<br>
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
| 244 |
+
# //////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
def render_input_column():
|
| 246 |
st.markdown("##### Data Input")
|
| 247 |
|
|
|
|
| 254 |
)
|
| 255 |
|
| 256 |
# == Input Mode Logic ==
|
|
|
|
|
|
|
| 257 |
if mode == "Upload File":
|
| 258 |
upload_key = st.session_state["current_upload_key"]
|
| 259 |
up = st.file_uploader(
|
| 260 |
+
"Upload spectrum file (.txt, .csv, .json)",
|
| 261 |
+
type=["txt", "csv", "json"],
|
| 262 |
+
help="Upload spectroscopy data: TXT (2-column), CSV (with headers), or JSON format",
|
| 263 |
key=upload_key, # ← versioned key
|
| 264 |
)
|
| 265 |
|
| 266 |
+
# Process change immediately
|
| 267 |
if up is not None:
|
| 268 |
raw = up.read()
|
| 269 |
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
| 270 |
+
# only reparse if its a different file|source
|
| 271 |
if (
|
| 272 |
st.session_state.get("filename") != getattr(up, "name", None)
|
| 273 |
or st.session_state.get("input_source") != "upload"
|
|
|
|
| 283 |
st.session_state["status_type"] = "success"
|
| 284 |
reset_results("New file uploaded")
|
| 285 |
|
| 286 |
+
# Batch Upload tab
|
| 287 |
elif mode == "Batch Upload":
|
| 288 |
st.session_state["batch_mode"] = True
|
|
|
|
| 289 |
# Use a versioned key to ensure the file uploader resets properly.
|
| 290 |
batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
|
| 291 |
uploaded_files = st.file_uploader(
|
| 292 |
+
"Upload multiple spectrum files (.txt, .csv, .json)",
|
| 293 |
+
type=["txt", "csv", "json"],
|
| 294 |
accept_multiple_files=True,
|
| 295 |
+
help="Upload spectroscopy files in TXT, CSV, or JSON format.",
|
| 296 |
key=batch_upload_key,
|
| 297 |
)
|
|
|
|
| 298 |
|
| 299 |
if uploaded_files:
|
|
|
|
| 300 |
# Use a dictionary to keep only unique files based on name and size
|
| 301 |
unique_files = {(file.name, file.size): file for file in uploaded_files}
|
| 302 |
unique_file_list = list(unique_files.values())
|
|
|
|
| 306 |
|
| 307 |
# Optionally, inform the user that duplicates were removed
|
| 308 |
if num_uploaded > num_unique:
|
| 309 |
+
st.info(f"{num_uploaded - num_unique} duplicate file(s) were removed.")
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# Use the unique list
|
| 312 |
st.session_state["batch_files"] = unique_file_list
|
|
|
|
| 314 |
f"{num_unique} ready for batch analysis"
|
| 315 |
)
|
| 316 |
st.session_state["status_type"] = "success"
|
|
|
|
| 317 |
else:
|
| 318 |
st.session_state["batch_files"] = []
|
| 319 |
# This check prevents resetting the status if files are already staged
|
|
|
|
| 323 |
)
|
| 324 |
st.session_state["status_type"] = "info"
|
| 325 |
|
| 326 |
+
# Sample tab
|
| 327 |
elif mode == "Sample Data":
|
| 328 |
st.session_state["batch_mode"] = False
|
| 329 |
sample_files = get_sample_files()
|
|
|
|
| 352 |
else:
|
| 353 |
st.info(msg)
|
| 354 |
|
|
|
|
|
|
|
|
|
|
| 355 |
# Safely get model choice from session state
|
| 356 |
model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
|
| 357 |
model = load_model(model_choice)
|
|
|
|
| 407 |
st.error(f"Error processing spectrum data: {e}")
|
| 408 |
|
| 409 |
|
| 410 |
+
# //////////////////////////////////////////
|
| 411 |
|
| 412 |
|
| 413 |
def render_results_column():
|
|
|
|
| 429 |
filename = st.session_state.get("filename", "Unknown")
|
| 430 |
|
| 431 |
if all(v is not None for v in [x_raw, y_raw, y_resampled]):
|
| 432 |
+
# Run inference
|
| 433 |
if y_resampled is None:
|
| 434 |
raise ValueError(
|
| 435 |
"y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
|
|
|
|
| 456 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
| 457 |
)
|
| 458 |
|
| 459 |
+
# Get ground truth
|
| 460 |
true_label_idx = label_file(filename)
|
| 461 |
true_label_str = (
|
| 462 |
LABEL_MAP.get(true_label_idx, "Unknown")
|
| 463 |
if true_label_idx is not None
|
| 464 |
else "Unknown"
|
| 465 |
)
|
| 466 |
+
# Get prediction
|
| 467 |
predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
|
| 468 |
|
| 469 |
# Enhanced confidence calculation
|
|
|
|
| 474 |
)
|
| 475 |
confidence_desc = confidence_level
|
| 476 |
else:
|
| 477 |
+
# Fallback to legacy method
|
| 478 |
logit_margin = abs(
|
| 479 |
(logits_list[0] - logits_list[1])
|
| 480 |
if logits_list is not None and len(logits_list) >= 2
|
|
|
|
| 506 |
},
|
| 507 |
)
|
| 508 |
|
| 509 |
+
# Precompute Stats
|
| 510 |
model_choice = (
|
| 511 |
st.session_state.get("model_select", "").split(" ", 1)[1]
|
| 512 |
if "model_select" in st.session_state
|
|
|
|
| 524 |
if os.path.exists(model_path)
|
| 525 |
else "N/A"
|
| 526 |
)
|
|
|
|
| 527 |
|
| 528 |
start_render = time.time()
|
| 529 |
|
|
|
|
| 608 |
""",
|
| 609 |
unsafe_allow_html=True,
|
| 610 |
)
|
|
|
|
| 611 |
|
| 612 |
st.divider()
|
| 613 |
|
| 614 |
+
# METADATA FOOTER
|
|
|
|
| 615 |
st.caption(
|
| 616 |
f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
|
| 617 |
)
|
|
|
|
|
|
|
| 618 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 619 |
|
| 620 |
elif active_tab == "Technical":
|
|
|
|
| 932 |
"""
|
| 933 |
)
|
| 934 |
else:
|
| 935 |
+
# Getting Started
|
| 936 |
st.markdown(
|
| 937 |
"""
|
| 938 |
##### How to Get Started
|
|
|
|
| 962 |
- 🏭 Quality control in manufacturing
|
| 963 |
"""
|
| 964 |
)
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
# //////////////////////////////////////////
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def render_comparison_tab():
|
| 971 |
+
"""Render the multi-model comparison interface"""
|
| 972 |
+
import streamlit as st
|
| 973 |
+
import matplotlib.pyplot as plt
|
| 974 |
+
from models.registry import choices, validate_model_list
|
| 975 |
+
from utils.results_manager import ResultsManager
|
| 976 |
+
from core_logic import get_sample_files, run_inference, parse_spectrum_data
|
| 977 |
+
from utils.preprocessing import preprocess_spectrum
|
| 978 |
+
from utils.multifile import parse_spectrum_data
|
| 979 |
+
import numpy as np
|
| 980 |
+
import time
|
| 981 |
+
|
| 982 |
+
st.markdown("### Multi-Model Comparison Analysis")
|
| 983 |
+
st.markdown(
|
| 984 |
+
"Compare predictions across different AI models for comprehensive analysis."
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# Model selection for comparison
|
| 988 |
+
st.markdown("##### Select Models for Comparison")
|
| 989 |
+
|
| 990 |
+
available_models = choices()
|
| 991 |
+
selected_models = st.multiselect(
|
| 992 |
+
"Choose models to compare",
|
| 993 |
+
available_models,
|
| 994 |
+
default=(
|
| 995 |
+
available_models[:2] if len(available_models) >= 2 else available_models
|
| 996 |
+
),
|
| 997 |
+
help="Select 2 or more models to compare their predictions side-by-side",
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
if len(selected_models) < 2:
|
| 1001 |
+
st.warning("⚠️ Please select at least 2 models for comparison.")
|
| 1002 |
+
|
| 1003 |
+
# Input selection for comparison
|
| 1004 |
+
col1, col2 = st.columns([1, 1.5])
|
| 1005 |
+
|
| 1006 |
+
with col1:
|
| 1007 |
+
st.markdown("###### Input Data")
|
| 1008 |
+
|
| 1009 |
+
# File upload for comparison
|
| 1010 |
+
comparison_file = st.file_uploader(
|
| 1011 |
+
"Upload spectrum for comparison",
|
| 1012 |
+
type=["txt", "csv", "json"],
|
| 1013 |
+
key="comparison_file_upload",
|
| 1014 |
+
help="Upload a spectrum file to test across all selected models",
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
# Or select sample data
|
| 1018 |
+
selected_sample = None # Initialize with a default value
|
| 1019 |
+
sample_files = get_sample_files()
|
| 1020 |
+
if sample_files:
|
| 1021 |
+
sample_options = ["-- Select Sample --"] + [p.name for p in sample_files]
|
| 1022 |
+
selected_sample = st.selectbox(
|
| 1023 |
+
"Or choose sample data", sample_options, key="comparison_sample_select"
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
# Get modality from session state
|
| 1027 |
+
modality = st.session_state.get("modality_select", "raman")
|
| 1028 |
+
st.info(f"Using {modality.upper()} preprocessing parameters")
|
| 1029 |
+
|
| 1030 |
+
# Run comparison button
|
| 1031 |
+
run_comparison = st.button(
|
| 1032 |
+
"Run Multi-Model Comparison",
|
| 1033 |
+
type="primary",
|
| 1034 |
+
disabled=not (
|
| 1035 |
+
comparison_file
|
| 1036 |
+
or (sample_files and selected_sample != "-- Select Sample --")
|
| 1037 |
+
),
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
with col2:
|
| 1041 |
+
st.markdown("###### Comparison Results")
|
| 1042 |
+
|
| 1043 |
+
if run_comparison:
|
| 1044 |
+
# Determine input source
|
| 1045 |
+
input_text = None
|
| 1046 |
+
filename = "unknown"
|
| 1047 |
+
|
| 1048 |
+
if comparison_file:
|
| 1049 |
+
raw = comparison_file.read()
|
| 1050 |
+
input_text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
| 1051 |
+
filename = comparison_file.name
|
| 1052 |
+
elif sample_files and selected_sample != "-- Select Sample --":
|
| 1053 |
+
sample_path = next(p for p in sample_files if p.name == selected_sample)
|
| 1054 |
+
with open(sample_path, "r") as f:
|
| 1055 |
+
input_text = f.read()
|
| 1056 |
+
filename = selected_sample
|
| 1057 |
+
|
| 1058 |
+
if input_text:
|
| 1059 |
+
try:
|
| 1060 |
+
# Parse spectrum data
|
| 1061 |
+
x_raw, y_raw = parse_spectrum_data(
|
| 1062 |
+
str(input_text), filename or "unknown_filename"
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
# Store results
|
| 1066 |
+
comparison_results = {}
|
| 1067 |
+
processing_times = {}
|
| 1068 |
+
|
| 1069 |
+
progress_bar = st.progress(0)
|
| 1070 |
+
status_text = st.empty()
|
| 1071 |
+
|
| 1072 |
+
for i, model_name in enumerate(selected_models):
|
| 1073 |
+
status_text.text(f"Running inference with {model_name}...")
|
| 1074 |
+
|
| 1075 |
+
start_time = time.time()
|
| 1076 |
+
|
| 1077 |
+
# Preprocess spectrum with modality-specific parameters
|
| 1078 |
+
_, y_processed = preprocess_spectrum(
|
| 1079 |
+
x_raw, y_raw, modality=modality, target_len=500
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
# Run inference
|
| 1083 |
+
prediction, logits_list, probs, inference_time, logits = (
|
| 1084 |
+
run_inference(y_processed, model_name)
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
processing_time = time.time() - start_time
|
| 1088 |
+
|
| 1089 |
+
if prediction is not None:
|
| 1090 |
+
# Map prediction to class name
|
| 1091 |
+
class_names = ["Stable", "Weathered"]
|
| 1092 |
+
predicted_class = (
|
| 1093 |
+
class_names[int(prediction)]
|
| 1094 |
+
if prediction < len(class_names)
|
| 1095 |
+
else f"Class_{prediction}"
|
| 1096 |
+
)
|
| 1097 |
+
confidence = (
|
| 1098 |
+
max(probs)
|
| 1099 |
+
if probs is not None and len(probs) > 0
|
| 1100 |
+
else 0.0
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
comparison_results[model_name] = {
|
| 1104 |
+
"prediction": prediction,
|
| 1105 |
+
"predicted_class": predicted_class,
|
| 1106 |
+
"confidence": confidence,
|
| 1107 |
+
"probs": probs if probs is not None else [],
|
| 1108 |
+
"logits": (
|
| 1109 |
+
logits_list if logits_list is not None else []
|
| 1110 |
+
),
|
| 1111 |
+
"processing_time": processing_time,
|
| 1112 |
+
}
|
| 1113 |
+
processing_times[model_name] = processing_time
|
| 1114 |
+
|
| 1115 |
+
progress_bar.progress((i + 1) / len(selected_models))
|
| 1116 |
+
|
| 1117 |
+
status_text.text("Comparison complete!")
|
| 1118 |
+
|
| 1119 |
+
# Display results
|
| 1120 |
+
if comparison_results:
|
| 1121 |
+
st.markdown("###### Model Predictions")
|
| 1122 |
+
|
| 1123 |
+
# Create comparison table
|
| 1124 |
+
import pandas as pd
|
| 1125 |
+
|
| 1126 |
+
table_data = []
|
| 1127 |
+
for model_name, result in comparison_results.items():
|
| 1128 |
+
row = {
|
| 1129 |
+
"Model": model_name,
|
| 1130 |
+
"Prediction": result["predicted_class"],
|
| 1131 |
+
"Confidence": f"{result['confidence']:.3f}",
|
| 1132 |
+
"Processing Time (s)": f"{result['processing_time']:.3f}",
|
| 1133 |
+
}
|
| 1134 |
+
table_data.append(row)
|
| 1135 |
+
|
| 1136 |
+
df = pd.DataFrame(table_data)
|
| 1137 |
+
st.dataframe(df, use_container_width=True)
|
| 1138 |
+
|
| 1139 |
+
# Show confidence comparison
|
| 1140 |
+
st.markdown("##### Confidence Comparison")
|
| 1141 |
+
conf_col1, conf_col2 = st.columns(2)
|
| 1142 |
+
|
| 1143 |
+
with conf_col1:
|
| 1144 |
+
# Bar chart of confidences
|
| 1145 |
+
models = list(comparison_results.keys())
|
| 1146 |
+
confidences = [
|
| 1147 |
+
comparison_results[m]["confidence"] for m in models
|
| 1148 |
+
]
|
| 1149 |
+
|
| 1150 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 1151 |
+
bars = ax.bar(
|
| 1152 |
+
models,
|
| 1153 |
+
confidences,
|
| 1154 |
+
alpha=0.7,
|
| 1155 |
+
color=["steelblue", "orange", "green", "red"][
|
| 1156 |
+
: len(models)
|
| 1157 |
+
],
|
| 1158 |
+
)
|
| 1159 |
+
ax.set_ylabel("Confidence")
|
| 1160 |
+
ax.set_title("Model Confidence Comparison")
|
| 1161 |
+
ax.set_ylim(0, 1)
|
| 1162 |
+
plt.xticks(rotation=45)
|
| 1163 |
+
|
| 1164 |
+
# Add value labels on bars
|
| 1165 |
+
for bar, conf in zip(bars, confidences):
|
| 1166 |
+
height = bar.get_height()
|
| 1167 |
+
ax.text(
|
| 1168 |
+
bar.get_x() + bar.get_width() / 2.0,
|
| 1169 |
+
height + 0.01,
|
| 1170 |
+
f"{conf:.3f}",
|
| 1171 |
+
ha="center",
|
| 1172 |
+
va="bottom",
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
plt.tight_layout()
|
| 1176 |
+
st.pyplot(fig)
|
| 1177 |
+
|
| 1178 |
+
with conf_col2:
|
| 1179 |
+
# Agreement analysis
|
| 1180 |
+
predictions = [
|
| 1181 |
+
comparison_results[m]["prediction"] for m in models
|
| 1182 |
+
]
|
| 1183 |
+
unique_predictions = set(predictions)
|
| 1184 |
+
|
| 1185 |
+
if len(unique_predictions) == 1:
|
| 1186 |
+
st.success("✅ All models agree on the prediction!")
|
| 1187 |
+
else:
|
| 1188 |
+
st.warning("⚠️ Models disagree on the prediction")
|
| 1189 |
+
|
| 1190 |
+
# Show prediction distribution
|
| 1191 |
+
from collections import Counter
|
| 1192 |
+
|
| 1193 |
+
pred_counts = Counter(predictions)
|
| 1194 |
+
|
| 1195 |
+
st.markdown("**Prediction Distribution:**")
|
| 1196 |
+
for pred, count in pred_counts.items():
|
| 1197 |
+
class_name = (
|
| 1198 |
+
["Stable", "Weathered"][pred]
|
| 1199 |
+
if pred < 2
|
| 1200 |
+
else f"Class_{pred}"
|
| 1201 |
+
)
|
| 1202 |
+
percentage = (count / len(predictions)) * 100
|
| 1203 |
+
st.write(
|
| 1204 |
+
f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
# Performance metrics
|
| 1208 |
+
st.markdown("##### Performance Metrics")
|
| 1209 |
+
perf_col1, perf_col2 = st.columns(2)
|
| 1210 |
+
|
| 1211 |
+
with perf_col1:
|
| 1212 |
+
avg_time = np.mean(list(processing_times.values()))
|
| 1213 |
+
fastest_model = min(
|
| 1214 |
+
processing_times.keys(),
|
| 1215 |
+
key=lambda k: processing_times[k],
|
| 1216 |
+
)
|
| 1217 |
+
slowest_model = max(
|
| 1218 |
+
processing_times.keys(),
|
| 1219 |
+
key=lambda k: processing_times[k],
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
st.metric("Average Processing Time", f"{avg_time:.3f}s")
|
| 1223 |
+
st.metric(
|
| 1224 |
+
"Fastest Model",
|
| 1225 |
+
f"{fastest_model}",
|
| 1226 |
+
f"{processing_times[fastest_model]:.3f}s",
|
| 1227 |
+
)
|
| 1228 |
+
st.metric(
|
| 1229 |
+
"Slowest Model",
|
| 1230 |
+
f"{slowest_model}",
|
| 1231 |
+
f"{processing_times[slowest_model]:.3f}s",
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
with perf_col2:
|
| 1235 |
+
most_confident = max(
|
| 1236 |
+
comparison_results.keys(),
|
| 1237 |
+
key=lambda k: comparison_results[k]["confidence"],
|
| 1238 |
+
)
|
| 1239 |
+
least_confident = min(
|
| 1240 |
+
comparison_results.keys(),
|
| 1241 |
+
key=lambda k: comparison_results[k]["confidence"],
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
st.metric(
|
| 1245 |
+
"Most Confident",
|
| 1246 |
+
f"{most_confident}",
|
| 1247 |
+
f"{comparison_results[most_confident]['confidence']:.3f}",
|
| 1248 |
+
)
|
| 1249 |
+
st.metric(
|
| 1250 |
+
"Least Confident",
|
| 1251 |
+
f"{least_confident}",
|
| 1252 |
+
f"{comparison_results[least_confident]['confidence']:.3f}",
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
# Store results in session state for potential export
|
| 1256 |
+
# Store results in session state for potential export
|
| 1257 |
+
st.session_state["last_comparison_results"] = {
|
| 1258 |
+
"filename": filename,
|
| 1259 |
+
"modality": modality,
|
| 1260 |
+
"models": comparison_results,
|
| 1261 |
+
"summary": {
|
| 1262 |
+
"agreement": len(unique_predictions) == 1,
|
| 1263 |
+
"avg_processing_time": avg_time,
|
| 1264 |
+
"fastest_model": fastest_model,
|
| 1265 |
+
"most_confident": most_confident,
|
| 1266 |
+
},
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
except Exception as e:
|
| 1270 |
+
st.error(f"Error during comparison: {str(e)}")
|
| 1271 |
+
|
| 1272 |
+
# Show recent comparison results if available
|
| 1273 |
+
elif "last_comparison_results" in st.session_state:
|
| 1274 |
+
st.info(
|
| 1275 |
+
"Previous comparison results available. Upload a new file or select a sample to run new comparison."
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
# Show comparison history
|
| 1279 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
| 1280 |
+
if comparison_stats:
|
| 1281 |
+
st.markdown("#### Comparison History")
|
| 1282 |
+
|
| 1283 |
+
with st.expander("View detailed comparison statistics", expanded=False):
|
| 1284 |
+
# Show model statistics table
|
| 1285 |
+
stats_data = []
|
| 1286 |
+
for model_name, stats in comparison_stats.items():
|
| 1287 |
+
row = {
|
| 1288 |
+
"Model": model_name,
|
| 1289 |
+
"Total Predictions": stats["total_predictions"],
|
| 1290 |
+
"Avg Confidence": f"{stats['avg_confidence']:.3f}",
|
| 1291 |
+
"Avg Processing Time": f"{stats['avg_processing_time']:.3f}s",
|
| 1292 |
+
"Accuracy": (
|
| 1293 |
+
f"{stats['accuracy']:.3f}"
|
| 1294 |
+
if stats["accuracy"] is not None
|
| 1295 |
+
else "N/A"
|
| 1296 |
+
),
|
| 1297 |
+
}
|
| 1298 |
+
stats_data.append(row)
|
| 1299 |
+
|
| 1300 |
+
if stats_data:
|
| 1301 |
+
import pandas as pd
|
| 1302 |
+
|
| 1303 |
+
stats_df = pd.DataFrame(stats_data)
|
| 1304 |
+
st.dataframe(stats_df, use_container_width=True)
|
| 1305 |
+
|
| 1306 |
+
# Show agreement matrix if multiple models
|
| 1307 |
+
agreement_matrix = ResultsManager.get_agreement_matrix()
|
| 1308 |
+
if not agreement_matrix.empty and len(agreement_matrix) > 1:
|
| 1309 |
+
st.markdown("**Model Agreement Matrix**")
|
| 1310 |
+
st.dataframe(agreement_matrix.round(3), use_container_width=True)
|
| 1311 |
+
|
| 1312 |
+
# Plot agreement heatmap
|
| 1313 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 1314 |
+
im = ax.imshow(
|
| 1315 |
+
agreement_matrix.values, cmap="RdYlGn", vmin=0, vmax=1
|
| 1316 |
+
)
|
| 1317 |
+
|
| 1318 |
+
# Add text annotations
|
| 1319 |
+
for i in range(len(agreement_matrix)):
|
| 1320 |
+
for j in range(len(agreement_matrix.columns)):
|
| 1321 |
+
text = ax.text(
|
| 1322 |
+
j,
|
| 1323 |
+
i,
|
| 1324 |
+
f"{agreement_matrix.iloc[i, j]:.2f}",
|
| 1325 |
+
ha="center",
|
| 1326 |
+
va="center",
|
| 1327 |
+
color="black",
|
| 1328 |
+
)
|
| 1329 |
+
|
| 1330 |
+
ax.set_xticks(range(len(agreement_matrix.columns)))
|
| 1331 |
+
ax.set_yticks(range(len(agreement_matrix)))
|
| 1332 |
+
ax.set_xticklabels(agreement_matrix.columns, rotation=45)
|
| 1333 |
+
ax.set_yticklabels(agreement_matrix.index)
|
| 1334 |
+
ax.set_title("Model Agreement Matrix")
|
| 1335 |
+
|
| 1336 |
+
plt.colorbar(im, ax=ax, label="Agreement Rate")
|
| 1337 |
+
plt.tight_layout()
|
| 1338 |
+
st.pyplot(fig)
|
| 1339 |
+
|
| 1340 |
+
# Export functionality
|
| 1341 |
+
if "last_comparison_results" in st.session_state:
|
| 1342 |
+
st.markdown("##### Export Results")
|
| 1343 |
+
|
| 1344 |
+
export_col1, export_col2 = st.columns(2)
|
| 1345 |
+
|
| 1346 |
+
with export_col1:
|
| 1347 |
+
if st.button("📥 Export Comparison (JSON)"):
|
| 1348 |
+
import json
|
| 1349 |
+
|
| 1350 |
+
results = st.session_state["last_comparison_results"]
|
| 1351 |
+
json_str = json.dumps(results, indent=2, default=str)
|
| 1352 |
+
st.download_button(
|
| 1353 |
+
label="Download JSON",
|
| 1354 |
+
data=json_str,
|
| 1355 |
+
file_name=f"comparison_{results['filename'].split('.')[0]}.json",
|
| 1356 |
+
mime="application/json",
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
with export_col2:
|
| 1360 |
+
if st.button("📊 Export Full Report"):
|
| 1361 |
+
report = ResultsManager.export_comparison_report()
|
| 1362 |
+
st.download_button(
|
| 1363 |
+
label="Download Full Report",
|
| 1364 |
+
data=report,
|
| 1365 |
+
file_name="model_comparison_report.json",
|
| 1366 |
+
mime="application/json",
|
| 1367 |
+
)
|
| 1368 |
+
|
| 1369 |
+
|
| 1370 |
+
# //////////////////////////////////////////
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
def render_performance_tab():
|
| 1374 |
+
"""Render the performance tracking and analysis tab."""
|
| 1375 |
+
from utils.performance_tracker import display_performance_dashboard
|
| 1376 |
+
|
| 1377 |
+
display_performance_dashboard()
|
sample_data/ftir-stable-1.txt
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sample FTIR spectrum data - Stable polymer
|
| 2 |
+
# Wavenumber (cm^-1) Absorbance
|
| 3 |
+
400.0 0.045
|
| 4 |
+
450.0 0.048
|
| 5 |
+
500.0 0.052
|
| 6 |
+
550.0 0.056
|
| 7 |
+
600.0 0.061
|
| 8 |
+
650.0 0.065
|
| 9 |
+
700.0 0.070
|
| 10 |
+
750.0 0.075
|
| 11 |
+
800.0 0.082
|
| 12 |
+
850.0 0.089
|
| 13 |
+
900.0 0.096
|
| 14 |
+
950.0 0.104
|
| 15 |
+
1000.0 0.112
|
| 16 |
+
1050.0 0.121
|
| 17 |
+
1100.0 0.130
|
| 18 |
+
1150.0 0.140
|
| 19 |
+
1200.0 0.151
|
| 20 |
+
1250.0 0.162
|
| 21 |
+
1300.0 0.174
|
| 22 |
+
1350.0 0.187
|
| 23 |
+
1400.0 0.200
|
| 24 |
+
1450.0 0.215
|
| 25 |
+
1500.0 0.230
|
| 26 |
+
1550.0 0.246
|
| 27 |
+
1600.0 0.263
|
| 28 |
+
1650.0 0.281
|
| 29 |
+
1700.0 0.300
|
| 30 |
+
1750.0 0.320
|
| 31 |
+
1800.0 0.341
|
| 32 |
+
1850.0 0.363
|
| 33 |
+
1900.0 0.386
|
| 34 |
+
1950.0 0.410
|
| 35 |
+
2000.0 0.435
|
| 36 |
+
2050.0 0.461
|
| 37 |
+
2100.0 0.488
|
| 38 |
+
2150.0 0.516
|
| 39 |
+
2200.0 0.545
|
| 40 |
+
2250.0 0.575
|
| 41 |
+
2300.0 0.606
|
| 42 |
+
2350.0 0.638
|
| 43 |
+
2400.0 0.671
|
| 44 |
+
2450.0 0.705
|
| 45 |
+
2500.0 0.740
|
| 46 |
+
2550.0 0.776
|
| 47 |
+
2600.0 0.813
|
| 48 |
+
2650.0 0.851
|
| 49 |
+
2700.0 0.890
|
| 50 |
+
2750.0 0.930
|
| 51 |
+
2800.0 0.971
|
| 52 |
+
2850.0 1.013
|
| 53 |
+
2900.0 1.056
|
| 54 |
+
2950.0 1.100
|
| 55 |
+
3000.0 1.145
|
| 56 |
+
3050.0 1.191
|
| 57 |
+
3100.0 1.238
|
| 58 |
+
3150.0 1.286
|
| 59 |
+
3200.0 1.335
|
| 60 |
+
3250.0 1.385
|
| 61 |
+
3300.0 1.436
|
| 62 |
+
3350.0 1.488
|
| 63 |
+
3400.0 1.541
|
| 64 |
+
3450.0 1.595
|
| 65 |
+
3500.0 1.650
|
| 66 |
+
3550.0 1.706
|
| 67 |
+
3600.0 1.763
|
| 68 |
+
3650.0 1.821
|
| 69 |
+
3700.0 1.880
|
| 70 |
+
3750.0 1.940
|
| 71 |
+
3800.0 2.001
|
| 72 |
+
3850.0 2.063
|
| 73 |
+
3900.0 2.126
|
| 74 |
+
3950.0 2.190
|
| 75 |
+
4000.0 2.255
|
sample_data/ftir-weathered-1.txt
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sample FTIR spectrum data - Weathered polymer
|
| 2 |
+
# Wavenumber (cm^-1) Absorbance
|
| 3 |
+
400.0 0.062
|
| 4 |
+
450.0 0.069
|
| 5 |
+
500.0 0.077
|
| 6 |
+
550.0 0.086
|
| 7 |
+
600.0 0.095
|
| 8 |
+
650.0 0.105
|
| 9 |
+
700.0 0.116
|
| 10 |
+
750.0 0.128
|
| 11 |
+
800.0 0.141
|
| 12 |
+
850.0 0.155
|
| 13 |
+
900.0 0.170
|
| 14 |
+
950.0 0.186
|
| 15 |
+
1000.0 0.203
|
| 16 |
+
1050.0 0.221
|
| 17 |
+
1100.0 0.240
|
| 18 |
+
1150.0 0.260
|
| 19 |
+
1200.0 0.281
|
| 20 |
+
1250.0 0.303
|
| 21 |
+
1300.0 0.326
|
| 22 |
+
1350.0 0.350
|
| 23 |
+
1400.0 0.375
|
| 24 |
+
1450.0 0.401
|
| 25 |
+
1500.0 0.428
|
| 26 |
+
1550.0 0.456
|
| 27 |
+
1600.0 0.485
|
| 28 |
+
1650.0 0.515
|
| 29 |
+
1700.0 0.546
|
| 30 |
+
1750.0 0.578
|
| 31 |
+
1800.0 0.611
|
| 32 |
+
1850.0 0.645
|
| 33 |
+
1900.0 0.680
|
| 34 |
+
1950.0 0.716
|
| 35 |
+
2000.0 0.753
|
| 36 |
+
2050.0 0.791
|
| 37 |
+
2100.0 0.830
|
| 38 |
+
2150.0 0.870
|
| 39 |
+
2200.0 0.911
|
| 40 |
+
2250.0 0.953
|
| 41 |
+
2300.0 0.996
|
| 42 |
+
2350.0 1.040
|
| 43 |
+
2400.0 1.085
|
| 44 |
+
2450.0 1.131
|
| 45 |
+
2500.0 1.178
|
| 46 |
+
2550.0 1.226
|
| 47 |
+
2600.0 1.275
|
| 48 |
+
2650.0 1.325
|
| 49 |
+
2700.0 1.376
|
| 50 |
+
2750.0 1.428
|
| 51 |
+
2800.0 1.481
|
| 52 |
+
2850.0 1.535
|
| 53 |
+
2900.0 1.590
|
| 54 |
+
2950.0 1.646
|
| 55 |
+
3000.0 1.703
|
| 56 |
+
3050.0 1.761
|
| 57 |
+
3100.0 1.820
|
| 58 |
+
3150.0 1.880
|
| 59 |
+
3200.0 1.941
|
| 60 |
+
3250.0 2.003
|
| 61 |
+
3300.0 2.066
|
| 62 |
+
3350.0 2.130
|
| 63 |
+
3400.0 2.195
|
| 64 |
+
3450.0 2.261
|
| 65 |
+
3500.0 2.328
|
| 66 |
+
3550.0 2.396
|
| 67 |
+
3600.0 2.465
|
| 68 |
+
3650.0 2.535
|
| 69 |
+
3700.0 2.606
|
| 70 |
+
3750.0 2.678
|
| 71 |
+
3800.0 2.751
|
| 72 |
+
3850.0 2.825
|
| 73 |
+
3900.0 2.900
|
| 74 |
+
3950.0 2.976
|
| 75 |
+
4000.0 3.053
|
sample_data/stable.sample.csv
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wavenumber,intensity
|
| 2 |
+
200.0,1542.3
|
| 3 |
+
205.0,1543.1
|
| 4 |
+
210.0,1544.8
|
| 5 |
+
215.0,1546.2
|
| 6 |
+
220.0,1547.9
|
| 7 |
+
225.0,1549.1
|
| 8 |
+
230.0,1550.4
|
| 9 |
+
235.0,1551.8
|
| 10 |
+
240.0,1553.2
|
| 11 |
+
245.0,1554.6
|
| 12 |
+
250.0,1556.1
|
| 13 |
+
255.0,1557.6
|
| 14 |
+
260.0,1559.1
|
| 15 |
+
265.0,1560.7
|
| 16 |
+
270.0,1562.3
|
| 17 |
+
275.0,1563.9
|
| 18 |
+
280.0,1565.6
|
| 19 |
+
285.0,1567.3
|
| 20 |
+
290.0,1569.0
|
| 21 |
+
295.0,1570.8
|
| 22 |
+
300.0,1572.6
|
scripts/run_inference.py
CHANGED
|
@@ -17,144 +17,447 @@ python scripts/run_inference.py --input ... --arch resnet --weights ... --disabl
|
|
| 17 |
|
| 18 |
import os
|
| 19 |
import sys
|
|
|
|
| 20 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 21 |
|
| 22 |
import argparse
|
| 23 |
import json
|
|
|
|
| 24 |
import logging
|
| 25 |
from pathlib import Path
|
| 26 |
-
from typing import cast
|
| 27 |
from torch import nn
|
|
|
|
| 28 |
|
| 29 |
import numpy as np
|
| 30 |
import torch
|
| 31 |
import torch.nn.functional as F
|
| 32 |
|
| 33 |
-
from models.registry import build, choices
|
| 34 |
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
|
|
|
|
| 35 |
from scripts.plot_spectrum import load_spectrum
|
| 36 |
from scripts.discover_raman_files import label_file
|
| 37 |
|
| 38 |
|
| 39 |
def parse_args():
|
| 40 |
-
p = argparse.ArgumentParser(
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
p.add_argument(
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Default = ON; use disable- flags to turn steps off explicitly.
|
| 47 |
-
p.add_argument(
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).")
|
| 52 |
-
p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).")
|
| 53 |
return p.parse_args()
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
def _load_state_dict_safe(path: str):
|
| 57 |
"""Load a state dict safely across torch versions & checkpoint formats."""
|
| 58 |
try:
|
| 59 |
obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
|
| 60 |
except TypeError:
|
| 61 |
obj = torch.load(path, map_location="cpu") # fallback for older torch
|
| 62 |
-
|
| 63 |
# Accept either a plain state_dict or a checkpoint dict that contains one
|
| 64 |
if isinstance(obj, dict):
|
| 65 |
for k in ("state_dict", "model_state_dict", "model"):
|
| 66 |
if k in obj and isinstance(obj[k], dict):
|
| 67 |
obj = obj[k]
|
| 68 |
break
|
| 69 |
-
|
| 70 |
if not isinstance(obj, dict):
|
| 71 |
raise ValueError(
|
| 72 |
"Loaded object is not a state_dict or checkpoint with a state_dict. "
|
| 73 |
f"Type={type(obj)} from file={path}"
|
| 74 |
)
|
| 75 |
-
|
| 76 |
# Strip DataParallel 'module.' prefixes if present
|
| 77 |
if any(key.startswith("module.") for key in obj.keys()):
|
| 78 |
obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
|
| 79 |
-
|
| 80 |
return obj
|
| 81 |
|
| 82 |
|
| 83 |
-
|
| 84 |
-
logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
|
| 85 |
-
args = parse_args()
|
| 86 |
|
| 87 |
-
in_path = Path(args.input)
|
| 88 |
-
if not in_path.exists():
|
| 89 |
-
raise FileNotFoundError(f"Input file not found: {in_path}")
|
| 90 |
|
| 91 |
-
|
| 92 |
-
x_raw,
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
#
|
| 97 |
_, y_proc = preprocess_spectrum(
|
| 98 |
-
|
| 99 |
-
|
| 100 |
target_len=args.target_len,
|
|
|
|
| 101 |
do_baseline=not args.disable_baseline,
|
| 102 |
do_smooth=not args.disable_smooth,
|
| 103 |
do_normalize=not args.disable_normalize,
|
| 104 |
out_dtype="float32",
|
| 105 |
)
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
state = _load_state_dict_safe(args.weights)
|
| 111 |
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 112 |
if missing or unexpected:
|
| 113 |
-
logging.info(
|
|
|
|
|
|
|
| 114 |
|
| 115 |
model.eval()
|
| 116 |
|
| 117 |
-
#
|
| 118 |
x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
|
| 119 |
|
| 120 |
with torch.no_grad():
|
| 121 |
-
logits = model(x_tensor).float().cpu()
|
| 122 |
probs = F.softmax(logits, dim=1)
|
| 123 |
|
|
|
|
| 124 |
probs_np = probs.numpy().ravel().tolist()
|
| 125 |
logits_np = logits.numpy().ravel().tolist()
|
| 126 |
pred_label = int(np.argmax(probs_np))
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
"preprocessing": {
|
| 142 |
-
"baseline": not args.disable_baseline,
|
| 143 |
-
"smooth": not args.disable_smooth,
|
| 144 |
-
"normalize": not args.disable_normalize,
|
| 145 |
-
},
|
| 146 |
-
"predicted_label": pred_label,
|
| 147 |
-
"true_label": true_label,
|
| 148 |
"probs": probs_np,
|
| 149 |
"logits": logits_np,
|
|
|
|
| 150 |
}
|
| 151 |
|
| 152 |
-
with open(out_path, "w", encoding="utf-8") as f:
|
| 153 |
-
json.dump(result, f, indent=2)
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
if __name__ == "__main__":
|
|
|
|
| 17 |
|
| 18 |
import os
|
| 19 |
import sys
|
| 20 |
+
|
| 21 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 22 |
|
| 23 |
import argparse
|
| 24 |
import json
|
| 25 |
+
import csv
|
| 26 |
import logging
|
| 27 |
from pathlib import Path
|
| 28 |
+
from typing import cast, Dict, List, Any
|
| 29 |
from torch import nn
|
| 30 |
+
import time
|
| 31 |
|
| 32 |
import numpy as np
|
| 33 |
import torch
|
| 34 |
import torch.nn.functional as F
|
| 35 |
|
| 36 |
+
from models.registry import build, choices, build_multiple, validate_model_list
|
| 37 |
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
|
| 38 |
+
from utils.multifile import parse_spectrum_data, detect_file_format
|
| 39 |
from scripts.plot_spectrum import load_spectrum
|
| 40 |
from scripts.discover_raman_files import label_file
|
| 41 |
|
| 42 |
|
| 43 |
def parse_args():
|
| 44 |
+
p = argparse.ArgumentParser(
|
| 45 |
+
description="Raman/FTIR spectrum inference with multi-model support."
|
| 46 |
+
)
|
| 47 |
+
p.add_argument(
|
| 48 |
+
"--input",
|
| 49 |
+
required=True,
|
| 50 |
+
help="Path to spectrum file (.txt, .csv, .json) or directory for batch processing.",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Model selection - either single or multiple
|
| 54 |
+
group = p.add_mutually_exclusive_group(required=True)
|
| 55 |
+
group.add_argument(
|
| 56 |
+
"--arch", choices=choices(), help="Single model architecture key."
|
| 57 |
+
)
|
| 58 |
+
group.add_argument(
|
| 59 |
+
"--models",
|
| 60 |
+
help="Comma-separated list of models for comparison (e.g., 'figure2,resnet,resnet18vision').",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
p.add_argument(
|
| 64 |
+
"--weights",
|
| 65 |
+
help="Path to model weights (.pth). For multi-model, use pattern with {model} placeholder.",
|
| 66 |
+
)
|
| 67 |
+
p.add_argument(
|
| 68 |
+
"--target-len",
|
| 69 |
+
type=int,
|
| 70 |
+
default=TARGET_LENGTH,
|
| 71 |
+
help="Resample length (default: 500).",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Modality support
|
| 75 |
+
p.add_argument(
|
| 76 |
+
"--modality",
|
| 77 |
+
choices=["raman", "ftir"],
|
| 78 |
+
default="raman",
|
| 79 |
+
help="Spectroscopy modality for preprocessing (default: raman).",
|
| 80 |
+
)
|
| 81 |
|
| 82 |
# Default = ON; use disable- flags to turn steps off explicitly.
|
| 83 |
+
p.add_argument(
|
| 84 |
+
"--disable-baseline", action="store_true", help="Disable baseline correction."
|
| 85 |
+
)
|
| 86 |
+
p.add_argument(
|
| 87 |
+
"--disable-smooth",
|
| 88 |
+
action="store_true",
|
| 89 |
+
help="Disable Savitzky–Golay smoothing.",
|
| 90 |
+
)
|
| 91 |
+
p.add_argument(
|
| 92 |
+
"--disable-normalize",
|
| 93 |
+
action="store_true",
|
| 94 |
+
help="Disable min-max normalization.",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
p.add_argument(
|
| 98 |
+
"--output",
|
| 99 |
+
default=None,
|
| 100 |
+
help="Output path - JSON for single file, CSV for multi-model comparison.",
|
| 101 |
+
)
|
| 102 |
+
p.add_argument(
|
| 103 |
+
"--output-format",
|
| 104 |
+
choices=["json", "csv"],
|
| 105 |
+
default="json",
|
| 106 |
+
help="Output format for results.",
|
| 107 |
+
)
|
| 108 |
+
p.add_argument(
|
| 109 |
+
"--device",
|
| 110 |
+
default="cpu",
|
| 111 |
+
choices=["cpu", "cuda"],
|
| 112 |
+
help="Compute device (default: cpu).",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# File format options
|
| 116 |
+
p.add_argument(
|
| 117 |
+
"--file-format",
|
| 118 |
+
choices=["auto", "txt", "csv", "json"],
|
| 119 |
+
default="auto",
|
| 120 |
+
help="Input file format (auto-detect by default).",
|
| 121 |
+
)
|
| 122 |
|
|
|
|
|
|
|
| 123 |
return p.parse_args()
|
| 124 |
|
| 125 |
|
| 126 |
+
# /////////////////////////////////////////////////////////
|
| 127 |
+
|
| 128 |
+
|
| 129 |
def _load_state_dict_safe(path: str):
|
| 130 |
"""Load a state dict safely across torch versions & checkpoint formats."""
|
| 131 |
try:
|
| 132 |
obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
|
| 133 |
except TypeError:
|
| 134 |
obj = torch.load(path, map_location="cpu") # fallback for older torch
|
|
|
|
| 135 |
# Accept either a plain state_dict or a checkpoint dict that contains one
|
| 136 |
if isinstance(obj, dict):
|
| 137 |
for k in ("state_dict", "model_state_dict", "model"):
|
| 138 |
if k in obj and isinstance(obj[k], dict):
|
| 139 |
obj = obj[k]
|
| 140 |
break
|
|
|
|
| 141 |
if not isinstance(obj, dict):
|
| 142 |
raise ValueError(
|
| 143 |
"Loaded object is not a state_dict or checkpoint with a state_dict. "
|
| 144 |
f"Type={type(obj)} from file={path}"
|
| 145 |
)
|
|
|
|
| 146 |
# Strip DataParallel 'module.' prefixes if present
|
| 147 |
if any(key.startswith("module.") for key in obj.keys()):
|
| 148 |
obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
|
|
|
|
| 149 |
return obj
|
| 150 |
|
| 151 |
|
| 152 |
+
# /////////////////////////////////////////////////////////
|
|
|
|
|
|
|
| 153 |
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
def run_single_model_inference(
|
| 156 |
+
x_raw: np.ndarray,
|
| 157 |
+
y_raw: np.ndarray,
|
| 158 |
+
model_name: str,
|
| 159 |
+
weights_path: str,
|
| 160 |
+
args: argparse.Namespace,
|
| 161 |
+
device: torch.device,
|
| 162 |
+
) -> Dict[str, Any]:
|
| 163 |
+
"""Run inference with a single model."""
|
| 164 |
+
start_time = time.time()
|
| 165 |
|
| 166 |
+
# Preprocess spectrum
|
| 167 |
_, y_proc = preprocess_spectrum(
|
| 168 |
+
x_raw,
|
| 169 |
+
y_raw,
|
| 170 |
target_len=args.target_len,
|
| 171 |
+
modality=args.modality,
|
| 172 |
do_baseline=not args.disable_baseline,
|
| 173 |
do_smooth=not args.disable_smooth,
|
| 174 |
do_normalize=not args.disable_normalize,
|
| 175 |
out_dtype="float32",
|
| 176 |
)
|
| 177 |
|
| 178 |
+
# Build model & load weights
|
| 179 |
+
model = cast(nn.Module, build(model_name, args.target_len)).to(device)
|
| 180 |
+
state = _load_state_dict_safe(weights_path)
|
|
|
|
| 181 |
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 182 |
if missing or unexpected:
|
| 183 |
+
logging.info(
|
| 184 |
+
f"Model {model_name}: Loaded with non-strict keys. missing={len(missing)} unexpected={len(unexpected)}"
|
| 185 |
+
)
|
| 186 |
|
| 187 |
model.eval()
|
| 188 |
|
| 189 |
+
# Run inference
|
| 190 |
x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
|
| 191 |
|
| 192 |
with torch.no_grad():
|
| 193 |
+
logits = model(x_tensor).float().cpu()
|
| 194 |
probs = F.softmax(logits, dim=1)
|
| 195 |
|
| 196 |
+
processing_time = time.time() - start_time
|
| 197 |
probs_np = probs.numpy().ravel().tolist()
|
| 198 |
logits_np = logits.numpy().ravel().tolist()
|
| 199 |
pred_label = int(np.argmax(probs_np))
|
| 200 |
|
| 201 |
+
# Map prediction to class name
|
| 202 |
+
class_names = ["Stable", "Weathered"]
|
| 203 |
+
predicted_class = (
|
| 204 |
+
class_names[pred_label]
|
| 205 |
+
if pred_label < len(class_names)
|
| 206 |
+
else f"Class_{pred_label}"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return {
|
| 210 |
+
"model": model_name,
|
| 211 |
+
"prediction": pred_label,
|
| 212 |
+
"predicted_class": predicted_class,
|
| 213 |
+
"confidence": max(probs_np),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
"probs": probs_np,
|
| 215 |
"logits": logits_np,
|
| 216 |
+
"processing_time": processing_time,
|
| 217 |
}
|
| 218 |
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
# /////////////////////////////////////////////////////////
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def run_multi_model_inference(
|
| 224 |
+
x_raw: np.ndarray,
|
| 225 |
+
y_raw: np.ndarray,
|
| 226 |
+
model_names: List[str],
|
| 227 |
+
args: argparse.Namespace,
|
| 228 |
+
device: torch.device,
|
| 229 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 230 |
+
"""Run inference with multiple models for comparison."""
|
| 231 |
+
results = {}
|
| 232 |
+
|
| 233 |
+
for model_name in model_names:
|
| 234 |
+
try:
|
| 235 |
+
# Generate weights path - either use pattern or assume same weights for all
|
| 236 |
+
if args.weights and "{model}" in args.weights:
|
| 237 |
+
weights_path = args.weights.format(model=model_name)
|
| 238 |
+
elif args.weights:
|
| 239 |
+
weights_path = args.weights
|
| 240 |
+
else:
|
| 241 |
+
# Default weights path pattern
|
| 242 |
+
weights_path = f"outputs/{model_name}_model.pth"
|
| 243 |
+
|
| 244 |
+
if not Path(weights_path).exists():
|
| 245 |
+
logging.warning(f"Weights not found for {model_name}: {weights_path}")
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
result = run_single_model_inference(
|
| 249 |
+
x_raw, y_raw, model_name, weights_path, args, device
|
| 250 |
+
)
|
| 251 |
+
results[model_name] = result
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logging.error(f"Failed to run inference with {model_name}: {str(e)}")
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
return results
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# /////////////////////////////////////////////////////////
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def save_results(
|
| 264 |
+
results: Dict[str, Any], output_path: Path, format: str = "json"
|
| 265 |
+
) -> None:
|
| 266 |
+
"""Save results to file in specified format"""
|
| 267 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
|
| 269 |
+
if format == "json":
|
| 270 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 271 |
+
json.dump(results, f, indent=2)
|
| 272 |
+
elif format == "csv":
|
| 273 |
+
# Convert to tabular format for CSV
|
| 274 |
+
if "models" in results: # Multi-model results
|
| 275 |
+
rows = []
|
| 276 |
+
for model_name, model_result in results["models"].items():
|
| 277 |
+
row = {
|
| 278 |
+
"model": model_name,
|
| 279 |
+
"prediction": model_result["prediction"],
|
| 280 |
+
"predicted_class": model_result["predicted_class"],
|
| 281 |
+
"confidence": model_result["confidence"],
|
| 282 |
+
"processing_time": model_result["processing_time"],
|
| 283 |
+
}
|
| 284 |
+
# Add individual class probabilities
|
| 285 |
+
if "probs" in model_result:
|
| 286 |
+
for i, prob in enumerate(model_result["probs"]):
|
| 287 |
+
row[f"prob_class_{i}"] = prob
|
| 288 |
+
rows.append(row)
|
| 289 |
+
|
| 290 |
+
# Write CSV
|
| 291 |
+
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 292 |
+
if rows:
|
| 293 |
+
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
| 294 |
+
writer.writeheader()
|
| 295 |
+
writer.writerows(rows)
|
| 296 |
+
else: # Single model result
|
| 297 |
+
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 298 |
+
writer = csv.DictWriter(f, fieldnames=results.keys())
|
| 299 |
+
writer.writeheader()
|
| 300 |
+
writer.writerow(results)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def main():
|
| 304 |
+
logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
|
| 305 |
+
args = parse_args()
|
| 306 |
+
|
| 307 |
+
# Input validation
|
| 308 |
+
in_path = Path(args.input)
|
| 309 |
+
if not in_path.exists():
|
| 310 |
+
raise FileNotFoundError(f"Input file not found: {in_path}")
|
| 311 |
+
|
| 312 |
+
# Determine if this is single or multi-model inference
|
| 313 |
+
if args.models:
|
| 314 |
+
model_names = [m.strip() for m in args.models.split(",")]
|
| 315 |
+
model_names = validate_model_list(model_names)
|
| 316 |
+
if not model_names:
|
| 317 |
+
raise ValueError(f"No valid models found in: {args.models}")
|
| 318 |
+
multi_model = True
|
| 319 |
+
else:
|
| 320 |
+
model_names = [args.arch]
|
| 321 |
+
multi_model = False
|
| 322 |
+
|
| 323 |
+
# Load and parse spectrum data
|
| 324 |
+
if args.file_format == "auto":
|
| 325 |
+
file_format = None # Auto-detect
|
| 326 |
+
else:
|
| 327 |
+
file_format = args.file_format
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
# Read file content
|
| 331 |
+
with open(in_path, "r", encoding="utf-8") as f:
|
| 332 |
+
content = f.read()
|
| 333 |
+
|
| 334 |
+
# Parse spectrum data with format detection
|
| 335 |
+
x_raw, y_raw = parse_spectrum_data(content, str(in_path))
|
| 336 |
+
x_raw = np.array(x_raw, dtype=np.float32)
|
| 337 |
+
y_raw = np.array(y_raw, dtype=np.float32)
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
x_raw, y_raw = load_spectrum(str(in_path))
|
| 341 |
+
x_raw = np.array(x_raw, dtype=np.float32)
|
| 342 |
+
y_raw = np.array(y_raw, dtype=np.float32)
|
| 343 |
+
logging.warning(
|
| 344 |
+
f"Failed to parse with new parser, falling back to original: {e}"
|
| 345 |
+
)
|
| 346 |
+
x_raw, y_raw = load_spectrum(str(in_path))
|
| 347 |
+
|
| 348 |
+
if len(x_raw) < 10:
|
| 349 |
+
raise ValueError("Input spectrum has too few points (<10).")
|
| 350 |
+
|
| 351 |
+
# Setup device
|
| 352 |
+
device = torch.device(
|
| 353 |
+
args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Run inference
|
| 357 |
+
model_results = {} # Initialize to avoid unbound variable error
|
| 358 |
+
if multi_model:
|
| 359 |
+
model_results = run_multi_model_inference(
|
| 360 |
+
np.array(x_raw, dtype=np.float32),
|
| 361 |
+
np.array(y_raw, dtype=np.float32),
|
| 362 |
+
model_names,
|
| 363 |
+
args,
|
| 364 |
+
device,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Get ground truth if available
|
| 368 |
+
true_label = label_file(str(in_path))
|
| 369 |
+
|
| 370 |
+
# Prepare combined results
|
| 371 |
+
results = {
|
| 372 |
+
"input_file": str(in_path),
|
| 373 |
+
"modality": args.modality,
|
| 374 |
+
"models": model_results,
|
| 375 |
+
"true_label": true_label,
|
| 376 |
+
"preprocessing": {
|
| 377 |
+
"baseline": not args.disable_baseline,
|
| 378 |
+
"smooth": not args.disable_smooth,
|
| 379 |
+
"normalize": not args.disable_normalize,
|
| 380 |
+
"target_len": args.target_len,
|
| 381 |
+
},
|
| 382 |
+
"comparison": {
|
| 383 |
+
"total_models": len(model_results),
|
| 384 |
+
"agreements": (
|
| 385 |
+
sum(
|
| 386 |
+
1
|
| 387 |
+
for i, (_, r1) in enumerate(model_results.items())
|
| 388 |
+
for j, (_, r2) in enumerate(
|
| 389 |
+
list(model_results.items())[i + 1 :]
|
| 390 |
+
)
|
| 391 |
+
if r1["prediction"] == r2["prediction"]
|
| 392 |
+
)
|
| 393 |
+
if len(model_results) > 1
|
| 394 |
+
else 0
|
| 395 |
+
),
|
| 396 |
+
},
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
# Default output path for multi-model
|
| 400 |
+
default_output = (
|
| 401 |
+
Path("outputs")
|
| 402 |
+
/ "inference"
|
| 403 |
+
/ f"{in_path.stem}_comparison.{args.output_format}"
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
else:
|
| 407 |
+
# Single model inference
|
| 408 |
+
model_result = run_single_model_inference(
|
| 409 |
+
x_raw, y_raw, model_names[0], args.weights, args, device
|
| 410 |
+
)
|
| 411 |
+
true_label = label_file(str(in_path))
|
| 412 |
+
|
| 413 |
+
results = {
|
| 414 |
+
"input_file": str(in_path),
|
| 415 |
+
"modality": args.modality,
|
| 416 |
+
"arch": model_names[0],
|
| 417 |
+
"weights": str(args.weights),
|
| 418 |
+
"target_len": args.target_len,
|
| 419 |
+
"preprocessing": {
|
| 420 |
+
"baseline": not args.disable_baseline,
|
| 421 |
+
"smooth": not args.disable_smooth,
|
| 422 |
+
"normalize": not args.disable_normalize,
|
| 423 |
+
},
|
| 424 |
+
"predicted_label": model_result["prediction"],
|
| 425 |
+
"predicted_class": model_result["predicted_class"],
|
| 426 |
+
"true_label": true_label,
|
| 427 |
+
"confidence": model_result["confidence"],
|
| 428 |
+
"probs": model_result["probs"],
|
| 429 |
+
"logits": model_result["logits"],
|
| 430 |
+
"processing_time": model_result["processing_time"],
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
# Default output path for single model
|
| 434 |
+
default_output = (
|
| 435 |
+
Path("outputs")
|
| 436 |
+
/ "inference"
|
| 437 |
+
/ f"{in_path.stem}_{model_names[0]}.{args.output_format}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Save results
|
| 441 |
+
output_path = Path(args.output) if args.output else default_output
|
| 442 |
+
save_results(results, output_path, args.output_format)
|
| 443 |
+
|
| 444 |
+
# Log summary
|
| 445 |
+
if multi_model:
|
| 446 |
+
logging.info(
|
| 447 |
+
f"Multi-model inference completed with {len(model_results)} models"
|
| 448 |
+
)
|
| 449 |
+
for model_name, result in model_results.items():
|
| 450 |
+
logging.info(
|
| 451 |
+
f"{model_name}: {result['predicted_class']} (confidence: {result['confidence']:.3f})"
|
| 452 |
+
)
|
| 453 |
+
logging.info(f"Results saved to {output_path}")
|
| 454 |
+
else:
|
| 455 |
+
logging.info(
|
| 456 |
+
f"Predicted Label: {results['predicted_label']} ({results['predicted_class']})"
|
| 457 |
+
)
|
| 458 |
+
logging.info(f"Confidence: {results['confidence']:.3f}")
|
| 459 |
+
logging.info(f"True Label: {results['true_label']}")
|
| 460 |
+
logging.info(f"Result saved to {output_path}")
|
| 461 |
|
| 462 |
|
| 463 |
if __name__ == "__main__":
|
tests/test_ftir_preprocessing.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for FTIR preprocessing functionality."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utils.preprocessing import (
|
| 6 |
+
preprocess_spectrum,
|
| 7 |
+
validate_spectrum_range,
|
| 8 |
+
get_modality_info,
|
| 9 |
+
MODALITY_RANGES,
|
| 10 |
+
MODALITY_PARAMS,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_modality_ranges():
|
| 15 |
+
"""Test that modality ranges are correctly defined."""
|
| 16 |
+
assert "raman" in MODALITY_RANGES
|
| 17 |
+
assert "ftir" in MODALITY_RANGES
|
| 18 |
+
|
| 19 |
+
raman_range = MODALITY_RANGES["raman"]
|
| 20 |
+
ftir_range = MODALITY_RANGES["ftir"]
|
| 21 |
+
|
| 22 |
+
assert raman_range[0] < raman_range[1] # Valid range
|
| 23 |
+
assert ftir_range[0] < ftir_range[1] # Valid range
|
| 24 |
+
assert ftir_range[0] >= 400 # FTIR starts at 400 cm⁻¹
|
| 25 |
+
assert ftir_range[1] <= 4000 # FTIR ends at 4000 cm⁻¹
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_validate_spectrum_range():
|
| 29 |
+
"""Test spectrum range validation for different modalities."""
|
| 30 |
+
# Test Raman range validation
|
| 31 |
+
raman_x = np.linspace(300, 3500, 100) # Typical Raman range
|
| 32 |
+
assert validate_spectrum_range(raman_x, "raman") == True
|
| 33 |
+
|
| 34 |
+
# Test FTIR range validation
|
| 35 |
+
ftir_x = np.linspace(500, 3800, 100) # Typical FTIR range
|
| 36 |
+
assert validate_spectrum_range(ftir_x, "ftir") == True
|
| 37 |
+
|
| 38 |
+
# Test out-of-range data
|
| 39 |
+
out_of_range_x = np.linspace(50, 150, 100) # Too low for either
|
| 40 |
+
assert validate_spectrum_range(out_of_range_x, "raman") == False
|
| 41 |
+
assert validate_spectrum_range(out_of_range_x, "ftir") == False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_ftir_preprocessing():
|
| 45 |
+
"""Test FTIR-specific preprocessing parameters."""
|
| 46 |
+
# Generate synthetic FTIR spectrum
|
| 47 |
+
x = np.linspace(400, 4000, 200) # FTIR range
|
| 48 |
+
y = np.sin(x / 500) + 0.1 * np.random.randn(len(x)) + 2.0 # Synthetic absorbance
|
| 49 |
+
|
| 50 |
+
# Test FTIR preprocessing
|
| 51 |
+
x_proc, y_proc = preprocess_spectrum(x, y, modality="ftir", target_len=500)
|
| 52 |
+
|
| 53 |
+
assert x_proc.shape == (500,)
|
| 54 |
+
assert y_proc.shape == (500,)
|
| 55 |
+
assert np.all(np.diff(x_proc) > 0) # Monotonic increasing
|
| 56 |
+
assert np.min(y_proc) >= 0.0 # Normalized to [0, 1]
|
| 57 |
+
assert np.max(y_proc) <= 1.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_raman_preprocessing():
|
| 61 |
+
"""Test Raman-specific preprocessing parameters."""
|
| 62 |
+
# Generate synthetic Raman spectrum
|
| 63 |
+
x = np.linspace(200, 3500, 200) # Raman range
|
| 64 |
+
y = np.exp(-(((x - 1500) / 200) ** 2)) + 0.05 * np.random.randn(
|
| 65 |
+
len(x)
|
| 66 |
+
) # Gaussian peak
|
| 67 |
+
|
| 68 |
+
# Test Raman preprocessing
|
| 69 |
+
x_proc, y_proc = preprocess_spectrum(x, y, modality="raman", target_len=500)
|
| 70 |
+
|
| 71 |
+
assert x_proc.shape == (500,)
|
| 72 |
+
assert y_proc.shape == (500,)
|
| 73 |
+
assert np.all(np.diff(x_proc) > 0) # Monotonic increasing
|
| 74 |
+
assert np.min(y_proc) >= 0.0 # Normalized to [0, 1]
|
| 75 |
+
assert np.max(y_proc) <= 1.0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_modality_specific_parameters():
|
| 79 |
+
"""Test that different modalities use different default parameters."""
|
| 80 |
+
x = np.linspace(400, 4000, 200)
|
| 81 |
+
y = np.sin(x / 500) + 1.0
|
| 82 |
+
|
| 83 |
+
# Test that FTIR uses different window length than Raman
|
| 84 |
+
ftir_params = MODALITY_PARAMS["ftir"]
|
| 85 |
+
raman_params = MODALITY_PARAMS["raman"]
|
| 86 |
+
|
| 87 |
+
assert ftir_params["smooth_window"] != raman_params["smooth_window"]
|
| 88 |
+
|
| 89 |
+
# Preprocess with both modalities (should use different parameters)
|
| 90 |
+
x_raman, y_raman = preprocess_spectrum(x, y, modality="raman")
|
| 91 |
+
x_ftir, y_ftir = preprocess_spectrum(x, y, modality="ftir")
|
| 92 |
+
|
| 93 |
+
# Results should be slightly different due to different parameters
|
| 94 |
+
assert not np.allclose(y_raman, y_ftir, rtol=1e-10)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_get_modality_info():
|
| 98 |
+
"""Test modality information retrieval."""
|
| 99 |
+
raman_info = get_modality_info("raman")
|
| 100 |
+
ftir_info = get_modality_info("ftir")
|
| 101 |
+
|
| 102 |
+
assert "range" in raman_info
|
| 103 |
+
assert "params" in raman_info
|
| 104 |
+
assert "range" in ftir_info
|
| 105 |
+
assert "params" in ftir_info
|
| 106 |
+
|
| 107 |
+
# Check that ranges match expected values
|
| 108 |
+
assert raman_info["range"] == MODALITY_RANGES["raman"]
|
| 109 |
+
assert ftir_info["range"] == MODALITY_RANGES["ftir"]
|
| 110 |
+
|
| 111 |
+
# Check that parameters are present
|
| 112 |
+
assert "baseline_degree" in raman_info["params"]
|
| 113 |
+
assert "smooth_window" in ftir_info["params"]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_invalid_modality():
|
| 117 |
+
"""Test handling of invalid modality."""
|
| 118 |
+
x = np.linspace(1000, 2000, 100)
|
| 119 |
+
y = np.sin(x / 100)
|
| 120 |
+
|
| 121 |
+
with pytest.raises(ValueError, match="Unsupported modality"):
|
| 122 |
+
preprocess_spectrum(x, y, modality="invalid")
|
| 123 |
+
|
| 124 |
+
with pytest.raises(ValueError, match="Unknown modality"):
|
| 125 |
+
validate_spectrum_range(x, "invalid")
|
| 126 |
+
|
| 127 |
+
with pytest.raises(ValueError, match="Unknown modality"):
|
| 128 |
+
get_modality_info("invalid")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_modality_parameter_override():
|
| 132 |
+
"""Test that modality defaults can be overridden."""
|
| 133 |
+
x = np.linspace(400, 4000, 100)
|
| 134 |
+
y = np.sin(x / 500) + 1.0
|
| 135 |
+
|
| 136 |
+
# Override FTIR default window length
|
| 137 |
+
custom_window = 21 # Different from FTIR default (13)
|
| 138 |
+
|
| 139 |
+
x_proc, y_proc = preprocess_spectrum(
|
| 140 |
+
x, y, modality="ftir", window_length=custom_window
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
assert x_proc.shape[0] > 0
|
| 144 |
+
assert y_proc.shape[0] > 0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_range_validation_warning():
|
| 148 |
+
"""Test that range validation warnings work correctly."""
|
| 149 |
+
# Create spectrum outside typical FTIR range
|
| 150 |
+
x_bad = np.linspace(100, 300, 50) # Too low for FTIR
|
| 151 |
+
y_bad = np.ones_like(x_bad)
|
| 152 |
+
|
| 153 |
+
# Should still process but with validation disabled
|
| 154 |
+
x_proc, y_proc = preprocess_spectrum(
|
| 155 |
+
x_bad, y_bad, modality="ftir", validate_range=False # Disable validation
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
assert len(x_proc) > 0
|
| 159 |
+
assert len(y_proc) > 0
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_backwards_compatibility():
|
| 163 |
+
"""Test that old preprocessing calls still work (defaults to Raman)."""
|
| 164 |
+
x = np.linspace(1000, 2000, 100)
|
| 165 |
+
y = np.sin(x / 100)
|
| 166 |
+
|
| 167 |
+
# Old style call (should default to Raman)
|
| 168 |
+
x_old, y_old = preprocess_spectrum(x, y)
|
| 169 |
+
|
| 170 |
+
# New style call with explicit Raman
|
| 171 |
+
x_new, y_new = preprocess_spectrum(x, y, modality="raman")
|
| 172 |
+
|
| 173 |
+
# Should be identical
|
| 174 |
+
np.testing.assert_array_equal(x_old, x_new)
|
| 175 |
+
np.testing.assert_array_equal(y_old, y_new)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
pytest.main([__file__])
|
tests/test_multi_format.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for multi-format file parsing functionality."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utils.multifile import (
|
| 6 |
+
parse_spectrum_data,
|
| 7 |
+
detect_file_format,
|
| 8 |
+
parse_json_spectrum,
|
| 9 |
+
parse_csv_spectrum,
|
| 10 |
+
parse_txt_spectrum,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_detect_file_format():
|
| 15 |
+
"""Test automatic file format detection."""
|
| 16 |
+
# JSON detection
|
| 17 |
+
json_content = '{"wavenumbers": [1, 2, 3], "intensities": [0.1, 0.2, 0.3]}'
|
| 18 |
+
assert detect_file_format("test.json", json_content) == "json"
|
| 19 |
+
|
| 20 |
+
# CSV detection
|
| 21 |
+
csv_content = "wavenumber,intensity\n1000,0.5\n1001,0.6"
|
| 22 |
+
assert detect_file_format("test.csv", csv_content) == "csv"
|
| 23 |
+
|
| 24 |
+
# TXT detection (default)
|
| 25 |
+
txt_content = "1000 0.5\n1001 0.6"
|
| 26 |
+
assert detect_file_format("test.txt", txt_content) == "txt"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_parse_json_spectrum():
|
| 30 |
+
"""Test JSON spectrum parsing."""
|
| 31 |
+
# Test object format
|
| 32 |
+
json_content = '{"wavenumbers": [1000, 1001, 1002], "intensities": [0.1, 0.2, 0.3]}'
|
| 33 |
+
x, y = parse_json_spectrum(json_content)
|
| 34 |
+
|
| 35 |
+
expected_x = np.array([1000, 1001, 1002])
|
| 36 |
+
expected_y = np.array([0.1, 0.2, 0.3])
|
| 37 |
+
|
| 38 |
+
np.testing.assert_array_equal(x, expected_x)
|
| 39 |
+
np.testing.assert_array_equal(y, expected_y)
|
| 40 |
+
|
| 41 |
+
# Test alternative key names
|
| 42 |
+
json_content_alt = '{"x": [1000, 1001, 1002], "y": [0.1, 0.2, 0.3]}'
|
| 43 |
+
x_alt, y_alt = parse_json_spectrum(json_content_alt)
|
| 44 |
+
np.testing.assert_array_equal(x_alt, expected_x)
|
| 45 |
+
np.testing.assert_array_equal(y_alt, expected_y)
|
| 46 |
+
|
| 47 |
+
# Test array of objects format
|
| 48 |
+
json_array = """[
|
| 49 |
+
{"wavenumber": 1000, "intensity": 0.1},
|
| 50 |
+
{"wavenumber": 1001, "intensity": 0.2},
|
| 51 |
+
{"wavenumber": 1002, "intensity": 0.3}
|
| 52 |
+
]"""
|
| 53 |
+
x_arr, y_arr = parse_json_spectrum(json_array)
|
| 54 |
+
np.testing.assert_array_equal(x_arr, expected_x)
|
| 55 |
+
np.testing.assert_array_equal(y_arr, expected_y)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_parse_csv_spectrum():
|
| 59 |
+
"""Test CSV spectrum parsing."""
|
| 60 |
+
# Test with headers
|
| 61 |
+
csv_with_headers = """wavenumber,intensity
|
| 62 |
+
1000,0.1
|
| 63 |
+
1001,0.2
|
| 64 |
+
1002,0.3
|
| 65 |
+
1003,0.4
|
| 66 |
+
1004,0.5
|
| 67 |
+
1005,0.6
|
| 68 |
+
1006,0.7
|
| 69 |
+
1007,0.8
|
| 70 |
+
1008,0.9
|
| 71 |
+
1009,1.0
|
| 72 |
+
1010,1.1
|
| 73 |
+
1011,1.2"""
|
| 74 |
+
|
| 75 |
+
x, y = parse_csv_spectrum(csv_with_headers)
|
| 76 |
+
expected_x = np.array(
|
| 77 |
+
[1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011]
|
| 78 |
+
)
|
| 79 |
+
expected_y = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
|
| 80 |
+
|
| 81 |
+
np.testing.assert_array_equal(x, expected_x)
|
| 82 |
+
np.testing.assert_array_equal(y, expected_y)
|
| 83 |
+
|
| 84 |
+
# Test without headers
|
| 85 |
+
csv_no_headers = """1000,0.1
|
| 86 |
+
1001,0.2
|
| 87 |
+
1002,0.3
|
| 88 |
+
1003,0.4
|
| 89 |
+
1004,0.5
|
| 90 |
+
1005,0.6
|
| 91 |
+
1006,0.7
|
| 92 |
+
1007,0.8
|
| 93 |
+
1008,0.9
|
| 94 |
+
1009,1.0
|
| 95 |
+
1010,1.1
|
| 96 |
+
1011,1.2"""
|
| 97 |
+
|
| 98 |
+
x_no_h, y_no_h = parse_csv_spectrum(csv_no_headers)
|
| 99 |
+
np.testing.assert_array_equal(x_no_h, expected_x)
|
| 100 |
+
np.testing.assert_array_equal(y_no_h, expected_y)
|
| 101 |
+
|
| 102 |
+
# Test semicolon delimiter
|
| 103 |
+
csv_semicolon = """1000;0.1
|
| 104 |
+
1001;0.2
|
| 105 |
+
1002;0.3
|
| 106 |
+
1003;0.4
|
| 107 |
+
1004;0.5
|
| 108 |
+
1005;0.6
|
| 109 |
+
1006;0.7
|
| 110 |
+
1007;0.8
|
| 111 |
+
1008;0.9
|
| 112 |
+
1009;1.0
|
| 113 |
+
1010;1.1
|
| 114 |
+
1011;1.2"""
|
| 115 |
+
|
| 116 |
+
x_semi, y_semi = parse_csv_spectrum(csv_semicolon)
|
| 117 |
+
np.testing.assert_array_equal(x_semi, expected_x)
|
| 118 |
+
np.testing.assert_array_equal(y_semi, expected_y)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def test_parse_txt_spectrum():
|
| 122 |
+
"""Test TXT spectrum parsing."""
|
| 123 |
+
txt_content = """# Comment line
|
| 124 |
+
1000 0.1
|
| 125 |
+
1001 0.2
|
| 126 |
+
1002 0.3
|
| 127 |
+
1003 0.4
|
| 128 |
+
1004 0.5
|
| 129 |
+
1005 0.6
|
| 130 |
+
1006 0.7
|
| 131 |
+
1007 0.8
|
| 132 |
+
1008 0.9
|
| 133 |
+
1009 1.0
|
| 134 |
+
1010 1.1
|
| 135 |
+
1011 1.2"""
|
| 136 |
+
|
| 137 |
+
x, y = parse_txt_spectrum(txt_content)
|
| 138 |
+
expected_x = np.array(
|
| 139 |
+
[1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011]
|
| 140 |
+
)
|
| 141 |
+
expected_y = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
|
| 142 |
+
|
| 143 |
+
np.testing.assert_array_equal(x, expected_x)
|
| 144 |
+
np.testing.assert_array_equal(y, expected_y)
|
| 145 |
+
|
| 146 |
+
# Test comma-separated
|
| 147 |
+
txt_comma = """1000,0.1
|
| 148 |
+
1001,0.2
|
| 149 |
+
1002,0.3
|
| 150 |
+
1003,0.4
|
| 151 |
+
1004,0.5
|
| 152 |
+
1005,0.6
|
| 153 |
+
1006,0.7
|
| 154 |
+
1007,0.8
|
| 155 |
+
1008,0.9
|
| 156 |
+
1009,1.0
|
| 157 |
+
1010,1.1
|
| 158 |
+
1011,1.2"""
|
| 159 |
+
|
| 160 |
+
x_comma, y_comma = parse_txt_spectrum(txt_comma)
|
| 161 |
+
np.testing.assert_array_equal(x_comma, expected_x)
|
| 162 |
+
np.testing.assert_array_equal(y_comma, expected_y)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def test_parse_spectrum_data_integration():
|
| 166 |
+
"""Test integrated spectrum data parsing with format detection."""
|
| 167 |
+
# Test automatic format detection and parsing
|
| 168 |
+
test_cases = [
|
| 169 |
+
(
|
| 170 |
+
'{"wavenumbers": [1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011], "intensities": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]}',
|
| 171 |
+
"test.json",
|
| 172 |
+
),
|
| 173 |
+
(
|
| 174 |
+
"wavenumber,intensity\n1000,0.1\n1001,0.2\n1002,0.3\n1003,0.4\n1004,0.5\n1005,0.6\n1006,0.7\n1007,0.8\n1008,0.9\n1009,1.0\n1010,1.1\n1011,1.2",
|
| 175 |
+
"test.csv",
|
| 176 |
+
),
|
| 177 |
+
(
|
| 178 |
+
"1000 0.1\n1001 0.2\n1002 0.3\n1003 0.4\n1004 0.5\n1005 0.6\n1006 0.7\n1007 0.8\n1008 0.9\n1009 1.0\n1010 1.1\n1011 1.2",
|
| 179 |
+
"test.txt",
|
| 180 |
+
),
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
for content, filename in test_cases:
|
| 184 |
+
x, y = parse_spectrum_data(content, filename)
|
| 185 |
+
assert len(x) >= 10
|
| 186 |
+
assert len(y) >= 10
|
| 187 |
+
assert len(x) == len(y)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test_insufficient_data_points():
|
| 191 |
+
"""Test handling of insufficient data points."""
|
| 192 |
+
# Test with too few points
|
| 193 |
+
insufficient_data = "1000 0.1\n1001 0.2" # Only 2 points, need at least 10
|
| 194 |
+
|
| 195 |
+
with pytest.raises(ValueError, match="Insufficient data points"):
|
| 196 |
+
parse_txt_spectrum(insufficient_data, "test.txt")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def test_invalid_json():
|
| 200 |
+
"""Test handling of invalid JSON."""
|
| 201 |
+
invalid_json = (
|
| 202 |
+
'{"wavenumbers": [1000, 1001], "intensities": [0.1}' # Missing closing bracket
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
with pytest.raises(ValueError, match="Invalid JSON format"):
|
| 206 |
+
parse_json_spectrum(invalid_json)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def test_empty_file():
|
| 210 |
+
"""Test handling of empty files."""
|
| 211 |
+
empty_content = ""
|
| 212 |
+
|
| 213 |
+
with pytest.raises(ValueError, match="No data lines found"):
|
| 214 |
+
parse_txt_spectrum(empty_content, "empty.txt")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
pytest.main([__file__])
|
utils/multifile.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
| 1 |
-
"""Multi-file processing
|
| 2 |
-
Handles multiple file uploads and iterative processing.
|
|
|
|
| 3 |
|
| 4 |
-
from typing import List, Dict, Any, Tuple, Optional
|
| 5 |
import time
|
| 6 |
import streamlit as st
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from .preprocessing import resample_spectrum
|
| 11 |
from .errors import ErrorHandler, safe_execute
|
|
@@ -13,83 +18,230 @@ from .results_manager import ResultsManager
|
|
| 13 |
from .confidence import calculate_softmax_confidence
|
| 14 |
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
|
| 18 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
| 19 |
-
"""
|
| 20 |
-
Parse spectrum data from text content
|
| 21 |
|
| 22 |
Args:
|
| 23 |
-
|
| 24 |
-
|
| 25 |
|
| 26 |
Returns:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
Raises:
|
| 30 |
-
ValueError: If the data cannot be parsed
|
| 31 |
"""
|
| 32 |
-
try
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
data_lines.append(line)
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
-
# ==Try to parse==
|
| 46 |
-
x_vals, y_vals = [], []
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
x_vals.append(x_val)
|
| 64 |
y_vals.append(y_val)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
except ValueError:
|
| 67 |
ErrorHandler.log_warning(
|
| 68 |
-
f"Could not parse
|
| 69 |
)
|
| 70 |
continue
|
| 71 |
|
| 72 |
-
if len(x_vals) < 10:
|
| 73 |
raise ValueError(
|
| 74 |
f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
|
| 75 |
)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
y = np.array(y_vals)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
raise ValueError("Input data contains NaN values")
|
| 83 |
|
| 84 |
-
# Check monotonic increasing x
|
| 85 |
-
if not np.all(np.diff(x) > 0):
|
| 86 |
-
raise ValueError("Wavenumbers must be strictly increasing")
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
return x, y
|
| 95 |
|
|
@@ -97,6 +249,95 @@ def parse_spectrum_data(
|
|
| 97 |
raise ValueError(f"Failed to parse spectrum data: {str(e)}")
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def process_single_file(
|
| 101 |
filename: str,
|
| 102 |
text_content: str,
|
|
|
|
| 1 |
+
"""Multi-file processing utilities for batch inference.
|
| 2 |
+
Handles multiple file uploads and iterative processing.
|
| 3 |
+
Supports TXT, CSV, and JSON file formats with automatic detection."""
|
| 4 |
|
| 5 |
+
from typing import List, Dict, Any, Tuple, Optional, Union
|
| 6 |
import time
|
| 7 |
import streamlit as st
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
+
import json
|
| 11 |
+
import csv
|
| 12 |
+
import io
|
| 13 |
+
from pathlib import Path
|
| 14 |
|
| 15 |
from .preprocessing import resample_spectrum
|
| 16 |
from .errors import ErrorHandler, safe_execute
|
|
|
|
| 18 |
from .confidence import calculate_softmax_confidence
|
| 19 |
|
| 20 |
|
| 21 |
+
def detect_file_format(filename: str, content: str) -> str:
|
| 22 |
+
"""Automatically detect file format based on exstention and content
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
Args:
|
| 25 |
+
filename: Name of the file
|
| 26 |
+
content: Content of the file
|
| 27 |
|
| 28 |
Returns:
|
| 29 |
+
File format: .'txt', .'csv', .'json'
|
|
|
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
+
# First try by extension
|
| 32 |
+
suffix = Path(filename).suffix.lower()
|
| 33 |
+
if suffix == ".json":
|
| 34 |
+
try:
|
| 35 |
+
json.loads(content)
|
| 36 |
+
return "json"
|
| 37 |
+
except:
|
| 38 |
+
pass
|
| 39 |
+
elif suffix == ".csv":
|
| 40 |
+
return "csv"
|
| 41 |
+
elif suffix == ".txt":
|
| 42 |
+
return "txt"
|
| 43 |
+
|
| 44 |
+
# If extension doesn't match or is unclear, try content detection
|
| 45 |
+
content_stripped = content.strip()
|
| 46 |
+
|
| 47 |
+
# Try JSON
|
| 48 |
+
if content_stripped.startswith(("{", "[")):
|
| 49 |
+
try:
|
| 50 |
+
json.loads(content)
|
| 51 |
+
return "json"
|
| 52 |
+
except:
|
| 53 |
+
pass
|
| 54 |
|
| 55 |
+
# Try CSV (look for commas in first few lines)
|
| 56 |
+
lines = content_stripped.split("\n")[:5]
|
| 57 |
+
comma_count = sum(line.count(",") for line in lines)
|
| 58 |
+
if comma_count > len(lines): # More commas than lines suggests CSV
|
| 59 |
+
return "csv"
|
|
|
|
| 60 |
|
| 61 |
+
# Default to TXT
|
| 62 |
+
return "txt"
|
| 63 |
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
# /////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def parse_json_spectrum(
|
| 69 |
+
content: str, filename: str = "unknown"
|
| 70 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 71 |
+
"""
|
| 72 |
+
Parse spectrum data from JSON format.
|
| 73 |
+
|
| 74 |
+
Expected formats:
|
| 75 |
+
- {"wavenumbers": [...], "intensities": [...]}
|
| 76 |
+
- {"x": [...], "y": [...]}
|
| 77 |
+
- [{"wavenumber": val, "intensity": val}, ...]
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
data = json.load(content)
|
| 82 |
+
|
| 83 |
+
# Format 1: Object with arrays
|
| 84 |
+
if isinstance(data, dict):
|
| 85 |
+
x_key = None
|
| 86 |
+
y_key = None
|
| 87 |
+
|
| 88 |
+
# Try common key names for x-axis
|
| 89 |
+
for key in ["wavenumbers", "wavenumber", "x", "freq", "frequency"]:
|
| 90 |
+
if key in data:
|
| 91 |
+
x_key = key
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
# Try common key names for y-axis
|
| 95 |
+
for key in ["intensities", "intensity", "y", "counts", "absorbance"]:
|
| 96 |
+
if key in data:
|
| 97 |
+
y_key = key
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
if x_key and y_key:
|
| 101 |
+
x_vals = np.array(data[x_key], dtype=float)
|
| 102 |
+
y_vals = np.array(data[y_key], dtype=float)
|
| 103 |
+
return x_vals, y_vals
|
| 104 |
+
|
| 105 |
+
# Format 2: Array of objects
|
| 106 |
+
elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
|
| 107 |
+
x_vals = []
|
| 108 |
+
y_vals = []
|
| 109 |
+
|
| 110 |
+
for item in data:
|
| 111 |
+
# Try to find x and y values
|
| 112 |
+
x_val = None
|
| 113 |
+
y_val = None
|
| 114 |
+
|
| 115 |
+
for x_key in ["wavenumber", "wavenumbers", "x", "freq"]:
|
| 116 |
+
if x_key in item:
|
| 117 |
+
x_val = float(item[x_key])
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
for y_key in ["intensity", "intensities", "y", "counts"]:
|
| 121 |
+
if y_key in item:
|
| 122 |
+
y_val = float(item[y_key])
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
if x_val is not None and y_val is not None:
|
| 126 |
x_vals.append(x_val)
|
| 127 |
y_vals.append(y_val)
|
| 128 |
|
| 129 |
+
if x_vals and y_vals:
|
| 130 |
+
return np.array(x_vals), np.array(y_vals)
|
| 131 |
+
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"JSON format not recognized. Expected wavenumber/intensity pairs."
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
except json.JSONDecodeError as e:
|
| 137 |
+
raise ValueError(f"Invalid JSON format: {str(e)}")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
raise ValueError(f"Failed to parse JSON spectrum: {str(e)}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# /////////////////////////////////////////////////////
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def parse_csv_spectrum(
|
| 146 |
+
content: str, filename: str = "unknown"
|
| 147 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 148 |
+
"""
|
| 149 |
+
Parse spectrum data from CSV format.
|
| 150 |
+
|
| 151 |
+
Handles various CSV formats with headers or without.
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
# Use StringIO to treat string as file-like object
|
| 155 |
+
csv_file = io.StringIO(content)
|
| 156 |
+
|
| 157 |
+
# Try to detect delimiter
|
| 158 |
+
sample = content[:1024]
|
| 159 |
+
delimiter = ","
|
| 160 |
+
if sample.count(";") > sample.count(","):
|
| 161 |
+
delimiter = ";"
|
| 162 |
+
elif sample.count("\t") > sample.count(","):
|
| 163 |
+
delimiter = "\t"
|
| 164 |
+
|
| 165 |
+
# Read CSV
|
| 166 |
+
csv_reader = csv.reader(csv_file, delimiter=delimiter)
|
| 167 |
+
rows = list(csv_reader)
|
| 168 |
+
|
| 169 |
+
if not rows:
|
| 170 |
+
raise ValueError("Empty CSV file")
|
| 171 |
+
|
| 172 |
+
# Check if first row is header
|
| 173 |
+
has_header = False
|
| 174 |
+
try:
|
| 175 |
+
# If first row contains non-numeric data, it's likely a header
|
| 176 |
+
float(rows[0][0])
|
| 177 |
+
float(rows[0][1])
|
| 178 |
+
except (ValueError, IndexError):
|
| 179 |
+
has_header = True
|
| 180 |
+
|
| 181 |
+
data_rows = rows[1:] if has_header else rows
|
| 182 |
+
|
| 183 |
+
# Extract x and y values
|
| 184 |
+
x_vals = []
|
| 185 |
+
y_vals = []
|
| 186 |
+
|
| 187 |
+
for i, row in enumerate(data_rows):
|
| 188 |
+
if len(row) < 2:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
x_val = float(row[0])
|
| 193 |
+
y_val = float(row[1])
|
| 194 |
+
x_vals.append(x_val)
|
| 195 |
+
y_vals.append(y_val)
|
| 196 |
except ValueError:
|
| 197 |
ErrorHandler.log_warning(
|
| 198 |
+
f"Could not parse CSV row {i+1}: {row}", f"Parsing {filename}"
|
| 199 |
)
|
| 200 |
continue
|
| 201 |
|
| 202 |
+
if len(x_vals) < 10:
|
| 203 |
raise ValueError(
|
| 204 |
f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
|
| 205 |
)
|
| 206 |
|
| 207 |
+
return np.array(x_vals), np.array(y_vals)
|
|
|
|
| 208 |
|
| 209 |
+
except Exception as e:
|
| 210 |
+
raise ValueError(f"Failed to parse CSV spectrum: {str(e)}")
|
|
|
|
| 211 |
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
# /////////////////////////////////////////////////////
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def parse_spectrum_data(
|
| 217 |
+
text_content: str, filename: str = "unknown", file_format: Optional[str] = None
|
| 218 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 219 |
+
"""
|
| 220 |
+
Parse spectrum data from text content with automatic format detection.
|
| 221 |
+
Args:
|
| 222 |
+
text_content: Raw text content of the spectrum file
|
| 223 |
+
filename: Name of the file for error reporting
|
| 224 |
+
file_format: Force specific format ('txt', 'csv', 'json') or None for auto-detection
|
| 225 |
+
Returns:
|
| 226 |
+
Tuple of (x_values, y_values) as numpy arrays
|
| 227 |
+
Raises:
|
| 228 |
+
ValueError: If the data cannot be parsed
|
| 229 |
+
"""
|
| 230 |
+
try:
|
| 231 |
+
# Detect format if not specified
|
| 232 |
+
if file_format is None:
|
| 233 |
+
file_format = detect_file_format(filename, text_content)
|
| 234 |
+
|
| 235 |
+
# Parse based on detected/specified format
|
| 236 |
+
if file_format == "json":
|
| 237 |
+
x, y = parse_json_spectrum(text_content, filename)
|
| 238 |
+
elif file_format == "csv":
|
| 239 |
+
x, y = parse_csv_spectrum(text_content, filename)
|
| 240 |
+
else: # Default to TXT format
|
| 241 |
+
x, y = parse_txt_spectrum(text_content, filename)
|
| 242 |
+
|
| 243 |
+
# Common validation for all formats
|
| 244 |
+
validate_spectrum_data(x, y, filename)
|
| 245 |
|
| 246 |
return x, y
|
| 247 |
|
|
|
|
| 249 |
raise ValueError(f"Failed to parse spectrum data: {str(e)}")
|
| 250 |
|
| 251 |
|
| 252 |
+
# /////////////////////////////////////////////////////
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def parse_txt_spectrum(
|
| 256 |
+
content: str, filename: str = "unknown"
|
| 257 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 258 |
+
"""
|
| 259 |
+
Parse spectrum data from TXT format (original implementation).
|
| 260 |
+
"""
|
| 261 |
+
lines = content.strip().split("\n")
|
| 262 |
+
|
| 263 |
+
# ==Remove empty lines and comments==
|
| 264 |
+
data_lines = []
|
| 265 |
+
for line in lines:
|
| 266 |
+
line = line.strip()
|
| 267 |
+
if line and not line.startswith("#") and not line.startswith("%"):
|
| 268 |
+
data_lines.append(line)
|
| 269 |
+
|
| 270 |
+
if not data_lines:
|
| 271 |
+
raise ValueError("No data lines found in file")
|
| 272 |
+
|
| 273 |
+
# ==Try to parse==
|
| 274 |
+
x_vals, y_vals = [], []
|
| 275 |
+
|
| 276 |
+
for i, line in enumerate(data_lines):
|
| 277 |
+
try:
|
| 278 |
+
# Handle different separators
|
| 279 |
+
parts = line.replace(",", " ").split()
|
| 280 |
+
numbers = [
|
| 281 |
+
p
|
| 282 |
+
for p in parts
|
| 283 |
+
if p.replace(".", "", 1)
|
| 284 |
+
.replace("-", "", 1)
|
| 285 |
+
.replace("+", "", 1)
|
| 286 |
+
.isdigit()
|
| 287 |
+
]
|
| 288 |
+
if len(numbers) >= 2:
|
| 289 |
+
x_val = float(numbers[0])
|
| 290 |
+
y_val = float(numbers[1])
|
| 291 |
+
x_vals.append(x_val)
|
| 292 |
+
y_vals.append(y_val)
|
| 293 |
+
|
| 294 |
+
except ValueError:
|
| 295 |
+
ErrorHandler.log_warning(
|
| 296 |
+
f"Could not parse line {i+1}: {line}", f"Parsing {filename}"
|
| 297 |
+
)
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
if len(x_vals) < 10: # ==Need minimum points for interpolation==
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
return np.array(x_vals), np.array(y_vals)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# /////////////////////////////////////////////////////
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
|
| 312 |
+
"""
|
| 313 |
+
Validate parsed spectrum data for common issues.
|
| 314 |
+
"""
|
| 315 |
+
# Check for NaNs
|
| 316 |
+
if np.any(np.isnan(x)) or np.any(np.isnan(y)):
|
| 317 |
+
raise ValueError("Input data contains NaN values")
|
| 318 |
+
|
| 319 |
+
# Check monotonic increasing x (sort if needed)
|
| 320 |
+
if not np.all(np.diff(x) >= 0):
|
| 321 |
+
# Sort by x values if not monotonic
|
| 322 |
+
sort_idx = np.argsort(x)
|
| 323 |
+
x = x[sort_idx]
|
| 324 |
+
y = y[sort_idx]
|
| 325 |
+
ErrorHandler.log_warning(
|
| 326 |
+
"Wavenumbers were not monotonic - data has been sorted",
|
| 327 |
+
f"Parsing {filename}",
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Check reasonable range for spectroscopy
|
| 331 |
+
if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
|
| 332 |
+
ErrorHandler.log_warning(
|
| 333 |
+
f"Unusual wavenumber range: {min(x):.1f} - {max(x):.1f} cm⁻¹",
|
| 334 |
+
f"Parsing {filename}",
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# /////////////////////////////////////////////////////
|
| 339 |
+
|
| 340 |
+
|
| 341 |
def process_single_file(
|
| 342 |
filename: str,
|
| 343 |
text_content: str,
|
utils/performance_tracker.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Performance tracking and logging utilities for POLYMEROS platform."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import sqlite3
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Any, Optional
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import streamlit as st
|
| 12 |
+
from dataclasses import dataclass, asdict
|
| 13 |
+
from contextlib import contextmanager
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class PerformanceMetrics:
|
| 18 |
+
"""Data class for performance metrics."""
|
| 19 |
+
|
| 20 |
+
model_name: str
|
| 21 |
+
prediction_time: float
|
| 22 |
+
preprocessing_time: float
|
| 23 |
+
total_time: float
|
| 24 |
+
memory_usage_mb: float
|
| 25 |
+
accuracy: Optional[float]
|
| 26 |
+
confidence: float
|
| 27 |
+
timestamp: str
|
| 28 |
+
input_size: int
|
| 29 |
+
modality: str
|
| 30 |
+
|
| 31 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 32 |
+
return asdict(self)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PerformanceTracker:
|
| 36 |
+
"""Automatic performance tracking and logging system."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, db_path: str = "outputs/performance_tracking.db"):
|
| 39 |
+
self.db_path = Path(db_path)
|
| 40 |
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
self._init_database()
|
| 42 |
+
|
| 43 |
+
def _init_database(self):
|
| 44 |
+
"""Initialize SQLite database for performance tracking."""
|
| 45 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 46 |
+
conn.execute(
|
| 47 |
+
"""
|
| 48 |
+
CREATE TABLE IF NOT EXISTS performance_metrics (
|
| 49 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 50 |
+
model_name TEXT NOT NULL,
|
| 51 |
+
prediction_time REAL NOT NULL,
|
| 52 |
+
preprocessing_time REAL NOT NULL,
|
| 53 |
+
total_time REAL NOT NULL,
|
| 54 |
+
memory_usage_mb REAL,
|
| 55 |
+
accuracy REAL,
|
| 56 |
+
confidence REAL NOT NULL,
|
| 57 |
+
timestamp TEXT NOT NULL,
|
| 58 |
+
input_size INTEGER NOT NULL,
|
| 59 |
+
modality TEXT NOT NULL
|
| 60 |
+
)
|
| 61 |
+
"""
|
| 62 |
+
)
|
| 63 |
+
conn.commit()
|
| 64 |
+
|
| 65 |
+
def log_performance(self, metrics: PerformanceMetrics):
|
| 66 |
+
"""Log performance metrics to database."""
|
| 67 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 68 |
+
conn.execute(
|
| 69 |
+
"""
|
| 70 |
+
INSERT INTO performance_metrics
|
| 71 |
+
(model_name, prediction_time, preprocessing_time, total_time,
|
| 72 |
+
memory_usage_mb, accuracy, confidence, timestamp, input_size, modality)
|
| 73 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 74 |
+
""",
|
| 75 |
+
(
|
| 76 |
+
metrics.model_name,
|
| 77 |
+
metrics.prediction_time,
|
| 78 |
+
metrics.preprocessing_time,
|
| 79 |
+
metrics.total_time,
|
| 80 |
+
metrics.memory_usage_mb,
|
| 81 |
+
metrics.accuracy,
|
| 82 |
+
metrics.confidence,
|
| 83 |
+
metrics.timestamp,
|
| 84 |
+
metrics.input_size,
|
| 85 |
+
metrics.modality,
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
conn.commit()
|
| 89 |
+
|
| 90 |
+
@contextmanager
|
| 91 |
+
def track_inference(self, model_name: str, modality: str = "raman"):
|
| 92 |
+
"""Context manager for automatic performance tracking."""
|
| 93 |
+
start_time = time.time()
|
| 94 |
+
start_memory = self._get_memory_usage()
|
| 95 |
+
|
| 96 |
+
tracking_data = {
|
| 97 |
+
"model_name": model_name,
|
| 98 |
+
"modality": modality,
|
| 99 |
+
"start_time": start_time,
|
| 100 |
+
"start_memory": start_memory,
|
| 101 |
+
"preprocessing_time": 0.0,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
yield tracking_data
|
| 106 |
+
finally:
|
| 107 |
+
end_time = time.time()
|
| 108 |
+
end_memory = self._get_memory_usage()
|
| 109 |
+
|
| 110 |
+
total_time = end_time - start_time
|
| 111 |
+
memory_usage = max(end_memory - start_memory, 0)
|
| 112 |
+
|
| 113 |
+
# Create metrics object if not provided
|
| 114 |
+
if "metrics" not in tracking_data:
|
| 115 |
+
metrics = PerformanceMetrics(
|
| 116 |
+
model_name=model_name,
|
| 117 |
+
prediction_time=tracking_data.get("prediction_time", total_time),
|
| 118 |
+
preprocessing_time=tracking_data.get("preprocessing_time", 0.0),
|
| 119 |
+
total_time=total_time,
|
| 120 |
+
memory_usage_mb=memory_usage,
|
| 121 |
+
accuracy=tracking_data.get("accuracy"),
|
| 122 |
+
confidence=tracking_data.get("confidence", 0.0),
|
| 123 |
+
timestamp=datetime.now().isoformat(),
|
| 124 |
+
input_size=tracking_data.get("input_size", 0),
|
| 125 |
+
modality=modality,
|
| 126 |
+
)
|
| 127 |
+
self.log_performance(metrics)
|
| 128 |
+
|
| 129 |
+
def _get_memory_usage(self) -> float:
|
| 130 |
+
"""Get current memory usage in MB."""
|
| 131 |
+
try:
|
| 132 |
+
import psutil
|
| 133 |
+
|
| 134 |
+
process = psutil.Process()
|
| 135 |
+
return process.memory_info().rss / 1024 / 1024 # Convert to MB
|
| 136 |
+
except ImportError:
|
| 137 |
+
return 0.0 # psutil not available
|
| 138 |
+
|
| 139 |
+
def get_recent_metrics(self, limit: int = 100) -> List[Dict[str, Any]]:
|
| 140 |
+
"""Get recent performance metrics."""
|
| 141 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 142 |
+
conn.row_factory = sqlite3.Row # Enable column access by name
|
| 143 |
+
cursor = conn.execute(
|
| 144 |
+
"""
|
| 145 |
+
SELECT * FROM performance_metrics
|
| 146 |
+
ORDER BY timestamp DESC
|
| 147 |
+
LIMIT ?
|
| 148 |
+
""",
|
| 149 |
+
(limit,),
|
| 150 |
+
)
|
| 151 |
+
return [dict(row) for row in cursor.fetchall()]
|
| 152 |
+
|
| 153 |
+
def get_model_statistics(self, model_name: Optional[str] = None) -> Dict[str, Any]:
|
| 154 |
+
"""Get statistical summary of model performance."""
|
| 155 |
+
where_clause = "WHERE model_name = ?" if model_name else ""
|
| 156 |
+
params = (model_name,) if model_name else ()
|
| 157 |
+
|
| 158 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 159 |
+
cursor = conn.execute(
|
| 160 |
+
f"""
|
| 161 |
+
SELECT
|
| 162 |
+
model_name,
|
| 163 |
+
COUNT(*) as total_inferences,
|
| 164 |
+
AVG(prediction_time) as avg_prediction_time,
|
| 165 |
+
AVG(preprocessing_time) as avg_preprocessing_time,
|
| 166 |
+
AVG(total_time) as avg_total_time,
|
| 167 |
+
AVG(memory_usage_mb) as avg_memory_usage,
|
| 168 |
+
AVG(confidence) as avg_confidence,
|
| 169 |
+
MIN(total_time) as fastest_inference,
|
| 170 |
+
MAX(total_time) as slowest_inference
|
| 171 |
+
FROM performance_metrics
|
| 172 |
+
{where_clause}
|
| 173 |
+
GROUP BY model_name
|
| 174 |
+
""",
|
| 175 |
+
params,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
results = cursor.fetchall()
|
| 179 |
+
if model_name and results:
|
| 180 |
+
# Return single model stats as dict
|
| 181 |
+
row = results[0]
|
| 182 |
+
return {
|
| 183 |
+
"model_name": row[0],
|
| 184 |
+
"total_inferences": row[1],
|
| 185 |
+
"avg_prediction_time": row[2],
|
| 186 |
+
"avg_preprocessing_time": row[3],
|
| 187 |
+
"avg_total_time": row[4],
|
| 188 |
+
"avg_memory_usage": row[5],
|
| 189 |
+
"avg_confidence": row[6],
|
| 190 |
+
"fastest_inference": row[7],
|
| 191 |
+
"slowest_inference": row[8],
|
| 192 |
+
}
|
| 193 |
+
elif not model_name:
|
| 194 |
+
# Return all models stats as dict of dicts
|
| 195 |
+
return {
|
| 196 |
+
row[0]: {
|
| 197 |
+
"model_name": row[0],
|
| 198 |
+
"total_inferences": row[1],
|
| 199 |
+
"avg_prediction_time": row[2],
|
| 200 |
+
"avg_preprocessing_time": row[3],
|
| 201 |
+
"avg_total_time": row[4],
|
| 202 |
+
"avg_memory_usage": row[5],
|
| 203 |
+
"avg_confidence": row[6],
|
| 204 |
+
"fastest_inference": row[7],
|
| 205 |
+
"slowest_inference": row[8],
|
| 206 |
+
}
|
| 207 |
+
for row in results
|
| 208 |
+
}
|
| 209 |
+
else:
|
| 210 |
+
return {}
|
| 211 |
+
|
| 212 |
+
def create_performance_visualization(self) -> plt.Figure:
|
| 213 |
+
"""Create performance visualization charts."""
|
| 214 |
+
metrics = self.get_recent_metrics(50)
|
| 215 |
+
|
| 216 |
+
if not metrics:
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
|
| 220 |
+
|
| 221 |
+
# Convert to convenient format
|
| 222 |
+
models = [m["model_name"] for m in metrics]
|
| 223 |
+
times = [m["total_time"] for m in metrics]
|
| 224 |
+
confidences = [m["confidence"] for m in metrics]
|
| 225 |
+
timestamps = [datetime.fromisoformat(m["timestamp"]) for m in metrics]
|
| 226 |
+
|
| 227 |
+
# 1. Inference Time Over Time
|
| 228 |
+
ax1.plot(timestamps, times, "o-", alpha=0.7)
|
| 229 |
+
ax1.set_title("Inference Time Over Time")
|
| 230 |
+
ax1.set_ylabel("Time (seconds)")
|
| 231 |
+
ax1.tick_params(axis="x", rotation=45)
|
| 232 |
+
|
| 233 |
+
# 2. Performance by Model
|
| 234 |
+
model_stats = self.get_model_statistics()
|
| 235 |
+
if model_stats:
|
| 236 |
+
model_names = list(model_stats.keys())
|
| 237 |
+
avg_times = [model_stats[m]["avg_total_time"] for m in model_names]
|
| 238 |
+
|
| 239 |
+
ax2.bar(model_names, avg_times, alpha=0.7)
|
| 240 |
+
ax2.set_title("Average Inference Time by Model")
|
| 241 |
+
ax2.set_ylabel("Time (seconds)")
|
| 242 |
+
ax2.tick_params(axis="x", rotation=45)
|
| 243 |
+
|
| 244 |
+
# 3. Confidence Distribution
|
| 245 |
+
ax3.hist(confidences, bins=20, alpha=0.7)
|
| 246 |
+
ax3.set_title("Confidence Score Distribution")
|
| 247 |
+
ax3.set_xlabel("Confidence")
|
| 248 |
+
ax3.set_ylabel("Frequency")
|
| 249 |
+
|
| 250 |
+
# 4. Memory Usage if available
|
| 251 |
+
memory_usage = [
|
| 252 |
+
m["memory_usage_mb"] for m in metrics if m["memory_usage_mb"] is not None
|
| 253 |
+
]
|
| 254 |
+
if memory_usage:
|
| 255 |
+
ax4.plot(range(len(memory_usage)), memory_usage, "o-", alpha=0.7)
|
| 256 |
+
ax4.set_title("Memory Usage")
|
| 257 |
+
ax4.set_xlabel("Inference Number")
|
| 258 |
+
ax4.set_ylabel("Memory (MB)")
|
| 259 |
+
else:
|
| 260 |
+
ax4.text(
|
| 261 |
+
0.5,
|
| 262 |
+
0.5,
|
| 263 |
+
"Memory tracking\nnot available",
|
| 264 |
+
ha="center",
|
| 265 |
+
va="center",
|
| 266 |
+
transform=ax4.transAxes,
|
| 267 |
+
)
|
| 268 |
+
ax4.set_title("Memory Usage")
|
| 269 |
+
|
| 270 |
+
plt.tight_layout()
|
| 271 |
+
return fig
|
| 272 |
+
|
| 273 |
+
def export_metrics(self, format: str = "json") -> str:
|
| 274 |
+
"""Export performance metrics in specified format."""
|
| 275 |
+
metrics = self.get_recent_metrics(1000) # Get more for export
|
| 276 |
+
|
| 277 |
+
if format == "json":
|
| 278 |
+
return json.dumps(metrics, indent=2, default=str)
|
| 279 |
+
elif format == "csv":
|
| 280 |
+
import pandas as pd
|
| 281 |
+
|
| 282 |
+
df = pd.DataFrame(metrics)
|
| 283 |
+
return df.to_csv(index=False)
|
| 284 |
+
else:
|
| 285 |
+
raise ValueError(f"Unsupported format: {format}")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# Global tracker instance
|
| 289 |
+
_tracker = None
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def get_performance_tracker() -> PerformanceTracker:
|
| 293 |
+
"""Get global performance tracker instance."""
|
| 294 |
+
global _tracker
|
| 295 |
+
if _tracker is None:
|
| 296 |
+
_tracker = PerformanceTracker()
|
| 297 |
+
return _tracker
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def display_performance_dashboard():
|
| 301 |
+
"""Display performance tracking dashboard in Streamlit."""
|
| 302 |
+
tracker = get_performance_tracker()
|
| 303 |
+
|
| 304 |
+
st.markdown("### 📈 Performance Dashboard")
|
| 305 |
+
|
| 306 |
+
# Recent metrics summary
|
| 307 |
+
recent_metrics = tracker.get_recent_metrics(20)
|
| 308 |
+
|
| 309 |
+
if not recent_metrics:
|
| 310 |
+
st.info(
|
| 311 |
+
"No performance data available yet. Run some inferences to see metrics."
|
| 312 |
+
)
|
| 313 |
+
return
|
| 314 |
+
|
| 315 |
+
# Summary statistics
|
| 316 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 317 |
+
|
| 318 |
+
total_inferences = len(recent_metrics)
|
| 319 |
+
avg_time = np.mean([m["total_time"] for m in recent_metrics])
|
| 320 |
+
avg_confidence = np.mean([m["confidence"] for m in recent_metrics])
|
| 321 |
+
unique_models = len(set(m["model_name"] for m in recent_metrics))
|
| 322 |
+
|
| 323 |
+
with col1:
|
| 324 |
+
st.metric("Total Inferences", total_inferences)
|
| 325 |
+
with col2:
|
| 326 |
+
st.metric("Avg Time", f"{avg_time:.3f}s")
|
| 327 |
+
with col3:
|
| 328 |
+
st.metric("Avg Confidence", f"{avg_confidence:.3f}")
|
| 329 |
+
with col4:
|
| 330 |
+
st.metric("Models Used", unique_models)
|
| 331 |
+
|
| 332 |
+
# Performance visualization
|
| 333 |
+
fig = tracker.create_performance_visualization()
|
| 334 |
+
if fig:
|
| 335 |
+
st.pyplot(fig)
|
| 336 |
+
|
| 337 |
+
# Model comparison table
|
| 338 |
+
st.markdown("#### Model Performance Comparison")
|
| 339 |
+
model_stats = tracker.get_model_statistics()
|
| 340 |
+
|
| 341 |
+
if model_stats:
|
| 342 |
+
import pandas as pd
|
| 343 |
+
|
| 344 |
+
stats_data = []
|
| 345 |
+
for model_name, stats in model_stats.items():
|
| 346 |
+
stats_data.append(
|
| 347 |
+
{
|
| 348 |
+
"Model": model_name,
|
| 349 |
+
"Total Inferences": stats["total_inferences"],
|
| 350 |
+
"Avg Time (s)": f"{stats['avg_total_time']:.3f}",
|
| 351 |
+
"Avg Confidence": f"{stats['avg_confidence']:.3f}",
|
| 352 |
+
"Fastest (s)": f"{stats['fastest_inference']:.3f}",
|
| 353 |
+
"Slowest (s)": f"{stats['slowest_inference']:.3f}",
|
| 354 |
+
}
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
df = pd.DataFrame(stats_data)
|
| 358 |
+
st.dataframe(df, use_container_width=True)
|
| 359 |
+
|
| 360 |
+
# Export options
|
| 361 |
+
with st.expander("📥 Export Performance Data"):
|
| 362 |
+
col1, col2 = st.columns(2)
|
| 363 |
+
|
| 364 |
+
with col1:
|
| 365 |
+
if st.button("Export JSON"):
|
| 366 |
+
json_data = tracker.export_metrics("json")
|
| 367 |
+
st.download_button(
|
| 368 |
+
"Download JSON",
|
| 369 |
+
json_data,
|
| 370 |
+
"performance_metrics.json",
|
| 371 |
+
"application/json",
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
with col2:
|
| 375 |
+
if st.button("Export CSV"):
|
| 376 |
+
csv_data = tracker.export_metrics("csv")
|
| 377 |
+
st.download_button(
|
| 378 |
+
"Download CSV", csv_data, "performance_metrics.csv", "text/csv"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
# Test the performance tracker
|
| 384 |
+
tracker = PerformanceTracker()
|
| 385 |
+
|
| 386 |
+
# Simulate some metrics
|
| 387 |
+
for i in range(5):
|
| 388 |
+
metrics = PerformanceMetrics(
|
| 389 |
+
model_name=f"test_model_{i%2}",
|
| 390 |
+
prediction_time=0.1 + i * 0.01,
|
| 391 |
+
preprocessing_time=0.05,
|
| 392 |
+
total_time=0.15 + i * 0.01,
|
| 393 |
+
memory_usage_mb=100 + i * 10,
|
| 394 |
+
accuracy=0.8 + i * 0.02,
|
| 395 |
+
confidence=0.7 + i * 0.05,
|
| 396 |
+
timestamp=datetime.now().isoformat(),
|
| 397 |
+
input_size=500,
|
| 398 |
+
modality="raman",
|
| 399 |
+
)
|
| 400 |
+
tracker.log_performance(metrics)
|
| 401 |
+
|
| 402 |
+
print("Performance tracking test completed!")
|
| 403 |
+
print(f"Recent metrics: {len(tracker.get_recent_metrics())}")
|
| 404 |
+
print(f"Model stats: {tracker.get_model_statistics()}")
|
utils/preprocessing.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
Preprocessing utilities for polymer classification app.
|
| 3 |
Adapted from the original scripts/preprocess_dataset.py for Hugging Face Spaces deployment.
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
@@ -9,8 +10,33 @@ from numpy.typing import DTypeLike
|
|
| 9 |
from scipy.interpolate import interp1d
|
| 10 |
from scipy.signal import savgol_filter
|
| 11 |
from scipy.interpolate import interp1d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
TARGET_LENGTH = 500 # Frozen default per PREPROCESSING_BASELINE
|
| 14 |
|
| 15 |
def _ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 16 |
x = np.asarray(x, dtype=float)
|
|
@@ -19,7 +45,10 @@ def _ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarr
|
|
| 19 |
raise ValueError("x and y must be 1D arrays of equal length >= 2")
|
| 20 |
return x, y
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
"""Linear re-sampling onto a uniform grid of length target_len."""
|
| 24 |
x, y = _ensure_1d_equal(x, y)
|
| 25 |
order = np.argsort(x)
|
|
@@ -29,6 +58,7 @@ def resample_spectrum(x: np.ndarray, y: np.ndarray, target_len: int = TARGET_LEN
|
|
| 29 |
y_new = f(x_new)
|
| 30 |
return x_new, y_new
|
| 31 |
|
|
|
|
| 32 |
def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
|
| 33 |
"""Polynomial baseline subtraction (degree=2 default)"""
|
| 34 |
y = np.asarray(y, dtype=float)
|
|
@@ -37,19 +67,25 @@ def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
|
|
| 37 |
baseline = np.polyval(coeffs, x_idx)
|
| 38 |
return y - baseline
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
"""Savitzky-Golay smoothing with safe/odd window enforcement"""
|
| 42 |
y = np.asarray(y, dtype=float)
|
| 43 |
window_length = int(window_length)
|
| 44 |
polyorder = int(polyorder)
|
| 45 |
# === window must be odd and >= polyorder+1 ===
|
| 46 |
if window_length % 2 == 0:
|
| 47 |
-
window_length += 1
|
| 48 |
min_win = polyorder + 1
|
| 49 |
if min_win % 2 == 0:
|
| 50 |
min_win += 1
|
| 51 |
window_length = max(window_length, min_win)
|
| 52 |
-
return savgol_filter(
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def normalize_spectrum(y: np.ndarray) -> np.ndarray:
|
| 55 |
"""Min-max normalization to [0, 1] with constant-signal guard."""
|
|
@@ -60,27 +96,114 @@ def normalize_spectrum(y: np.ndarray) -> np.ndarray:
|
|
| 60 |
return np.zeros_like(y)
|
| 61 |
return (y - y_min) / (y_max - y_min)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def preprocess_spectrum(
|
| 64 |
x: np.ndarray,
|
| 65 |
y: np.ndarray,
|
| 66 |
*,
|
| 67 |
target_len: int = TARGET_LENGTH,
|
|
|
|
| 68 |
do_baseline: bool = True,
|
| 69 |
-
degree: int =
|
| 70 |
do_smooth: bool = True,
|
| 71 |
-
window_length: int =
|
| 72 |
-
polyorder: int =
|
| 73 |
do_normalize: bool = True,
|
| 74 |
out_dtype: DTypeLike = np.float32,
|
|
|
|
| 75 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 76 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
x_rs, y_rs = resample_spectrum(x, y, target_len=target_len)
|
|
|
|
| 78 |
if do_baseline:
|
| 79 |
y_rs = remove_baseline(y_rs, degree=degree)
|
|
|
|
| 80 |
if do_smooth:
|
| 81 |
y_rs = smooth_spectrum(y_rs, window_length=window_length, polyorder=polyorder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if do_normalize:
|
| 83 |
y_rs = normalize_spectrum(y_rs)
|
|
|
|
| 84 |
# === Coerce to a real dtype to satisfy static checkers & runtime ===
|
| 85 |
out_dt = np.dtype(out_dtype)
|
| 86 |
-
return x_rs.astype(out_dt, copy=False), y_rs.astype(out_dt, copy=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Preprocessing utilities for polymer classification app.
|
| 3 |
Adapted from the original scripts/preprocess_dataset.py for Hugging Face Spaces deployment.
|
| 4 |
+
Supports both Raman and FTIR spectroscopy modalities.
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
|
|
|
| 10 |
from scipy.interpolate import interp1d
|
| 11 |
from scipy.signal import savgol_filter
|
| 12 |
from scipy.interpolate import interp1d
|
| 13 |
+
from typing import Tuple, Literal
|
| 14 |
+
|
| 15 |
+
TARGET_LENGTH = 500 # Frozen default per PREPROCESSING_BASELINE
|
| 16 |
+
|
| 17 |
+
# Modality-specific validation ranges (cm⁻¹)
|
| 18 |
+
MODALITY_RANGES = {
|
| 19 |
+
"raman": (200, 4000), # Typical Raman range
|
| 20 |
+
"ftir": (400, 4000), # FTIR wavenumber range
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# Modality-specific preprocessing parameters
|
| 24 |
+
MODALITY_PARAMS = {
|
| 25 |
+
"raman": {
|
| 26 |
+
"baseline_degree": 2,
|
| 27 |
+
"smooth_window": 11,
|
| 28 |
+
"smooth_polyorder": 2,
|
| 29 |
+
"cosmic_ray_removal": False,
|
| 30 |
+
},
|
| 31 |
+
"ftir": {
|
| 32 |
+
"baseline_degree": 2,
|
| 33 |
+
"smooth_window": 13, # Slightly larger window for FTIR
|
| 34 |
+
"smooth_polyorder": 2,
|
| 35 |
+
"cosmic_ray_removal": False, # Could add atmospheric correction
|
| 36 |
+
"atmospheric_correction": False, # Placeholder for future implementation
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
def _ensure_1d_equal(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 42 |
x = np.asarray(x, dtype=float)
|
|
|
|
| 45 |
raise ValueError("x and y must be 1D arrays of equal length >= 2")
|
| 46 |
return x, y
|
| 47 |
|
| 48 |
+
|
| 49 |
+
def resample_spectrum(
|
| 50 |
+
x: np.ndarray, y: np.ndarray, target_len: int = TARGET_LENGTH
|
| 51 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 52 |
"""Linear re-sampling onto a uniform grid of length target_len."""
|
| 53 |
x, y = _ensure_1d_equal(x, y)
|
| 54 |
order = np.argsort(x)
|
|
|
|
| 58 |
y_new = f(x_new)
|
| 59 |
return x_new, y_new
|
| 60 |
|
| 61 |
+
|
| 62 |
def remove_baseline(y: np.ndarray, degree: int = 2) -> np.ndarray:
|
| 63 |
"""Polynomial baseline subtraction (degree=2 default)"""
|
| 64 |
y = np.asarray(y, dtype=float)
|
|
|
|
| 67 |
baseline = np.polyval(coeffs, x_idx)
|
| 68 |
return y - baseline
|
| 69 |
|
| 70 |
+
|
| 71 |
+
def smooth_spectrum(
|
| 72 |
+
y: np.ndarray, window_length: int = 11, polyorder: int = 2
|
| 73 |
+
) -> np.ndarray:
|
| 74 |
"""Savitzky-Golay smoothing with safe/odd window enforcement"""
|
| 75 |
y = np.asarray(y, dtype=float)
|
| 76 |
window_length = int(window_length)
|
| 77 |
polyorder = int(polyorder)
|
| 78 |
# === window must be odd and >= polyorder+1 ===
|
| 79 |
if window_length % 2 == 0:
|
| 80 |
+
window_length += 1
|
| 81 |
min_win = polyorder + 1
|
| 82 |
if min_win % 2 == 0:
|
| 83 |
min_win += 1
|
| 84 |
window_length = max(window_length, min_win)
|
| 85 |
+
return savgol_filter(
|
| 86 |
+
y, window_length=window_length, polyorder=polyorder, mode="interp"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
|
| 90 |
def normalize_spectrum(y: np.ndarray) -> np.ndarray:
|
| 91 |
"""Min-max normalization to [0, 1] with constant-signal guard."""
|
|
|
|
| 96 |
return np.zeros_like(y)
|
| 97 |
return (y - y_min) / (y_max - y_min)
|
| 98 |
|
| 99 |
+
|
| 100 |
+
def validate_spectrum_range(x: np.ndarray, modality: str = "raman") -> bool:
|
| 101 |
+
"""Validate that spectrum wavenumbers are within expected range for modality."""
|
| 102 |
+
if modality not in MODALITY_RANGES:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"Unknown modality '{modality}'. Supported: {list(MODALITY_RANGES.keys())}"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
min_range, max_range = MODALITY_RANGES[modality]
|
| 108 |
+
x_min, x_max = np.min(x), np.max(x)
|
| 109 |
+
|
| 110 |
+
# Check if majority of data points are within range
|
| 111 |
+
in_range = np.sum((x >= min_range) & (x <= max_range))
|
| 112 |
+
total_points = len(x)
|
| 113 |
+
|
| 114 |
+
return (in_range / total_points) >= 0.7 # At least 70% should be in range
|
| 115 |
+
|
| 116 |
+
|
| 117 |
def preprocess_spectrum(
|
| 118 |
x: np.ndarray,
|
| 119 |
y: np.ndarray,
|
| 120 |
*,
|
| 121 |
target_len: int = TARGET_LENGTH,
|
| 122 |
+
modality: str = "raman", # New parameter for modality-specific processing
|
| 123 |
do_baseline: bool = True,
|
| 124 |
+
degree: int | None = None, # Will use modality default if None
|
| 125 |
do_smooth: bool = True,
|
| 126 |
+
window_length: int | None = None, # Will use modality default if None
|
| 127 |
+
polyorder: int | None = None, # Will use modality default if None
|
| 128 |
do_normalize: bool = True,
|
| 129 |
out_dtype: DTypeLike = np.float32,
|
| 130 |
+
validate_range: bool = True,
|
| 131 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 132 |
+
"""
|
| 133 |
+
Modality-aware preprocessing: resample -> baseline -> smooth -> normalize
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
x, y: Input spectrum data
|
| 137 |
+
target_len: Target length for resampling
|
| 138 |
+
modality: 'raman' or 'ftir' for modality-specific processing
|
| 139 |
+
do_baseline: Enable baseline correction
|
| 140 |
+
degree: Polynomial degree for baseline (uses modality default if None)
|
| 141 |
+
do_smooth: Enable smoothing
|
| 142 |
+
window_length: Smoothing window length (uses modality default if None)
|
| 143 |
+
polyorder: Polynomial order for smoothing (uses modality default if None)
|
| 144 |
+
do_normalize: Enable normalization
|
| 145 |
+
out_dtype: Output data type
|
| 146 |
+
validate_range: Check if wavenumbers are in expected range for modality
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Tuple of (resampled_x, processed_y)
|
| 150 |
+
"""
|
| 151 |
+
# Validate modality
|
| 152 |
+
if modality not in MODALITY_PARAMS:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
f"Unsupported modality '{modality}'. Supported: {list(MODALITY_PARAMS.keys())}"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Get modality-specific parameters
|
| 158 |
+
modality_config = MODALITY_PARAMS[modality]
|
| 159 |
+
|
| 160 |
+
# Use modality defaults if parameters not specified
|
| 161 |
+
if degree is None:
|
| 162 |
+
degree = modality_config["baseline_degree"]
|
| 163 |
+
if window_length is None:
|
| 164 |
+
window_length = modality_config["smooth_window"]
|
| 165 |
+
if polyorder is None:
|
| 166 |
+
polyorder = modality_config["smooth_polyorder"]
|
| 167 |
+
|
| 168 |
+
# Validate spectrum range if requested
|
| 169 |
+
if validate_range:
|
| 170 |
+
if not validate_spectrum_range(x, modality):
|
| 171 |
+
print(
|
| 172 |
+
f"Warning: Spectrum wavenumbers may not be optimal for {modality.upper()} analysis"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Standard preprocessing pipeline
|
| 176 |
x_rs, y_rs = resample_spectrum(x, y, target_len=target_len)
|
| 177 |
+
|
| 178 |
if do_baseline:
|
| 179 |
y_rs = remove_baseline(y_rs, degree=degree)
|
| 180 |
+
|
| 181 |
if do_smooth:
|
| 182 |
y_rs = smooth_spectrum(y_rs, window_length=window_length, polyorder=polyorder)
|
| 183 |
+
|
| 184 |
+
# FTIR-specific processing (placeholder for future enhancements)
|
| 185 |
+
if modality == "ftir":
|
| 186 |
+
if modality_config.get("atmospheric_correction", False):
|
| 187 |
+
# Placeholder for atmospheric correction
|
| 188 |
+
pass
|
| 189 |
+
if modality_config.get("cosmic_ray_removal", False):
|
| 190 |
+
# Placeholder for cosmic ray removal
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
if do_normalize:
|
| 194 |
y_rs = normalize_spectrum(y_rs)
|
| 195 |
+
|
| 196 |
# === Coerce to a real dtype to satisfy static checkers & runtime ===
|
| 197 |
out_dt = np.dtype(out_dtype)
|
| 198 |
+
return x_rs.astype(out_dt, copy=False), y_rs.astype(out_dt, copy=False)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_modality_info(modality: str) -> dict:
|
| 202 |
+
"""Get processing parameters and validation ranges for a modality."""
|
| 203 |
+
if modality not in MODALITY_PARAMS:
|
| 204 |
+
raise ValueError(f"Unknown modality '{modality}'")
|
| 205 |
+
|
| 206 |
+
return {
|
| 207 |
+
"range": MODALITY_RANGES[modality],
|
| 208 |
+
"params": MODALITY_PARAMS[modality].copy(),
|
| 209 |
+
}
|
utils/results_manager.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
"""Session results management for multi-file inference.
|
| 2 |
-
Handles in-memory results table and export functionality
|
|
|
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
import pandas as pd
|
| 6 |
import json
|
| 7 |
from datetime import datetime
|
| 8 |
-
from typing import Dict, List, Any, Optional
|
| 9 |
import numpy as np
|
| 10 |
from pathlib import Path
|
| 11 |
import io
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def local_css(file_name):
|
|
@@ -199,6 +202,219 @@ class ResultsManager:
|
|
| 199 |
|
| 200 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
@staticmethod
|
| 203 |
# ==UTILITY FUNCTIONS==
|
| 204 |
def init_session_state():
|
|
|
|
| 1 |
"""Session results management for multi-file inference.
|
| 2 |
+
Handles in-memory results table and export functionality.
|
| 3 |
+
Supports multi-model comparison and statistical analysis."""
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
import pandas as pd
|
| 7 |
import json
|
| 8 |
from datetime import datetime
|
| 9 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 10 |
import numpy as np
|
| 11 |
from pathlib import Path
|
| 12 |
import io
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
|
| 16 |
|
| 17 |
def local_css(file_name):
|
|
|
|
| 202 |
|
| 203 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
| 204 |
|
| 205 |
+
@staticmethod
|
| 206 |
+
def add_multi_model_results(
|
| 207 |
+
filename: str,
|
| 208 |
+
model_results: Dict[str, Dict[str, Any]],
|
| 209 |
+
ground_truth: Optional[int] = None,
|
| 210 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 211 |
+
) -> None:
|
| 212 |
+
"""
|
| 213 |
+
Add results from multiple models for the same file.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
filename: Name of the processed file
|
| 217 |
+
model_results: Dict with model_name -> result dict
|
| 218 |
+
ground_truth: True label if available
|
| 219 |
+
metadata: Additional file metadata
|
| 220 |
+
"""
|
| 221 |
+
for model_name, result in model_results.items():
|
| 222 |
+
ResultsManager.add_results(
|
| 223 |
+
filename=filename,
|
| 224 |
+
model_name=model_name,
|
| 225 |
+
prediction=result["prediction"],
|
| 226 |
+
predicted_class=result["predicted_class"],
|
| 227 |
+
confidence=result["confidence"],
|
| 228 |
+
logits=result["logits"],
|
| 229 |
+
ground_truth=ground_truth,
|
| 230 |
+
processing_time=result.get("processing_time", 0.0),
|
| 231 |
+
metadata=metadata,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def get_comparison_stats() -> Dict[str, Any]:
|
| 236 |
+
"""Get comparative statistics across all models."""
|
| 237 |
+
results = ResultsManager.get_results()
|
| 238 |
+
if not results:
|
| 239 |
+
return {}
|
| 240 |
+
|
| 241 |
+
# Group results by model
|
| 242 |
+
model_stats = defaultdict(list)
|
| 243 |
+
for result in results:
|
| 244 |
+
model_stats[result["model"]].append(result)
|
| 245 |
+
|
| 246 |
+
comparison = {}
|
| 247 |
+
for model_name, model_results in model_stats.items():
|
| 248 |
+
stats = {
|
| 249 |
+
"total_predictions": len(model_results),
|
| 250 |
+
"avg_confidence": np.mean([r["confidence"] for r in model_results]),
|
| 251 |
+
"std_confidence": np.std([r["confidence"] for r in model_results]),
|
| 252 |
+
"avg_processing_time": np.mean(
|
| 253 |
+
[r["processing_time"] for r in model_results]
|
| 254 |
+
),
|
| 255 |
+
"stable_predictions": sum(
|
| 256 |
+
1 for r in model_results if r["prediction"] == 0
|
| 257 |
+
),
|
| 258 |
+
"weathered_predictions": sum(
|
| 259 |
+
1 for r in model_results if r["prediction"] == 1
|
| 260 |
+
),
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
# Calculate accuracy if ground truth available
|
| 264 |
+
with_gt = [r for r in model_results if r["ground_truth"] is not None]
|
| 265 |
+
if with_gt:
|
| 266 |
+
correct = sum(
|
| 267 |
+
1 for r in with_gt if r["prediction"] == r["ground_truth"]
|
| 268 |
+
)
|
| 269 |
+
stats["accuracy"] = correct / len(with_gt)
|
| 270 |
+
stats["num_with_ground_truth"] = len(with_gt)
|
| 271 |
+
else:
|
| 272 |
+
stats["accuracy"] = None
|
| 273 |
+
stats["num_with_ground_truth"] = 0
|
| 274 |
+
|
| 275 |
+
comparison[model_name] = stats
|
| 276 |
+
|
| 277 |
+
return comparison
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def get_agreement_matrix() -> pd.DataFrame:
|
| 281 |
+
"""
|
| 282 |
+
Calculate agreement matrix between models for the same files.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
DataFrame showing model agreement rates
|
| 286 |
+
"""
|
| 287 |
+
results = ResultsManager.get_results()
|
| 288 |
+
if not results:
|
| 289 |
+
return pd.DataFrame()
|
| 290 |
+
|
| 291 |
+
# Group by filename
|
| 292 |
+
file_results = defaultdict(dict)
|
| 293 |
+
for result in results:
|
| 294 |
+
file_results[result["filename"]][result["model"]] = result["prediction"]
|
| 295 |
+
|
| 296 |
+
# Get unique models
|
| 297 |
+
all_models = list(set(r["model"] for r in results))
|
| 298 |
+
|
| 299 |
+
if len(all_models) < 2:
|
| 300 |
+
return pd.DataFrame()
|
| 301 |
+
|
| 302 |
+
# Calculate agreement matrix
|
| 303 |
+
agreement_matrix = np.zeros((len(all_models), len(all_models)))
|
| 304 |
+
|
| 305 |
+
for i, model1 in enumerate(all_models):
|
| 306 |
+
for j, model2 in enumerate(all_models):
|
| 307 |
+
if i == j:
|
| 308 |
+
agreement_matrix[i, j] = 1.0 # Perfect self-agreement
|
| 309 |
+
else:
|
| 310 |
+
agreements = 0
|
| 311 |
+
comparisons = 0
|
| 312 |
+
|
| 313 |
+
for filename, predictions in file_results.items():
|
| 314 |
+
if model1 in predictions and model2 in predictions:
|
| 315 |
+
comparisons += 1
|
| 316 |
+
if predictions[model1] == predictions[model2]:
|
| 317 |
+
agreements += 1
|
| 318 |
+
|
| 319 |
+
if comparisons > 0:
|
| 320 |
+
agreement_matrix[i, j] = agreements / comparisons
|
| 321 |
+
|
| 322 |
+
return pd.DataFrame(agreement_matrix, index=all_models, columns=all_models)
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def create_comparison_visualization() -> plt.Figure:
|
| 326 |
+
"""Create visualization comparing model performance."""
|
| 327 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
| 328 |
+
|
| 329 |
+
if not comparison_stats:
|
| 330 |
+
return None
|
| 331 |
+
|
| 332 |
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
|
| 333 |
+
|
| 334 |
+
models = list(comparison_stats.keys())
|
| 335 |
+
|
| 336 |
+
# 1. Average Confidence
|
| 337 |
+
confidences = [comparison_stats[m]["avg_confidence"] for m in models]
|
| 338 |
+
conf_stds = [comparison_stats[m]["std_confidence"] for m in models]
|
| 339 |
+
ax1.bar(models, confidences, yerr=conf_stds, capsize=5)
|
| 340 |
+
ax1.set_title("Average Confidence by Model")
|
| 341 |
+
ax1.set_ylabel("Confidence")
|
| 342 |
+
ax1.tick_params(axis="x", rotation=45)
|
| 343 |
+
|
| 344 |
+
# 2. Processing Time
|
| 345 |
+
proc_times = [comparison_stats[m]["avg_processing_time"] for m in models]
|
| 346 |
+
ax2.bar(models, proc_times)
|
| 347 |
+
ax2.set_title("Average Processing Time")
|
| 348 |
+
ax2.set_ylabel("Time (seconds)")
|
| 349 |
+
ax2.tick_params(axis="x", rotation=45)
|
| 350 |
+
|
| 351 |
+
# 3. Prediction Distribution
|
| 352 |
+
stable_counts = [comparison_stats[m]["stable_predictions"] for m in models]
|
| 353 |
+
weathered_counts = [
|
| 354 |
+
comparison_stats[m]["weathered_predictions"] for m in models
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
x = np.arange(len(models))
|
| 358 |
+
width = 0.35
|
| 359 |
+
ax3.bar(x - width / 2, stable_counts, width, label="Stable", alpha=0.8)
|
| 360 |
+
ax3.bar(x + width / 2, weathered_counts, width, label="Weathered", alpha=0.8)
|
| 361 |
+
ax3.set_title("Prediction Distribution")
|
| 362 |
+
ax3.set_ylabel("Count")
|
| 363 |
+
ax3.set_xticks(x)
|
| 364 |
+
ax3.set_xticklabels(models, rotation=45)
|
| 365 |
+
ax3.legend()
|
| 366 |
+
|
| 367 |
+
# 4. Accuracy (if available)
|
| 368 |
+
accuracies = []
|
| 369 |
+
models_with_acc = []
|
| 370 |
+
for model in models:
|
| 371 |
+
if comparison_stats[model]["accuracy"] is not None:
|
| 372 |
+
accuracies.append(comparison_stats[model]["accuracy"])
|
| 373 |
+
models_with_acc.append(model)
|
| 374 |
+
|
| 375 |
+
if accuracies:
|
| 376 |
+
ax4.bar(models_with_acc, accuracies)
|
| 377 |
+
ax4.set_title("Model Accuracy (where ground truth available)")
|
| 378 |
+
ax4.set_ylabel("Accuracy")
|
| 379 |
+
ax4.set_ylim(0, 1)
|
| 380 |
+
ax4.tick_params(axis="x", rotation=45)
|
| 381 |
+
else:
|
| 382 |
+
ax4.text(
|
| 383 |
+
0.5,
|
| 384 |
+
0.5,
|
| 385 |
+
"No ground truth\navailable",
|
| 386 |
+
ha="center",
|
| 387 |
+
va="center",
|
| 388 |
+
transform=ax4.transAxes,
|
| 389 |
+
)
|
| 390 |
+
ax4.set_title("Model Accuracy")
|
| 391 |
+
|
| 392 |
+
plt.tight_layout()
|
| 393 |
+
return fig
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def export_comparison_report() -> str:
|
| 397 |
+
"""Export comprehensive comparison report as JSON."""
|
| 398 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
| 399 |
+
agreement_matrix = ResultsManager.get_agreement_matrix()
|
| 400 |
+
|
| 401 |
+
report = {
|
| 402 |
+
"timestamp": datetime.now().isoformat(),
|
| 403 |
+
"model_comparison": comparison_stats,
|
| 404 |
+
"agreement_matrix": (
|
| 405 |
+
agreement_matrix.to_dict() if not agreement_matrix.empty else {}
|
| 406 |
+
),
|
| 407 |
+
"summary": {
|
| 408 |
+
"total_models_compared": len(comparison_stats),
|
| 409 |
+
"total_files_processed": len(
|
| 410 |
+
set(r["filename"] for r in ResultsManager.get_results())
|
| 411 |
+
),
|
| 412 |
+
"overall_statistics": ResultsManager.get_summary_stats(),
|
| 413 |
+
},
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
return json.dumps(report, indent=2, default=str)
|
| 417 |
+
|
| 418 |
@staticmethod
|
| 419 |
# ==UTILITY FUNCTIONS==
|
| 420 |
def init_session_state():
|