Spaces:
Sleeping
(FEAT)[Enhanced Results Widget]: Integrate advanced probability breakdown, QC, and provenance export
Browse files- Updated ui_components.py and 2_Enhanced_Analysis.py to include a research-grade probability breakdown widget.
- Added entropy, margin, calibration, and provenance export to results column.
- Ensured QC summary and preprocessing parameters are computed and stored in session_state for meaningful diagnostics.
(FIX)[QC Summary Assignment]: Ensure valid spectrum QC metrics in results
- Modified core_logic.py and app.py to compute and assign n_points, x_min, x_max, monotonic_x, nan_free, and variance_proxy after spectrum parsing.
- Eliminated NULL values in QC summary for reliable reporting.
(DOCS)[Markdown & Citation]: Improve README formatting and citation style
- Updated README.md to replace bare DOI URLs with markdown link syntax, resolving markdownlint MD034 and improving citation readability.
(FEAT)[Model Inspection Utility]: Add inspect_weights.py for model weight analysis
- Introduced inspect_weights.py to support model weight inspection and debugging for training and inference workflows.
(FEAT+FIX)[Batch & Image Processing]: Refine batch utilities and image processing
- Enhanced multifile.py and figure2_model.py for multi-format batch support, error resilience, and improved spectrum parsing.
- Improved training_engine.py for robust batch processing and model training integration.
Refactor training architecture and enhance model training capabilities
- Unified training logic by introducing a central `TrainingEngine` class in `utils/training_engine.py`.
- Decoupled data structures for training configuration and status into `utils/training_types.py`.
- Updated CLI script (`scripts/train_model.py`) to utilize the new training engine, improving maintainability.
- Enhanced `TrainingManager` in `utils/training_manager.py` to support the new training engine and provide better job management.
- Added diagnostics script (`inspect_weights.py`) for inspecting model weights and identifying potential issues.
- Improved data parsing by modifying `utils/multifile.py` to streamline spectrum data handling.
- Updated model weight handling and logging mechanisms to ensure better tracking of training progress and results.
- Created comprehensive README documentation for the training modules, detailing usage and project structure.
- app.py +0 -72
- core_logic.py +3 -58
- inspect_weights.py +116 -0
- modules/TRAINING_MODELS_README.md +94 -0
- modules/training_ui.py +2 -6
- modules/ui_components.py +218 -108
- outputs/figure2_model.pth +1 -1
- pages/2_Enhanced_Analysis.py +2 -1
- scripts/train_model.py +92 -1
- utils/multifile.py +17 -33
- utils/training_engine.py +160 -0
- utils/training_manager.py +66 -245
- utils/training_types.py +0 -0
@@ -1,72 +0,0 @@
|
|
1 |
-
# In App.py
|
2 |
-
import streamlit as st
|
3 |
-
|
4 |
-
from modules.callbacks import init_session_state
|
5 |
-
|
6 |
-
from modules.ui_components import (
|
7 |
-
render_sidebar,
|
8 |
-
render_results_column,
|
9 |
-
render_input_column,
|
10 |
-
render_comparison_tab,
|
11 |
-
render_performance_tab,
|
12 |
-
load_css,
|
13 |
-
)
|
14 |
-
|
15 |
-
from modules.training_ui import render_training_tab
|
16 |
-
|
17 |
-
from utils.image_processing import render_image_upload_interface
|
18 |
-
|
19 |
-
st.set_page_config(
|
20 |
-
page_title="ML Polymer Classification",
|
21 |
-
page_icon="🔬",
|
22 |
-
layout="wide",
|
23 |
-
initial_sidebar_state="expanded",
|
24 |
-
menu_items=None,
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
def main():
|
29 |
-
"""Modularized main content to other scripts to clean the main app"""
|
30 |
-
load_css("static/style.css")
|
31 |
-
init_session_state()
|
32 |
-
|
33 |
-
render_sidebar()
|
34 |
-
|
35 |
-
# Create main tabs for different analysis modes
|
36 |
-
tab1, tab2, tab3, tab4, tab5 = st.tabs(
|
37 |
-
[
|
38 |
-
"Standard Analysis",
|
39 |
-
"Model Comparison",
|
40 |
-
"Model Training",
|
41 |
-
"Image Analysis",
|
42 |
-
"Performance Tracking",
|
43 |
-
]
|
44 |
-
)
|
45 |
-
|
46 |
-
with tab1:
|
47 |
-
# Standard single-model analysis
|
48 |
-
col1, col2 = st.columns([1, 1.35], gap="small")
|
49 |
-
with col1:
|
50 |
-
render_input_column()
|
51 |
-
with col2:
|
52 |
-
render_results_column()
|
53 |
-
|
54 |
-
with tab2:
|
55 |
-
# Multi-model comparison interface
|
56 |
-
render_comparison_tab()
|
57 |
-
|
58 |
-
with tab3:
|
59 |
-
# Model training interface
|
60 |
-
render_training_tab()
|
61 |
-
|
62 |
-
with tab4:
|
63 |
-
# Image analysis interface
|
64 |
-
render_image_upload_interface()
|
65 |
-
|
66 |
-
with tab5:
|
67 |
-
# Performance tracking interface
|
68 |
-
render_performance_tab()
|
69 |
-
|
70 |
-
|
71 |
-
if __name__ == "__main__":
|
72 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -10,7 +10,6 @@ import numpy as np
|
|
10 |
import streamlit as st
|
11 |
from pathlib import Path
|
12 |
from config import SAMPLE_DATA_DIR
|
13 |
-
from datetime import datetime
|
14 |
from models.registry import build, choices
|
15 |
|
16 |
|
@@ -27,7 +26,7 @@ def label_file(filename: str) -> int:
|
|
27 |
|
28 |
|
29 |
@st.cache_data
|
30 |
-
def load_state_dict(
|
31 |
"""Load state dict with mtime in cache key to detect file changes"""
|
32 |
try:
|
33 |
return torch.load(model_path, map_location="cpu")
|
@@ -61,6 +60,7 @@ def load_model(model_name):
|
|
61 |
model.load_state_dict(state_dict, strict=True)
|
62 |
model.eval()
|
63 |
weights_loaded = True
|
|
|
64 |
|
65 |
except (OSError, RuntimeError):
|
66 |
continue
|
@@ -88,7 +88,7 @@ def cleanup_memory():
|
|
88 |
|
89 |
|
90 |
@st.cache_data
|
91 |
-
def run_inference(y_resampled, model_choice, modality: str,
|
92 |
"""Run model inference and cache results with performance tracking"""
|
93 |
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
|
94 |
from datetime import datetime
|
@@ -169,58 +169,3 @@ def get_sample_files():
|
|
169 |
if sample_dir.exists():
|
170 |
return sorted(list(sample_dir.glob("*.txt")))
|
171 |
return []
|
172 |
-
|
173 |
-
|
174 |
-
def parse_spectrum_data(raw_text):
|
175 |
-
"""Parse spectrum data from text with robust error handling and validation"""
|
176 |
-
x_vals, y_vals = [], []
|
177 |
-
|
178 |
-
for line in raw_text.splitlines():
|
179 |
-
line = line.strip()
|
180 |
-
if not line or line.startswith("#"): # Skip empty lines and comments
|
181 |
-
continue
|
182 |
-
|
183 |
-
try:
|
184 |
-
# Handle different separators
|
185 |
-
parts = line.replace(",", " ").split()
|
186 |
-
numbers = [
|
187 |
-
p
|
188 |
-
for p in parts
|
189 |
-
if p.replace(".", "", 1)
|
190 |
-
.replace("-", "", 1)
|
191 |
-
.replace("+", "", 1)
|
192 |
-
.isdigit()
|
193 |
-
]
|
194 |
-
|
195 |
-
if len(numbers) >= 2:
|
196 |
-
x, y = float(numbers[0]), float(numbers[1])
|
197 |
-
x_vals.append(x)
|
198 |
-
y_vals.append(y)
|
199 |
-
|
200 |
-
except ValueError:
|
201 |
-
# Skip problematic lines but don't fail completely
|
202 |
-
continue
|
203 |
-
|
204 |
-
if len(x_vals) < 10: # Minimum reasonable spectrum length
|
205 |
-
raise ValueError(
|
206 |
-
f"Insufficient data points: {len(x_vals)}. Need at least 10 points."
|
207 |
-
)
|
208 |
-
|
209 |
-
x = np.array(x_vals)
|
210 |
-
y = np.array(y_vals)
|
211 |
-
|
212 |
-
# Check for NaNs
|
213 |
-
if np.any(np.isnan(x)) or np.any(np.isnan(y)):
|
214 |
-
raise ValueError("Input data contains NaN values")
|
215 |
-
|
216 |
-
# Check monotonic increasing x
|
217 |
-
if not np.all(np.diff(x) > 0):
|
218 |
-
raise ValueError("Wavenumbers must be strictly increasing")
|
219 |
-
|
220 |
-
# Check reasonable range for Raman spectroscopy
|
221 |
-
if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
|
222 |
-
raise ValueError(
|
223 |
-
f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100"
|
224 |
-
)
|
225 |
-
|
226 |
-
return x, y
|
|
|
10 |
import streamlit as st
|
11 |
from pathlib import Path
|
12 |
from config import SAMPLE_DATA_DIR
|
|
|
13 |
from models.registry import build, choices
|
14 |
|
15 |
|
|
|
26 |
|
27 |
|
28 |
@st.cache_data
|
29 |
+
def load_state_dict(mtime, model_path):
|
30 |
"""Load state dict with mtime in cache key to detect file changes"""
|
31 |
try:
|
32 |
return torch.load(model_path, map_location="cpu")
|
|
|
60 |
model.load_state_dict(state_dict, strict=True)
|
61 |
model.eval()
|
62 |
weights_loaded = True
|
63 |
+
break # Exit loop after successful load
|
64 |
|
65 |
except (OSError, RuntimeError):
|
66 |
continue
|
|
|
88 |
|
89 |
|
90 |
@st.cache_data
|
91 |
+
def run_inference(y_resampled, model_choice, modality: str, cache_key=None):
|
92 |
"""Run model inference and cache results with performance tracking"""
|
93 |
from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
|
94 |
from datetime import datetime
|
|
|
169 |
if sample_dir.exists():
|
170 |
return sorted(list(sample_dir.glob("*.txt")))
|
171 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Diagnostic script to inspect the weights within a PyTorch .pth file.
|
3 |
+
|
4 |
+
This utility loads a model's state dictionary and prints summary statistics
|
5 |
+
(mean, std, min, max) for each parameter tensor. It helps diagnose issues
|
6 |
+
like corrupted weights from failed or interrupted training runs, which might
|
7 |
+
result in a model producing constant, incorrect outputs.
|
8 |
+
|
9 |
+
Usage:
|
10 |
+
python scripts/inspect_weights.py path/to/your/model_weights.pth
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
from pathlib import Path
|
17 |
+
import sys
|
18 |
+
|
19 |
+
# Add project root to path to allow imports from other modules
|
20 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
21 |
+
|
22 |
+
|
23 |
+
def inspect_weights(file_path: str):
|
24 |
+
"""
|
25 |
+
Loads a model state_dict from a .pth file and prints statistics
|
26 |
+
for each parameter tensor to help diagnose corrupted weights.
|
27 |
+
"""
|
28 |
+
if not os.path.exists(file_path):
|
29 |
+
print(f"❌ Error: File not found at {file_path}")
|
30 |
+
return
|
31 |
+
|
32 |
+
print(f"🔍 Inspecting weights for: {file_path}\n")
|
33 |
+
|
34 |
+
try:
|
35 |
+
# Load the state dictionary
|
36 |
+
# Use weights_only=True for security and to supress the warning
|
37 |
+
try:
|
38 |
+
state_dict = torch.load(
|
39 |
+
file_path, map_location=torch.device("cpu"), weights_only=True
|
40 |
+
)
|
41 |
+
except TypeError: # Fallback for older torch versions
|
42 |
+
state_dict = torch.load(file_path, map_location=torch.device("cpu"))
|
43 |
+
|
44 |
+
# Handle checkpoints that save the model in a sub-dictionary
|
45 |
+
if "model_state_dict" in state_dict:
|
46 |
+
state_dict = state_dict["model_state_dict"]
|
47 |
+
elif "model" in state_dict:
|
48 |
+
state_dict = state_dict["model"]
|
49 |
+
|
50 |
+
if not state_dict:
|
51 |
+
print("⚠️ State dictionary is empty.")
|
52 |
+
return
|
53 |
+
|
54 |
+
print(
|
55 |
+
f"{'Parameter Name':<40} {'Shape':<20} {'Mean':<15} {'Std Dev':<15} {'Min':<15} {'Max':<15}"
|
56 |
+
)
|
57 |
+
print("-" * 120)
|
58 |
+
all_stds = []
|
59 |
+
|
60 |
+
for name, param in state_dict.items():
|
61 |
+
if isinstance(param, torch.Tensor):
|
62 |
+
# Ensure tensor is float for stats, but don't fail if not
|
63 |
+
try:
|
64 |
+
param_float = param.float()
|
65 |
+
mean_val = f"{param_float.mean().item():.4e}"
|
66 |
+
std_val_float = param_float.std().item()
|
67 |
+
std_val = f"{std_val_float:.4e}"
|
68 |
+
min_val = f"{param_float.min().item():.4e}"
|
69 |
+
max_val = f"{param_float.max().item():.4e}"
|
70 |
+
all_stds.append(std_val_float)
|
71 |
+
except (RuntimeError, TypeError):
|
72 |
+
mean_val, std_val, min_val, max_val = "N/A", "N/A", "N/A", "N/A"
|
73 |
+
|
74 |
+
shape_str = str(list(param.shape))
|
75 |
+
print(
|
76 |
+
f"{name:<40} {shape_str:<20} {mean_val:<15} {std_val:<15} {min_val:<15} {max_val:<15}"
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
print(f"{name:<40} {'Non-Tensor':<20} {str(param):<60}")
|
80 |
+
|
81 |
+
print("\n" + "-" * 120)
|
82 |
+
print("✅ Inspection complete.")
|
83 |
+
print("\nDiagnosis:")
|
84 |
+
print(
|
85 |
+
"- If you see all zeros, NaNs, or very small (e.g., e-38) uniform values, the weights file is likely corrupted."
|
86 |
+
)
|
87 |
+
if all(s < 1e-6 for s in all_stds if s is not None):
|
88 |
+
print(
|
89 |
+
"- WARNING: All parameter standard deviations are extremely low. The model may be 'dead' and insensitive to input."
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
print(
|
93 |
+
"- The weight statistics appear varied, suggesting the file is not corrupted with zeros/NaNs."
|
94 |
+
)
|
95 |
+
print(
|
96 |
+
"- If the model still produces constant output, it is likely poorly trained."
|
97 |
+
)
|
98 |
+
|
99 |
+
print("\nRecommendation: Retraining the model is the correct solution.")
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
print(f"❌ An error occurred while inspecting the weights file: {e}")
|
103 |
+
import traceback
|
104 |
+
|
105 |
+
traceback.print_exc()
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
parser = argparse.ArgumentParser(
|
110 |
+
description="Inspect PyTorch model weights in a .pth file."
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"file_path", type=str, help="Path to the .pth model weights file."
|
114 |
+
)
|
115 |
+
args = parser.parse_args()
|
116 |
+
inspect_weights(args.file_path)
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# POLYMEROS: AI-Driven Polymer Aging Analysis & Classification
|
2 |
+
|
3 |
+
POLYMEROS is an advanced, AI-driven platform for analyzing and classifying polymer degradation using spectroscopic data. This project extends a baseline CNN model to incorporate multi-modal analysis (Raman & FTIR), modern machine learning architectures, a comprehensive data pipeline, and an interactive educational framework.
|
4 |
+
|
5 |
+
[](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
|
6 |
+
|
7 |
+
---
|
8 |
+
|
9 |
+
## 🚀 Key Features & Recent Enhancements
|
10 |
+
|
11 |
+
This platform has been significantly enhanced with a suite of research-grade features. Recent architectural improvements have focused on creating a robust, maintainable, and unified training system.
|
12 |
+
|
13 |
+
### Unified Training Architecture
|
14 |
+
|
15 |
+
Previously, the project contained two separate implementations of the model training logic: one in the command-line script (`scripts/train_model.py`) and another powering the backend of the web UI (`utils/training_manager.py`). This duplication led to inconsistencies and made maintenance difficult.
|
16 |
+
|
17 |
+
The system has been refactored to follow the **Don't Repeat Yourself (DRY)** principle:
|
18 |
+
|
19 |
+
1. **Central `TrainingEngine`**: A new `utils/training_engine.py` module was created to house the core, canonical training and cross-validation loop. This engine is now the single source of truth for how models are trained.
|
20 |
+
|
21 |
+
2. **Decoupled Data Structures**: Shared data classes like `TrainingConfig` and `TrainingStatus` were moved to a dedicated `utils/training_types.py` file. This resolved circular import errors and improved modularity.
|
22 |
+
|
23 |
+
3. **Refactored Interfaces**:
|
24 |
+
- The **CLI script** (`scripts/train_model.py`) is now a lightweight wrapper that parses command-line arguments and calls the `TrainingEngine`.
|
25 |
+
- The **UI backend** (`utils/training_manager.py`) now also uses the `TrainingEngine` to run training jobs submitted from the "Model Training Hub".
|
26 |
+
|
27 |
+
This unified architecture ensures that any improvements to the training process are immediately available to both developers using the CLI and users interacting with the web UI.
|
28 |
+
|
29 |
+
---
|
30 |
+
|
31 |
+
## 🛠️ How to Train Models
|
32 |
+
|
33 |
+
With the new unified architecture, you can train models using either the command line or the interactive web UI, depending on your needs.
|
34 |
+
|
35 |
+
### 1. CLI Training (For Developers & Automation)
|
36 |
+
|
37 |
+
The command-line interface is the ideal method for reproducible experiments, automated workflows, or training on a remote server. It provides full control over all training hyperparameters.
|
38 |
+
|
39 |
+
**Why use the CLI?**
|
40 |
+
|
41 |
+
- For scripting multiple training runs.
|
42 |
+
- For integration into CI/CD pipelines.
|
43 |
+
- When working in a non-GUI environment.
|
44 |
+
|
45 |
+
**Example Command:**
|
46 |
+
To run a 10-fold cross-validation for the `figure2` model, run the following from the project's root directory:
|
47 |
+
|
48 |
+
```bash
|
49 |
+
python scripts/train_model.py --model figure2 --epochs 15 --baseline --smooth --normalize
|
50 |
+
```
|
51 |
+
|
52 |
+
This command will:
|
53 |
+
|
54 |
+
- Load the default dataset from `datasets/rdwp`.
|
55 |
+
- Apply the specified preprocessing steps.
|
56 |
+
- Run the training using the central `TrainingEngine`.
|
57 |
+
- Save the final model weights to `outputs/weights/figure2_model.pth` and a detailed JSON log to `outputs/logs/`.
|
58 |
+
|
59 |
+
### 2. UI Training Hub (For Interactive Use)
|
60 |
+
|
61 |
+
The "Model Training Hub" within the web application provides a user-friendly, graphical interface for training models. It's designed for interactive experimentation and for users who may not be comfortable with the command line.
|
62 |
+
|
63 |
+
**Why use the UI?**
|
64 |
+
|
65 |
+
- To easily train models on your own uploaded datasets.
|
66 |
+
- To interactively tweak hyperparameters and see their effect.
|
67 |
+
- To monitor training progress in real-time with visual feedback.
|
68 |
+
|
69 |
+
**How to use it:**
|
70 |
+
|
71 |
+
1. Navigate to the **Model Training Hub** tab in the application.
|
72 |
+
2. **Configure Your Job**:
|
73 |
+
- Select a model architecture.
|
74 |
+
- Upload a new dataset or choose an existing one.
|
75 |
+
- Adjust training parameters like epochs, learning rate, and batch size.
|
76 |
+
3. Click **"🚀 Start Training"**.
|
77 |
+
4. The job will run in the background, and you can monitor its progress in the "Training Status" and "Training Progress" sections. Completed models and logs can be downloaded directly from the UI.
|
78 |
+
|
79 |
+
---
|
80 |
+
|
81 |
+
## Project Structure Overview
|
82 |
+
|
83 |
+
- `app.py`: Main Streamlit application entry point.
|
84 |
+
- `modules/`: Contains all major feature modules.
|
85 |
+
- `training_ui.py`: Renders the "Model Training Hub" tab.
|
86 |
+
- `scripts/`: Contains command-line tools.
|
87 |
+
- `train_model.py`: The CLI for running training jobs.
|
88 |
+
- `inspect_weights.py`: A diagnostic tool to check model weight files.
|
89 |
+
- `utils/`: Core utilities for the application.
|
90 |
+
- `training_engine.py`: **The new central training logic.**
|
91 |
+
- `training_manager.py`: The backend manager for UI-based training jobs.
|
92 |
+
- `training_types.py`: **New file for shared training data structures.**
|
93 |
+
- `models/`: Model definitions and the central model registry.
|
94 |
+
- `outputs/`: Default directory for saved model weights and training logs.
|
@@ -17,12 +17,8 @@ import json
|
|
17 |
from datetime import datetime, timedelta
|
18 |
|
19 |
from models.registry import choices as model_choices, get_model_info
|
20 |
-
from utils.training_manager import
|
21 |
-
|
22 |
-
TrainingConfig,
|
23 |
-
TrainingStatus,
|
24 |
-
TrainingJob,
|
25 |
-
)
|
26 |
|
27 |
|
28 |
def render_training_tab():
|
|
|
17 |
from datetime import datetime, timedelta
|
18 |
|
19 |
from models.registry import choices as model_choices, get_model_info
|
20 |
+
from utils.training_manager import get_training_manager, TrainingJob
|
21 |
+
from utils.training_types import TrainingConfig, TrainingStatus
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def render_training_tab():
|
@@ -7,6 +7,7 @@ from PIL import Image
|
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
from typing import Union
|
|
|
10 |
import time
|
11 |
from config import TARGET_LEN, LABEL_MAP, MODEL_WEIGHTS_DIR
|
12 |
from models.registry import choices, get_model_info
|
@@ -18,16 +19,13 @@ from modules.callbacks import (
|
|
18 |
reset_ephemeral_state,
|
19 |
log_message,
|
20 |
)
|
21 |
-
from core_logic import
|
22 |
-
get_sample_files,
|
23 |
-
load_model,
|
24 |
-
run_inference,
|
25 |
-
parse_spectrum_data,
|
26 |
-
label_file,
|
27 |
-
)
|
28 |
from utils.results_manager import ResultsManager
|
29 |
-
from utils.multifile import process_multiple_files
|
30 |
-
from utils.preprocessing import
|
|
|
|
|
|
|
31 |
from utils.confidence import calculate_softmax_confidence
|
32 |
|
33 |
|
@@ -69,9 +67,6 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
69 |
return Image.open(buf)
|
70 |
|
71 |
|
72 |
-
# //////////////////////////////////////////
|
73 |
-
|
74 |
-
|
75 |
def render_confidence_progress(
|
76 |
probs: np.ndarray,
|
77 |
labels: list[str] = ["Stable", "Weathered"],
|
@@ -132,9 +127,6 @@ def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
|
|
132 |
st.caption(f"**{k}:** {v}")
|
133 |
|
134 |
|
135 |
-
# //////////////////////////////////////////
|
136 |
-
|
137 |
-
|
138 |
def render_model_meta(model_choice: str):
|
139 |
info = get_model_info(model_choice)
|
140 |
emoji = info.get("emoji", "")
|
@@ -152,9 +144,6 @@ def render_model_meta(model_choice: str):
|
|
152 |
st.caption(desc)
|
153 |
|
154 |
|
155 |
-
# //////////////////////////////////////////
|
156 |
-
|
157 |
-
|
158 |
def get_confidence_description(logit_margin):
|
159 |
"""Get human-readable confidence description"""
|
160 |
if logit_margin > 1000:
|
@@ -167,9 +156,6 @@ def get_confidence_description(logit_margin):
|
|
167 |
return "LOW", "🔴"
|
168 |
|
169 |
|
170 |
-
# //////////////////////////////////////////
|
171 |
-
|
172 |
-
|
173 |
def render_sidebar():
|
174 |
with st.sidebar:
|
175 |
# Header
|
@@ -254,7 +240,6 @@ def render_sidebar():
|
|
254 |
)
|
255 |
|
256 |
|
257 |
-
# //////////////////////////////////////////
|
258 |
def render_input_column():
|
259 |
st.markdown("##### Data Input")
|
260 |
|
@@ -393,6 +378,7 @@ def render_input_column():
|
|
393 |
|
394 |
# Handle form submission
|
395 |
if submitted and inference_ready:
|
|
|
396 |
if st.session_state.get("batch_mode"):
|
397 |
batch_files = st.session_state.get("batch_files", [])
|
398 |
with st.spinner(f"Processing {len(batch_files)} files ..."):
|
@@ -405,7 +391,31 @@ def render_input_column():
|
|
405 |
)
|
406 |
else:
|
407 |
try:
|
408 |
-
x_raw, y_raw = parse_spectrum_data(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
# Validate that spectrum matches selected modality
|
411 |
selected_modality = st.session_state.get("modality_select", "raman")
|
@@ -430,7 +440,10 @@ def render_input_column():
|
|
430 |
else:
|
431 |
st.stop() # Stop processing until user confirms
|
432 |
|
433 |
-
x_resampled, y_resampled =
|
|
|
|
|
|
|
434 |
st.session_state.update(
|
435 |
{
|
436 |
"x_raw": x_raw,
|
@@ -444,9 +457,6 @@ def render_input_column():
|
|
444 |
st.error(f"Error processing spectrum data: {e}")
|
445 |
|
446 |
|
447 |
-
# //////////////////////////////////////////
|
448 |
-
|
449 |
-
|
450 |
def render_results_column():
|
451 |
# Get the current mode and check for batch results
|
452 |
is_batch_mode = st.session_state.get("batch_mode", False)
|
@@ -483,7 +493,7 @@ def render_results_column():
|
|
483 |
else None
|
484 |
),
|
485 |
modality=st.session_state.get("modality_select", "raman"),
|
486 |
-
|
487 |
)
|
488 |
if prediction is None:
|
489 |
st.error(
|
@@ -491,6 +501,11 @@ def render_results_column():
|
|
491 |
)
|
492 |
st.stop() # prevents the rest of the code in this block from executing
|
493 |
|
|
|
|
|
|
|
|
|
|
|
494 |
log_message(
|
495 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
496 |
)
|
@@ -556,6 +571,8 @@ def render_results_column():
|
|
556 |
"⚠️ Model choice is not defined. Please select a model from the sidebar."
|
557 |
)
|
558 |
st.stop()
|
|
|
|
|
559 |
model_path = os.path.join(MODEL_WEIGHTS_DIR, f"{model_choice}_model.pth")
|
560 |
mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
|
561 |
file_hash = (
|
@@ -573,88 +590,188 @@ def render_results_column():
|
|
573 |
)
|
574 |
|
575 |
if active_tab == "Details":
|
576 |
-
st.markdown('<div class="expander-results">', unsafe_allow_html=True)
|
577 |
# Use a dynamic and informative title for the expander
|
578 |
with st.expander(f"Results for {filename}", expanded=True):
|
579 |
|
580 |
-
#
|
581 |
-
|
582 |
-
|
|
|
|
|
583 |
|
584 |
-
|
585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
else "🟡" if max_confidence >= 0.6 else "🔴"
|
592 |
)
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
)
|
598 |
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
)
|
610 |
-
|
611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
|
613 |
-
st.
|
614 |
-
# --- END: STREAMLINED METRICS ---
|
615 |
|
616 |
-
|
617 |
-
|
|
|
|
|
|
|
|
|
618 |
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
|
|
|
|
626 |
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
|
|
|
|
|
|
|
|
|
|
632 |
)
|
633 |
-
# Default values to prevent further errors
|
634 |
-
stable_prob, weathered_prob = 0.0, 0.0
|
635 |
-
is_stable_predicted, is_weathered_predicted = (
|
636 |
-
int(prediction) == 0
|
637 |
-
), (int(prediction) == 1)
|
638 |
-
|
639 |
-
st.markdown(
|
640 |
-
f"""
|
641 |
-
<div style="font-family: 'Fira Code', monospace;">
|
642 |
-
Stable (Unweathered)<br>
|
643 |
-
{create_bullet_bar(stable_prob, predicted=is_stable_predicted)}<br><br>
|
644 |
-
Weathered (Degraded)<br>
|
645 |
-
{create_bullet_bar(weathered_prob, predicted=is_weathered_predicted)}
|
646 |
-
</div>
|
647 |
-
""",
|
648 |
-
unsafe_allow_html=True,
|
649 |
-
)
|
650 |
|
651 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
|
653 |
# METADATA FOOTER
|
654 |
st.caption(
|
655 |
-
f"Analyzed with **{
|
656 |
)
|
657 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
658 |
|
659 |
elif active_tab == "Technical":
|
660 |
with st.container():
|
@@ -879,9 +996,6 @@ def render_results_column():
|
|
879 |
|
880 |
# Technical details
|
881 |
# MODIFIED: Wrap the expander in a div with the 'expander-advanced' class
|
882 |
-
st.markdown(
|
883 |
-
'<div class="expander-advanced">', unsafe_allow_html=True
|
884 |
-
)
|
885 |
with st.expander("🔧 Technical Details", expanded=False):
|
886 |
st.markdown(
|
887 |
"""
|
@@ -902,9 +1016,6 @@ def render_results_column():
|
|
902 |
- Normalization: None (preserves intensity relationships)
|
903 |
"""
|
904 |
)
|
905 |
-
st.markdown(
|
906 |
-
"</div>", unsafe_allow_html=True
|
907 |
-
) # Close the wrapper div
|
908 |
|
909 |
render_time = time.time() - start_render
|
910 |
log_message(
|
@@ -987,9 +1098,6 @@ def render_results_column():
|
|
987 |
)
|
988 |
|
989 |
|
990 |
-
# //////////////////////////////////////////
|
991 |
-
|
992 |
-
|
993 |
def render_comparison_tab():
|
994 |
"""Render the multi-model comparison interface"""
|
995 |
import streamlit as st
|
@@ -1001,7 +1109,7 @@ def render_comparison_tab():
|
|
1001 |
get_models_metadata,
|
1002 |
)
|
1003 |
from utils.results_manager import ResultsManager
|
1004 |
-
from core_logic import get_sample_files, run_inference
|
1005 |
from utils.preprocessing import preprocess_spectrum
|
1006 |
from utils.multifile import parse_spectrum_data
|
1007 |
import numpy as np
|
@@ -1159,8 +1267,16 @@ def render_comparison_tab():
|
|
1159 |
start_time = time.time()
|
1160 |
|
1161 |
# Run inference
|
|
|
|
|
|
|
1162 |
prediction, logits_list, probs, inference_time, logits = (
|
1163 |
-
run_inference(
|
|
|
|
|
|
|
|
|
|
|
1164 |
)
|
1165 |
|
1166 |
processing_time = time.time() - start_time
|
@@ -1587,15 +1703,9 @@ def render_comparison_tab():
|
|
1587 |
)
|
1588 |
|
1589 |
|
1590 |
-
# //////////////////////////////////////////
|
1591 |
-
|
1592 |
-
|
1593 |
from utils.performance_tracker import display_performance_dashboard
|
1594 |
|
1595 |
|
1596 |
def render_performance_tab():
|
1597 |
"""Render the performance tracking and analysis tab."""
|
1598 |
display_performance_dashboard()
|
1599 |
-
|
1600 |
-
|
1601 |
-
# //////////////////////////////////////////
|
|
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
from typing import Union
|
10 |
+
import uuid
|
11 |
import time
|
12 |
from config import TARGET_LEN, LABEL_MAP, MODEL_WEIGHTS_DIR
|
13 |
from models.registry import choices, get_model_info
|
|
|
19 |
reset_ephemeral_state,
|
20 |
log_message,
|
21 |
)
|
22 |
+
from core_logic import get_sample_files, load_model, run_inference, label_file
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from utils.results_manager import ResultsManager
|
24 |
+
from utils.multifile import process_multiple_files, parse_spectrum_data
|
25 |
+
from utils.preprocessing import (
|
26 |
+
validate_spectrum_modality,
|
27 |
+
preprocess_spectrum,
|
28 |
+
)
|
29 |
from utils.confidence import calculate_softmax_confidence
|
30 |
|
31 |
|
|
|
67 |
return Image.open(buf)
|
68 |
|
69 |
|
|
|
|
|
|
|
70 |
def render_confidence_progress(
|
71 |
probs: np.ndarray,
|
72 |
labels: list[str] = ["Stable", "Weathered"],
|
|
|
127 |
st.caption(f"**{k}:** {v}")
|
128 |
|
129 |
|
|
|
|
|
|
|
130 |
def render_model_meta(model_choice: str):
|
131 |
info = get_model_info(model_choice)
|
132 |
emoji = info.get("emoji", "")
|
|
|
144 |
st.caption(desc)
|
145 |
|
146 |
|
|
|
|
|
|
|
147 |
def get_confidence_description(logit_margin):
|
148 |
"""Get human-readable confidence description"""
|
149 |
if logit_margin > 1000:
|
|
|
156 |
return "LOW", "🔴"
|
157 |
|
158 |
|
|
|
|
|
|
|
159 |
def render_sidebar():
|
160 |
with st.sidebar:
|
161 |
# Header
|
|
|
240 |
)
|
241 |
|
242 |
|
|
|
243 |
def render_input_column():
|
244 |
st.markdown("##### Data Input")
|
245 |
|
|
|
378 |
|
379 |
# Handle form submission
|
380 |
if submitted and inference_ready:
|
381 |
+
st.session_state["run_uuid"] = uuid.uuid4().hex[:8]
|
382 |
if st.session_state.get("batch_mode"):
|
383 |
batch_files = st.session_state.get("batch_files", [])
|
384 |
with st.spinner(f"Processing {len(batch_files)} files ..."):
|
|
|
391 |
)
|
392 |
else:
|
393 |
try:
|
394 |
+
x_raw, y_raw = parse_spectrum_data(
|
395 |
+
st.session_state["input_text"],
|
396 |
+
filename=st.session_state.get("filename", "unknown"),
|
397 |
+
)
|
398 |
+
|
399 |
+
# QC Summary
|
400 |
+
st.session_state["qc_summary"] = {
|
401 |
+
"n_points": len(x_raw),
|
402 |
+
"x_min": f"{np.min(x_raw):.1f}",
|
403 |
+
"x_max": f"{np.max(x_raw):.1f}",
|
404 |
+
"monotonic_x": bool(np.all(np.diff(x_raw) > 0)),
|
405 |
+
"nan_free": not (
|
406 |
+
np.any(np.isnan(x_raw)) or np.any(np.isnan(y_raw))
|
407 |
+
),
|
408 |
+
"variance_proxy": f"{np.var(y_raw):.2e}",
|
409 |
+
}
|
410 |
+
|
411 |
+
# Preprocessing parameters
|
412 |
+
preproc_params = {
|
413 |
+
"target_len": TARGET_LEN,
|
414 |
+
"modality": st.session_state.get("modality_select", "raman"),
|
415 |
+
"do_baseline": True,
|
416 |
+
"do_smooth": True,
|
417 |
+
"do_normalize": True,
|
418 |
+
}
|
419 |
|
420 |
# Validate that spectrum matches selected modality
|
421 |
selected_modality = st.session_state.get("modality_select", "raman")
|
|
|
440 |
else:
|
441 |
st.stop() # Stop processing until user confirms
|
442 |
|
443 |
+
x_resampled, y_resampled = preprocess_spectrum(
|
444 |
+
x_raw, y_raw, **preproc_params
|
445 |
+
)
|
446 |
+
st.session_state["preproc_params"] = preproc_params
|
447 |
st.session_state.update(
|
448 |
{
|
449 |
"x_raw": x_raw,
|
|
|
457 |
st.error(f"Error processing spectrum data: {e}")
|
458 |
|
459 |
|
|
|
|
|
|
|
460 |
def render_results_column():
|
461 |
# Get the current mode and check for batch results
|
462 |
is_batch_mode = st.session_state.get("batch_mode", False)
|
|
|
493 |
else None
|
494 |
),
|
495 |
modality=st.session_state.get("modality_select", "raman"),
|
496 |
+
cache_key=cache_key,
|
497 |
)
|
498 |
if prediction is None:
|
499 |
st.error(
|
|
|
501 |
)
|
502 |
st.stop() # prevents the rest of the code in this block from executing
|
503 |
|
504 |
+
# Store results in session state for the Details tab
|
505 |
+
st.session_state["prediction"] = prediction
|
506 |
+
st.session_state["probs"] = probs
|
507 |
+
st.session_state["inference_time"] = inference_time
|
508 |
+
|
509 |
log_message(
|
510 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
511 |
)
|
|
|
571 |
"⚠️ Model choice is not defined. Please select a model from the sidebar."
|
572 |
)
|
573 |
st.stop()
|
574 |
+
model_info = get_model_info(model_choice)
|
575 |
+
st.session_state["model_info"] = model_info
|
576 |
model_path = os.path.join(MODEL_WEIGHTS_DIR, f"{model_choice}_model.pth")
|
577 |
mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
|
578 |
file_hash = (
|
|
|
590 |
)
|
591 |
|
592 |
if active_tab == "Details":
|
|
|
593 |
# Use a dynamic and informative title for the expander
|
594 |
with st.expander(f"Results for {filename}", expanded=True):
|
595 |
|
596 |
+
# ...inside the Details tab, after metrics...
|
597 |
+
|
598 |
+
import json, math, uuid
|
599 |
+
|
600 |
+
st.subheader("Probability Breakdown")
|
601 |
|
602 |
+
def _entropy(ps):
|
603 |
+
ps = [max(min(float(p), 1.0), 1e-12) for p in ps]
|
604 |
+
return -sum(p * math.log(p) for p in ps)
|
605 |
+
|
606 |
+
def _badge(text, kind="info"):
|
607 |
+
palette = {
|
608 |
+
"info": ("#334155", "#e2e8f0"),
|
609 |
+
"warn": ("#7c2d12", "#fde68a"),
|
610 |
+
"good": ("#064e3b", "#bbf7d0"),
|
611 |
+
"bad": ("#7f1d1d", "#fecaca"),
|
612 |
+
}
|
613 |
+
bg, fg = palette.get(kind, palette["info"])
|
614 |
+
st.markdown(
|
615 |
+
f"<span style='background:{bg};color:{fg};padding:4px 8px;"
|
616 |
+
f"border-radius:6px;font-size:0.80rem;white-space:nowrap'>{text}</span>",
|
617 |
+
unsafe_allow_html=True,
|
618 |
+
)
|
619 |
+
|
620 |
+
def _render_prob_row(label: str, prob: float, is_pred: bool):
|
621 |
+
c1, c2, c3 = st.columns([2, 7, 3])
|
622 |
+
with c1:
|
623 |
+
st.write(label)
|
624 |
+
with c2:
|
625 |
+
st.progress(min(max(prob, 0.0), 1.0))
|
626 |
+
with c3:
|
627 |
+
suffix = " \u2190 Predicted" if is_pred else ""
|
628 |
+
st.write(f"{prob:.1%}{suffix}")
|
629 |
+
|
630 |
+
probs = st.session_state.get("probs")
|
631 |
+
prediction = st.session_state.get("prediction")
|
632 |
+
inference_time = float(st.session_state.get("inference_time", 0.0))
|
633 |
+
|
634 |
+
if probs is None or len(probs) != 2:
|
635 |
+
st.error(
|
636 |
+
"❌ Probability values are missing or invalid. Check the inference process."
|
637 |
+
)
|
638 |
+
stable_prob, weathered_prob = 0.0, 0.0
|
639 |
+
else:
|
640 |
+
stable_prob, weathered_prob = float(probs[0]), float(probs[1])
|
641 |
|
642 |
+
is_stable_predicted = (
|
643 |
+
(int(prediction) == 0)
|
644 |
+
if prediction is not None
|
645 |
+
else (stable_prob >= weathered_prob)
|
|
|
646 |
)
|
647 |
+
is_weathered_predicted = (
|
648 |
+
(int(prediction) == 1)
|
649 |
+
if prediction is not None
|
650 |
+
else (weathered_prob > stable_prob)
|
651 |
)
|
652 |
|
653 |
+
margin = abs(stable_prob - weathered_prob)
|
654 |
+
entropy = _entropy([stable_prob, weathered_prob])
|
655 |
+
thresh = float(st.session_state.get("decision_threshold", 0.5))
|
656 |
+
cal = st.session_state.get("calibration", {}) or {}
|
657 |
+
cal_enabled = bool(cal.get("enabled", False))
|
658 |
+
ece = cal.get("ece", None)
|
659 |
+
|
660 |
+
ABSTAIN_TAU = 0.10
|
661 |
+
OOD_MAX_SOFT = 0.60
|
662 |
+
max_softmax = max(stable_prob, weathered_prob)
|
663 |
+
|
664 |
+
colA, colB, colC, colD = st.columns([3, 3, 3, 3])
|
665 |
+
with colA:
|
666 |
+
st.metric(
|
667 |
+
"Predicted",
|
668 |
+
"Stable" if is_stable_predicted else "Weathered",
|
669 |
)
|
670 |
+
with colB:
|
671 |
+
st.metric("Decision Margin", f"{margin:.2f}")
|
672 |
+
with colC:
|
673 |
+
st.metric("Entropy", f"{entropy:.3f}")
|
674 |
+
with colD:
|
675 |
+
st.metric("Threshold", f"{thresh:.2f}")
|
676 |
+
|
677 |
+
row = st.columns([3, 3, 6])
|
678 |
+
with row[0]:
|
679 |
+
if margin < ABSTAIN_TAU:
|
680 |
+
_badge("Low margin — consider abstain / re-measure", "warn")
|
681 |
+
with row[1]:
|
682 |
+
if max_softmax < OOD_MAX_SOFT:
|
683 |
+
_badge("Low confidence — possible OOD", "bad")
|
684 |
+
with row[2]:
|
685 |
+
if cal_enabled:
|
686 |
+
_badge(
|
687 |
+
(
|
688 |
+
f"Calibrated (ECE={ece:.2%})"
|
689 |
+
if isinstance(ece, (int, float))
|
690 |
+
else "Calibrated"
|
691 |
+
),
|
692 |
+
"good",
|
693 |
+
)
|
694 |
+
else:
|
695 |
+
_badge(
|
696 |
+
"Uncalibrated — probabilities may be miscalibrated",
|
697 |
+
"info",
|
698 |
+
)
|
699 |
|
700 |
+
st.write("")
|
|
|
701 |
|
702 |
+
_render_prob_row(
|
703 |
+
"Stable (Unweathered)", stable_prob, is_stable_predicted
|
704 |
+
)
|
705 |
+
_render_prob_row(
|
706 |
+
"Weathered (Degraded)", weathered_prob, is_weathered_predicted
|
707 |
+
)
|
708 |
|
709 |
+
qc = st.session_state.get("qc_summary", {}) or {}
|
710 |
+
pp = st.session_state.get("preproc_params", {}) or {}
|
711 |
+
model_info = st.session_state.get("model_info", {}) or {}
|
712 |
+
run_info = {
|
713 |
+
"model": model_choice,
|
714 |
+
"inference_time_s": inference_time,
|
715 |
+
"run_uuid": st.session_state.get("run_uuid", ""),
|
716 |
+
"app_commit": st.session_state.get("app_commit", "unknown"),
|
717 |
+
}
|
718 |
|
719 |
+
with st.expander("Input QC"):
|
720 |
+
st.write(
|
721 |
+
{
|
722 |
+
"n_points": qc.get("n_points", "N/A"),
|
723 |
+
"x_min_cm-1": qc.get("x_min", "N/A"),
|
724 |
+
"x_max_cm-1": qc.get("x_max", "N/A"),
|
725 |
+
"monotonic_x": qc.get("monotonic_x", "N/A"),
|
726 |
+
"nan_free": qc.get("nan_free", "N/A"),
|
727 |
+
"variance_proxy": qc.get("variance_proxy", "N/A"),
|
728 |
+
}
|
729 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
730 |
|
731 |
+
with st.expander("Preprocessing (applied)"):
|
732 |
+
st.write(pp)
|
733 |
+
|
734 |
+
with st.expander("Model & Run"):
|
735 |
+
st.write(
|
736 |
+
{
|
737 |
+
"model_name": model_info.get("name", model_choice),
|
738 |
+
"version": model_info.get("version", "n/a"),
|
739 |
+
"weights_mtime": model_info.get("weights_mtime", "n/a"),
|
740 |
+
"cv_accuracy": model_info.get("cv_accuracy", "n/a"),
|
741 |
+
"class_priors": model_info.get("class_priors", "n/a"),
|
742 |
+
**run_info,
|
743 |
+
}
|
744 |
+
)
|
745 |
+
|
746 |
+
export_payload = {
|
747 |
+
"prediction": "stable" if is_stable_predicted else "weathered",
|
748 |
+
"probs": {"stable": stable_prob, "weathered": weathered_prob},
|
749 |
+
"margin": margin,
|
750 |
+
"entropy": entropy,
|
751 |
+
"threshold": thresh,
|
752 |
+
"calibration": {
|
753 |
+
"enabled": cal_enabled,
|
754 |
+
"ece": ece,
|
755 |
+
"method": cal.get("method"),
|
756 |
+
"T": cal.get("T"),
|
757 |
+
},
|
758 |
+
"qc": qc,
|
759 |
+
"preprocessing": pp,
|
760 |
+
"model_info": model_info,
|
761 |
+
"run_info": run_info,
|
762 |
+
}
|
763 |
+
fname = f"result_{run_info['run_uuid'] or uuid.uuid4().hex}.json"
|
764 |
+
st.download_button(
|
765 |
+
"Download result JSON",
|
766 |
+
json.dumps(export_payload, indent=2),
|
767 |
+
file_name=fname,
|
768 |
+
mime="application/json",
|
769 |
+
)
|
770 |
|
771 |
# METADATA FOOTER
|
772 |
st.caption(
|
773 |
+
f"Analyzed with **{run_info['model']}** in **{inference_time:.2f}s**."
|
774 |
)
|
|
|
775 |
|
776 |
elif active_tab == "Technical":
|
777 |
with st.container():
|
|
|
996 |
|
997 |
# Technical details
|
998 |
# MODIFIED: Wrap the expander in a div with the 'expander-advanced' class
|
|
|
|
|
|
|
999 |
with st.expander("🔧 Technical Details", expanded=False):
|
1000 |
st.markdown(
|
1001 |
"""
|
|
|
1016 |
- Normalization: None (preserves intensity relationships)
|
1017 |
"""
|
1018 |
)
|
|
|
|
|
|
|
1019 |
|
1020 |
render_time = time.time() - start_render
|
1021 |
log_message(
|
|
|
1098 |
)
|
1099 |
|
1100 |
|
|
|
|
|
|
|
1101 |
def render_comparison_tab():
|
1102 |
"""Render the multi-model comparison interface"""
|
1103 |
import streamlit as st
|
|
|
1109 |
get_models_metadata,
|
1110 |
)
|
1111 |
from utils.results_manager import ResultsManager
|
1112 |
+
from core_logic import get_sample_files, run_inference
|
1113 |
from utils.preprocessing import preprocess_spectrum
|
1114 |
from utils.multifile import parse_spectrum_data
|
1115 |
import numpy as np
|
|
|
1267 |
start_time = time.time()
|
1268 |
|
1269 |
# Run inference
|
1270 |
+
cache_key = hashlib.md5(
|
1271 |
+
f"{y_processed.tobytes()}{model_name}".encode()
|
1272 |
+
).hexdigest()
|
1273 |
prediction, logits_list, probs, inference_time, logits = (
|
1274 |
+
run_inference(
|
1275 |
+
y_processed,
|
1276 |
+
model_name,
|
1277 |
+
modality=modality,
|
1278 |
+
cache_key=cache_key,
|
1279 |
+
)
|
1280 |
)
|
1281 |
|
1282 |
processing_time = time.time() - start_time
|
|
|
1703 |
)
|
1704 |
|
1705 |
|
|
|
|
|
|
|
1706 |
from utils.performance_tracker import display_performance_dashboard
|
1707 |
|
1708 |
|
1709 |
def render_performance_tab():
|
1710 |
"""Render the performance tracking and analysis tab."""
|
1711 |
display_performance_dashboard()
|
|
|
|
|
|
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4418520
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:852247bf0540aa947c9887a7e004c0858d622cfa0413e9b26bd9f5dab359ad5e
|
3 |
size 4418520
|
@@ -27,7 +27,8 @@ from modules.modern_ml_architecture import (
|
|
27 |
ModernMLPipeline,
|
28 |
)
|
29 |
from modules.enhanced_data_pipeline import EnhancedDataPipeline
|
30 |
-
from core_logic import load_model
|
|
|
31 |
from models.registry import choices
|
32 |
from config import TARGET_LEN
|
33 |
|
|
|
27 |
ModernMLPipeline,
|
28 |
)
|
29 |
from modules.enhanced_data_pipeline import EnhancedDataPipeline
|
30 |
+
from core_logic import load_model
|
31 |
+
from utils.multifile import parse_spectrum_data
|
32 |
from models.registry import choices
|
33 |
from config import TARGET_LEN
|
34 |
|
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
4 |
from datetime import datetime
|
5 |
import argparse, numpy as np, torch
|
@@ -10,6 +11,9 @@ from sklearn.metrics import confusion_matrix
|
|
10 |
import random
|
11 |
import json
|
12 |
|
|
|
|
|
|
|
13 |
# Reproducibility
|
14 |
SEED = 42
|
15 |
random.seed(SEED)
|
@@ -36,8 +40,26 @@ parser.add_argument("--batch-size", type=int, default=16)
|
|
36 |
parser.add_argument("--epochs", type=int, default=10)
|
37 |
parser.add_argument("--learning-rate", type=float, default=1e-3)
|
38 |
parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
args = parser.parse_args()
|
|
|
41 |
|
42 |
# Constants
|
43 |
# Raman-only dataset (RDWP)
|
@@ -48,6 +70,18 @@ NUM_FOLDS = 10
|
|
48 |
# Ensure output dirs exist
|
49 |
os.makedirs("outputs", exist_ok=True)
|
50 |
os.makedirs("outputs/logs", exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
print("Preprocessing Configuration:")
|
53 |
print(f" Resample to : {args.target_len}")
|
@@ -55,6 +89,27 @@ print(f" Resample to : {args.target_len}")
|
|
55 |
print(f" Baseline Correct: {'✅' if args.baseline else '❌'}")
|
56 |
print(f" Smoothing : {'✅' if args.smooth else '❌'}")
|
57 |
print(f" Normalization : {'✅' if args.normalize else '❌'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# Load + Preprocess data
|
60 |
print("🔄 Loading and preprocessing data ...")
|
@@ -73,24 +128,52 @@ print(f"🔍 Using model: {args.model}")
|
|
73 |
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
|
74 |
fold_accuracies = []
|
75 |
all_conf_matrices = []
|
|
|
|
|
|
|
76 |
|
77 |
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
|
78 |
print(f"\n🔁 Fold {fold}/{NUM_FOLDS}")
|
|
|
|
|
|
|
79 |
|
80 |
X_train, X_val = X[train_idx], X[val_idx]
|
81 |
y_train, y_val = y[train_idx], y[val_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
train_loader = DataLoader(
|
84 |
TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
|
85 |
batch_size=args.batch_size, shuffle=True)
|
86 |
val_loader = DataLoader(
|
87 |
TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)))
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Model selection
|
90 |
model = build_model(args.model, args.target_len).to(DEVICE)
|
91 |
|
92 |
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
93 |
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
|
94 |
|
95 |
for epoch in range(args.epochs):
|
96 |
model.train()
|
@@ -98,12 +181,18 @@ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
|
|
98 |
for inputs, labels in train_loader:
|
99 |
inputs = inputs.unsqueeze(1).to(DEVICE)
|
100 |
labels = labels.to(DEVICE)
|
|
|
|
|
|
|
|
|
101 |
|
102 |
optimizer.zero_grad()
|
103 |
loss = criterion(model(inputs), labels)
|
104 |
loss.backward()
|
105 |
optimizer.step()
|
106 |
RUNNING_LOSS += loss.item()
|
|
|
|
|
107 |
|
108 |
# After fold loop (outside the epoch loop), print 1 line:
|
109 |
print(f"✅ Fold {fold} done. Final loss: {RUNNING_LOSS:.4f}")
|
@@ -169,4 +258,6 @@ def save_diagnostics_log(fold_acc, confs, args_param, output_path):
|
|
169 |
print(f"🧠 Diagnostics written to {output_path}")
|
170 |
|
171 |
log_path = f"outputs/logs/raman_{args.model}_diagnostics.json"
|
172 |
-
save_diagnostics_log(fold_accuracies, all_conf_matrices, args, log_path)
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
|
4 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
5 |
from datetime import datetime
|
6 |
import argparse, numpy as np, torch
|
|
|
11 |
import random
|
12 |
import json
|
13 |
|
14 |
+
from utils.training_engine import TrainingEngine
|
15 |
+
from utils.training_manager import TrainingConfig
|
16 |
+
|
17 |
# Reproducibility
|
18 |
SEED = 42
|
19 |
random.seed(SEED)
|
|
|
40 |
parser.add_argument("--epochs", type=int, default=10)
|
41 |
parser.add_argument("--learning-rate", type=float, default=1e-3)
|
42 |
parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
|
43 |
+
def parse_args():
|
44 |
+
"""Parses command-line arguments for training."""
|
45 |
+
parser = argparse.ArgumentParser(
|
46 |
+
description="Run 10-fold CV on Raman data with optional preprocessing."
|
47 |
+
)
|
48 |
+
parser.add_argument("--target-len", type=int, default=500)
|
49 |
+
parser.add_argument("--baseline", action="store_true")
|
50 |
+
parser.add_argument("--smooth", action="store_true")
|
51 |
+
parser.add_argument("--normalize", action="store_true")
|
52 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
53 |
+
parser.add_argument("--epochs", type=int, default=10)
|
54 |
+
parser.add_argument("--learning-rate", type=float, default=1e-3)
|
55 |
+
parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
|
56 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
|
57 |
+
parser.add_argument("--dataset-path", type=str, default="datasets/rdwp")
|
58 |
+
parser.add_argument("--num-folds", type=int, default=10)
|
59 |
+
parser.add_argument("--cv-strategy", type=str, default="stratified_kfold", choices=["stratified_kfold", "kfold"])
|
60 |
|
61 |
args = parser.parse_args()
|
62 |
+
return parser.parse_args()
|
63 |
|
64 |
# Constants
|
65 |
# Raman-only dataset (RDWP)
|
|
|
70 |
# Ensure output dirs exist
|
71 |
os.makedirs("outputs", exist_ok=True)
|
72 |
os.makedirs("outputs/logs", exist_ok=True)
|
73 |
+
def cli_progress_callback(progress_data: dict):
|
74 |
+
"""A simple callback to print progress to the console."""
|
75 |
+
if progress_data["type"] == "fold_start":
|
76 |
+
print(f"\n🔁 Fold {progress_data['fold']}/{progress_data['total_folds']}")
|
77 |
+
elif progress_data["type"] == "epoch_end":
|
78 |
+
# Print progress on the same line
|
79 |
+
print(
|
80 |
+
f" Epoch {progress_data['epoch']}/{progress_data['total_epochs']} | Loss: {progress_data['loss']:.4f}",
|
81 |
+
end="\r",
|
82 |
+
)
|
83 |
+
elif progress_data["type"] == "fold_end":
|
84 |
+
print(f"\n✅ Fold {progress_data['fold']} Accuracy: {progress_data['accuracy'] * 100:.2f}%")
|
85 |
|
86 |
print("Preprocessing Configuration:")
|
87 |
print(f" Resample to : {args.target_len}")
|
|
|
89 |
print(f" Baseline Correct: {'✅' if args.baseline else '❌'}")
|
90 |
print(f" Smoothing : {'✅' if args.smooth else '❌'}")
|
91 |
print(f" Normalization : {'✅' if args.normalize else '❌'}")
|
92 |
+
def save_diagnostics_log(results: dict, config: TrainingConfig, output_path: str):
|
93 |
+
"""Saves a JSON log file with training diagnostics."""
|
94 |
+
fold_metrics = [
|
95 |
+
{"fold": i + 1, "accuracy": float(acc), "confusion_matrix": cm}
|
96 |
+
for i, (acc, cm) in enumerate(
|
97 |
+
zip(results["fold_accuracies"], results["confusion_matrices"])
|
98 |
+
)
|
99 |
+
]
|
100 |
+
log = {
|
101 |
+
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
102 |
+
"model_name": config.model_name,
|
103 |
+
"config": config.to_dict(),
|
104 |
+
"fold_metrics": fold_metrics,
|
105 |
+
"overall": {
|
106 |
+
"mean_accuracy": results["mean_accuracy"],
|
107 |
+
"std_accuracy": results["std_accuracy"],
|
108 |
+
},
|
109 |
+
}
|
110 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
111 |
+
json.dump(log, f, indent=2)
|
112 |
+
print(f"🧠 Diagnostics written to {output_path}")
|
113 |
|
114 |
# Load + Preprocess data
|
115 |
print("🔄 Loading and preprocessing data ...")
|
|
|
128 |
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
|
129 |
fold_accuracies = []
|
130 |
all_conf_matrices = []
|
131 |
+
def main():
|
132 |
+
"""Main function to run the training process from the CLI."""
|
133 |
+
args = parse_args()
|
134 |
|
135 |
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
|
136 |
print(f"\n🔁 Fold {fold}/{NUM_FOLDS}")
|
137 |
+
# Ensure output dirs exist
|
138 |
+
os.makedirs("outputs/weights", exist_ok=True)
|
139 |
+
os.makedirs("outputs/logs", exist_ok=True)
|
140 |
|
141 |
X_train, X_val = X[train_idx], X[val_idx]
|
142 |
y_train, y_val = y[train_idx], y[val_idx]
|
143 |
+
# Create TrainingConfig from CLI args
|
144 |
+
config = TrainingConfig(
|
145 |
+
model_name=args.model,
|
146 |
+
dataset_path=args.dataset_path,
|
147 |
+
target_len=args.target_len,
|
148 |
+
batch_size=args.batch_size,
|
149 |
+
epochs=args.epochs,
|
150 |
+
learning_rate=args.learning_rate,
|
151 |
+
num_folds=args.num_folds,
|
152 |
+
baseline_correction=args.baseline,
|
153 |
+
smoothing=args.smooth,
|
154 |
+
normalization=args.normalize,
|
155 |
+
device=args.device,
|
156 |
+
cv_strategy=args.cv_strategy,
|
157 |
+
)
|
158 |
|
159 |
train_loader = DataLoader(
|
160 |
TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
|
161 |
batch_size=args.batch_size, shuffle=True)
|
162 |
val_loader = DataLoader(
|
163 |
TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)))
|
164 |
+
print("🔄 Loading and preprocessing data...")
|
165 |
+
X, y = preprocess_dataset(config.dataset_path, target_len=config.target_len)
|
166 |
+
print(f"✅ Data Loaded: {X.shape[0]} samples, {X.shape[1]} features each.")
|
167 |
+
print(f"🔍 Using model: {config.model_name}")
|
168 |
|
169 |
# Model selection
|
170 |
model = build_model(args.model, args.target_len).to(DEVICE)
|
171 |
|
172 |
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
173 |
criterion = torch.nn.CrossEntropyLoss()
|
174 |
+
# Run training
|
175 |
+
engine = TrainingEngine(config)
|
176 |
+
results = engine.run(X, y, progress_callback=cli_progress_callback)
|
177 |
|
178 |
for epoch in range(args.epochs):
|
179 |
model.train()
|
|
|
181 |
for inputs, labels in train_loader:
|
182 |
inputs = inputs.unsqueeze(1).to(DEVICE)
|
183 |
labels = labels.to(DEVICE)
|
184 |
+
# Save final model and logs
|
185 |
+
model_path = f"outputs/weights/{config.model_name}_model.pth"
|
186 |
+
torch.save(results["model_state_dict"], model_path)
|
187 |
+
print(f"\n✅ Model saved to {model_path}")
|
188 |
|
189 |
optimizer.zero_grad()
|
190 |
loss = criterion(model(inputs), labels)
|
191 |
loss.backward()
|
192 |
optimizer.step()
|
193 |
RUNNING_LOSS += loss.item()
|
194 |
+
log_path = f"outputs/logs/{config.model_name}_cli_diagnostics.json"
|
195 |
+
save_diagnostics_log(results, config, log_path)
|
196 |
|
197 |
# After fold loop (outside the epoch loop), print 1 line:
|
198 |
print(f"✅ Fold {fold} done. Final loss: {RUNNING_LOSS:.4f}")
|
|
|
258 |
print(f"🧠 Diagnostics written to {output_path}")
|
259 |
|
260 |
log_path = f"outputs/logs/raman_{args.model}_diagnostics.json"
|
261 |
+
save_diagnostics_log(fold_accuracies, all_conf_matrices, args, log_path)
|
262 |
+
if __name__ == "__main__":
|
263 |
+
main()
|
@@ -11,6 +11,7 @@ import json
|
|
11 |
import csv
|
12 |
import io
|
13 |
from pathlib import Path
|
|
|
14 |
|
15 |
from .preprocessing import preprocess_spectrum
|
16 |
from .errors import ErrorHandler, safe_execute
|
@@ -35,7 +36,7 @@ def detect_file_format(filename: str, content: str) -> str:
|
|
35 |
try:
|
36 |
json.loads(content)
|
37 |
return "json"
|
38 |
-
except:
|
39 |
pass
|
40 |
elif suffix == ".csv":
|
41 |
return "csv"
|
@@ -50,7 +51,7 @@ def detect_file_format(filename: str, content: str) -> str:
|
|
50 |
try:
|
51 |
json.loads(content)
|
52 |
return "json"
|
53 |
-
except:
|
54 |
pass
|
55 |
|
56 |
# Try CSV (look for commas in first few lines)
|
@@ -63,12 +64,7 @@ def detect_file_format(filename: str, content: str) -> str:
|
|
63 |
return "txt"
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
def parse_json_spectrum(
|
70 |
-
content: str, filename: str = "unknown"
|
71 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
72 |
"""
|
73 |
Parse spectrum data from JSON format.
|
74 |
|
@@ -79,7 +75,7 @@ def parse_json_spectrum(
|
|
79 |
"""
|
80 |
|
81 |
try:
|
82 |
-
data = json.
|
83 |
|
84 |
# Format 1: Object with arrays
|
85 |
if isinstance(data, dict):
|
@@ -135,12 +131,9 @@ def parse_json_spectrum(
|
|
135 |
)
|
136 |
|
137 |
except json.JSONDecodeError as e:
|
138 |
-
raise ValueError(f"Invalid JSON format: {str(e)}")
|
139 |
except Exception as e:
|
140 |
-
raise ValueError(f"Failed to parse JSON spectrum: {str(e)}")
|
141 |
-
|
142 |
-
|
143 |
-
# /////////////////////////////////////////////////////
|
144 |
|
145 |
|
146 |
def parse_csv_spectrum(
|
@@ -208,10 +201,7 @@ def parse_csv_spectrum(
|
|
208 |
return np.array(x_vals), np.array(y_vals)
|
209 |
|
210 |
except Exception as e:
|
211 |
-
raise ValueError(f"Failed to parse CSV spectrum: {str(e)}")
|
212 |
-
|
213 |
-
|
214 |
-
# /////////////////////////////////////////////////////
|
215 |
|
216 |
|
217 |
def parse_spectrum_data(
|
@@ -235,7 +225,7 @@ def parse_spectrum_data(
|
|
235 |
|
236 |
# Parse based on detected/specified format
|
237 |
if file_format == "json":
|
238 |
-
x, y = parse_json_spectrum(text_content
|
239 |
elif file_format == "csv":
|
240 |
x, y = parse_csv_spectrum(text_content, filename)
|
241 |
else: # Default to TXT format
|
@@ -247,10 +237,7 @@ def parse_spectrum_data(
|
|
247 |
return x, y
|
248 |
|
249 |
except Exception as e:
|
250 |
-
raise ValueError(f"Failed to parse spectrum data: {str(e)}")
|
251 |
-
|
252 |
-
|
253 |
-
# /////////////////////////////////////////////////////
|
254 |
|
255 |
|
256 |
def parse_txt_spectrum(
|
@@ -287,7 +274,7 @@ def parse_txt_spectrum(
|
|
287 |
f"Parsing {filename}",
|
288 |
)
|
289 |
|
290 |
-
except
|
291 |
ErrorHandler.log_warning(
|
292 |
f"Error parsing line {i+1}: '{line}'. Error: {e}",
|
293 |
f"Parsing {filename}",
|
@@ -302,9 +289,6 @@ def parse_txt_spectrum(
|
|
302 |
return np.array(x_vals), np.array(y_vals)
|
303 |
|
304 |
|
305 |
-
# /////////////////////////////////////////////////////
|
306 |
-
|
307 |
-
|
308 |
def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
|
309 |
"""
|
310 |
Validate parsed spectrum data for common issues.
|
@@ -332,9 +316,6 @@ def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
|
|
332 |
)
|
333 |
|
334 |
|
335 |
-
# /////////////////////////////////////////////////////
|
336 |
-
|
337 |
-
|
338 |
def process_single_file(
|
339 |
filename: str,
|
340 |
text_content: str,
|
@@ -369,8 +350,11 @@ def process_single_file(
|
|
369 |
)
|
370 |
|
371 |
# 3. Run inference, passing modality
|
|
|
|
|
|
|
372 |
prediction, logits_list, probs, inference_time, logits = run_inference_func(
|
373 |
-
y_resampled, model_choice, modality=modality
|
374 |
)
|
375 |
|
376 |
if prediction is None:
|
@@ -418,7 +402,7 @@ def process_single_file(
|
|
418 |
"y_resampled": y_resampled,
|
419 |
}
|
420 |
|
421 |
-
except
|
422 |
ErrorHandler.log_error(e, f"processing {filename}")
|
423 |
return {
|
424 |
"filename": filename,
|
@@ -501,7 +485,7 @@ def process_multiple_files(
|
|
501 |
},
|
502 |
)
|
503 |
|
504 |
-
except
|
505 |
ErrorHandler.log_error(e, f"reading file {uploaded_file.name}")
|
506 |
results.append(
|
507 |
{
|
|
|
11 |
import csv
|
12 |
import io
|
13 |
from pathlib import Path
|
14 |
+
import hashlib
|
15 |
|
16 |
from .preprocessing import preprocess_spectrum
|
17 |
from .errors import ErrorHandler, safe_execute
|
|
|
36 |
try:
|
37 |
json.loads(content)
|
38 |
return "json"
|
39 |
+
except json.JSONDecodeError:
|
40 |
pass
|
41 |
elif suffix == ".csv":
|
42 |
return "csv"
|
|
|
51 |
try:
|
52 |
json.loads(content)
|
53 |
return "json"
|
54 |
+
except json.JSONDecodeError:
|
55 |
pass
|
56 |
|
57 |
# Try CSV (look for commas in first few lines)
|
|
|
64 |
return "txt"
|
65 |
|
66 |
|
67 |
+
def parse_json_spectrum(content: str) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
68 |
"""
|
69 |
Parse spectrum data from JSON format.
|
70 |
|
|
|
75 |
"""
|
76 |
|
77 |
try:
|
78 |
+
data = json.loads(content)
|
79 |
|
80 |
# Format 1: Object with arrays
|
81 |
if isinstance(data, dict):
|
|
|
131 |
)
|
132 |
|
133 |
except json.JSONDecodeError as e:
|
134 |
+
raise ValueError(f"Invalid JSON format: {str(e)}") from e
|
135 |
except Exception as e:
|
136 |
+
raise ValueError(f"Failed to parse JSON spectrum: {str(e)}") from e
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
def parse_csv_spectrum(
|
|
|
201 |
return np.array(x_vals), np.array(y_vals)
|
202 |
|
203 |
except Exception as e:
|
204 |
+
raise ValueError(f"Failed to parse CSV spectrum: {str(e)}") from e
|
|
|
|
|
|
|
205 |
|
206 |
|
207 |
def parse_spectrum_data(
|
|
|
225 |
|
226 |
# Parse based on detected/specified format
|
227 |
if file_format == "json":
|
228 |
+
x, y = parse_json_spectrum(text_content)
|
229 |
elif file_format == "csv":
|
230 |
x, y = parse_csv_spectrum(text_content, filename)
|
231 |
else: # Default to TXT format
|
|
|
237 |
return x, y
|
238 |
|
239 |
except Exception as e:
|
240 |
+
raise ValueError(f"Failed to parse spectrum data: {str(e)}") from e
|
|
|
|
|
|
|
241 |
|
242 |
|
243 |
def parse_txt_spectrum(
|
|
|
274 |
f"Parsing {filename}",
|
275 |
)
|
276 |
|
277 |
+
except ValueError as e:
|
278 |
ErrorHandler.log_warning(
|
279 |
f"Error parsing line {i+1}: '{line}'. Error: {e}",
|
280 |
f"Parsing {filename}",
|
|
|
289 |
return np.array(x_vals), np.array(y_vals)
|
290 |
|
291 |
|
|
|
|
|
|
|
292 |
def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
|
293 |
"""
|
294 |
Validate parsed spectrum data for common issues.
|
|
|
316 |
)
|
317 |
|
318 |
|
|
|
|
|
|
|
319 |
def process_single_file(
|
320 |
filename: str,
|
321 |
text_content: str,
|
|
|
350 |
)
|
351 |
|
352 |
# 3. Run inference, passing modality
|
353 |
+
cache_key = hashlib.md5(
|
354 |
+
f"{y_resampled.tobytes()}{model_choice}".encode()
|
355 |
+
).hexdigest()
|
356 |
prediction, logits_list, probs, inference_time, logits = run_inference_func(
|
357 |
+
y_resampled, model_choice, modality=modality, cache_key=cache_key
|
358 |
)
|
359 |
|
360 |
if prediction is None:
|
|
|
402 |
"y_resampled": y_resampled,
|
403 |
}
|
404 |
|
405 |
+
except ValueError as e:
|
406 |
ErrorHandler.log_error(e, f"processing {filename}")
|
407 |
return {
|
408 |
"filename": filename,
|
|
|
485 |
},
|
486 |
)
|
487 |
|
488 |
+
except ValueError as e:
|
489 |
ErrorHandler.log_error(e, f"reading file {uploaded_file.name}")
|
490 |
results.append(
|
491 |
{
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Core Training Engine for the POLYMEROS project.
|
3 |
+
|
4 |
+
This module contains the primary logic for model training and validation,
|
5 |
+
encapsulated in a reusable `TrainingEngine` class. It is designed to be
|
6 |
+
called by different interfaces, such as the command-line script
|
7 |
+
(train_model.py) and the web UI's TrainingManager.
|
8 |
+
|
9 |
+
This approach ensures that the core training process is consistent,
|
10 |
+
maintainable, and follows the DRY (Don't Repeat Yourself) principle.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
from torch.utils.data import TensorDataset, DataLoader
|
17 |
+
from sklearn.metrics import confusion_matrix, accuracy_score
|
18 |
+
|
19 |
+
from .training_types import (
|
20 |
+
TrainingConfig,
|
21 |
+
TrainingProgress,
|
22 |
+
get_cv_splitter,
|
23 |
+
augment_spectral_data,
|
24 |
+
)
|
25 |
+
from models.registry import build as build_model
|
26 |
+
|
27 |
+
|
28 |
+
class TrainingEngine:
|
29 |
+
"""Encapsulates the core model training and validation logic."""
|
30 |
+
|
31 |
+
def __init__(self, config: TrainingConfig):
|
32 |
+
"""
|
33 |
+
Initializes the TrainingEngine with a given configuration.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
config (TrainingConfig): The configuration object for the training run.
|
37 |
+
"""
|
38 |
+
self.config = config
|
39 |
+
self.device = self._get_device()
|
40 |
+
|
41 |
+
def _get_device(self) -> torch.device:
|
42 |
+
"""Selects the appropriate compute device."""
|
43 |
+
if self.config.device == "auto":
|
44 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
return torch.device(self.config.device)
|
46 |
+
|
47 |
+
def run(
|
48 |
+
self, X: np.ndarray, y: np.ndarray, progress_callback: callable = None
|
49 |
+
) -> dict:
|
50 |
+
"""
|
51 |
+
Executes the full cross-validation training and evaluation loop.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
X (np.ndarray): Feature data.
|
55 |
+
y (np.ndarray): Label data.
|
56 |
+
progress_callback (callable, optional): A function to call with
|
57 |
+
progress updates. Defaults to None.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
dict: A dictionary containing the final results and metrics.
|
61 |
+
"""
|
62 |
+
cv_splitter = get_cv_splitter(self.config.cv_strategy, self.config.num_folds)
|
63 |
+
|
64 |
+
fold_accuracies = []
|
65 |
+
all_conf_matrices = []
|
66 |
+
final_model_state = None
|
67 |
+
|
68 |
+
for fold, (train_idx, val_idx) in enumerate(cv_splitter.split(X, y), 1):
|
69 |
+
if progress_callback:
|
70 |
+
progress_callback(
|
71 |
+
{
|
72 |
+
"type": "fold_start",
|
73 |
+
"fold": fold,
|
74 |
+
"total_folds": self.config.num_folds,
|
75 |
+
}
|
76 |
+
)
|
77 |
+
|
78 |
+
X_train, X_val = X[train_idx], X[val_idx]
|
79 |
+
y_train, y_val = y[train_idx], y[val_idx]
|
80 |
+
|
81 |
+
# Apply data augmentation if enabled
|
82 |
+
if self.config.enable_augmentation:
|
83 |
+
X_train, y_train = augment_spectral_data(
|
84 |
+
X_train, y_train, noise_level=self.config.noise_level
|
85 |
+
)
|
86 |
+
|
87 |
+
train_loader = DataLoader(
|
88 |
+
TensorDataset(
|
89 |
+
torch.tensor(X_train, dtype=torch.float32),
|
90 |
+
torch.tensor(y_train, dtype=torch.long),
|
91 |
+
),
|
92 |
+
batch_size=self.config.batch_size,
|
93 |
+
shuffle=True,
|
94 |
+
)
|
95 |
+
val_loader = DataLoader(
|
96 |
+
TensorDataset(
|
97 |
+
torch.tensor(X_val, dtype=torch.float32),
|
98 |
+
torch.tensor(y_val, dtype=torch.long),
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
model = build_model(self.config.model_name, self.config.target_len).to(
|
103 |
+
self.device
|
104 |
+
)
|
105 |
+
optimizer = torch.optim.Adam(
|
106 |
+
model.parameters(), lr=self.config.learning_rate
|
107 |
+
)
|
108 |
+
criterion = nn.CrossEntropyLoss()
|
109 |
+
|
110 |
+
for epoch in range(self.config.epochs):
|
111 |
+
model.train()
|
112 |
+
running_loss = 0.0
|
113 |
+
for inputs, labels in train_loader:
|
114 |
+
inputs = inputs.unsqueeze(1).to(self.device)
|
115 |
+
labels = labels.to(self.device)
|
116 |
+
|
117 |
+
optimizer.zero_grad()
|
118 |
+
outputs = model(inputs)
|
119 |
+
loss = criterion(outputs, labels)
|
120 |
+
loss.backward()
|
121 |
+
optimizer.step()
|
122 |
+
running_loss += loss.item()
|
123 |
+
|
124 |
+
if progress_callback:
|
125 |
+
progress_callback(
|
126 |
+
{
|
127 |
+
"type": "epoch_end",
|
128 |
+
"fold": fold,
|
129 |
+
"epoch": epoch + 1,
|
130 |
+
"total_epochs": self.config.epochs,
|
131 |
+
"loss": running_loss / len(train_loader),
|
132 |
+
}
|
133 |
+
)
|
134 |
+
|
135 |
+
# Validation
|
136 |
+
model.eval()
|
137 |
+
all_true, all_pred = [], []
|
138 |
+
with torch.no_grad():
|
139 |
+
for inputs, labels in val_loader:
|
140 |
+
inputs = inputs.unsqueeze(1).to(self.device)
|
141 |
+
outputs = model(inputs)
|
142 |
+
_, predicted = torch.max(outputs, 1)
|
143 |
+
all_true.extend(labels.cpu().numpy())
|
144 |
+
all_pred.extend(predicted.cpu().numpy())
|
145 |
+
|
146 |
+
acc = accuracy_score(all_true, all_pred)
|
147 |
+
fold_accuracies.append(acc)
|
148 |
+
all_conf_matrices.append(confusion_matrix(all_true, all_pred).tolist())
|
149 |
+
final_model_state = model.state_dict()
|
150 |
+
|
151 |
+
if progress_callback:
|
152 |
+
progress_callback({"type": "fold_end", "fold": fold, "accuracy": acc})
|
153 |
+
|
154 |
+
return {
|
155 |
+
"fold_accuracies": fold_accuracies,
|
156 |
+
"confusion_matrices": all_conf_matrices,
|
157 |
+
"mean_accuracy": np.mean(fold_accuracies),
|
158 |
+
"std_accuracy": np.std(fold_accuracies),
|
159 |
+
"model_state_dict": final_model_state,
|
160 |
+
}
|
@@ -12,16 +12,14 @@ import threading
|
|
12 |
import concurrent.futures
|
13 |
import multiprocessing
|
14 |
from datetime import datetime, timedelta
|
15 |
-
from dataclasses import dataclass, asdict, field
|
16 |
-
from enum import Enum
|
17 |
from typing import Dict, List, Optional, Callable, Any, Tuple
|
18 |
from pathlib import Path
|
|
|
19 |
|
20 |
import torch
|
21 |
import torch.nn as nn
|
22 |
import numpy as np
|
23 |
from torch.utils.data import TensorDataset, DataLoader
|
24 |
-
from sklearn.model_selection import StratifiedKFold, KFold, TimeSeriesSplit
|
25 |
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
|
26 |
from sklearn.metrics.pairwise import cosine_similarity
|
27 |
from scipy.signal import find_peaks
|
@@ -30,6 +28,14 @@ from scipy.spatial.distance import euclidean
|
|
30 |
# Add project-specific imports
|
31 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
32 |
from models.registry import choices as model_choices, build as build_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
from utils.preprocessing import preprocess_spectrum
|
34 |
|
35 |
|
@@ -143,74 +149,21 @@ def calculate_spectroscopy_metrics(
|
|
143 |
return metrics
|
144 |
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
return StratifiedKFold(
|
150 |
-
n_splits=n_splits, shuffle=True, random_state=random_state
|
151 |
-
)
|
152 |
-
elif strategy == "kfold":
|
153 |
-
return KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
154 |
-
elif strategy == "time_series_split":
|
155 |
-
return TimeSeriesSplit(n_splits=n_splits)
|
156 |
-
else:
|
157 |
-
# Default to stratified k-fold
|
158 |
-
return StratifiedKFold(
|
159 |
-
n_splits=n_splits, shuffle=True, random_state=random_state
|
160 |
-
)
|
161 |
-
|
162 |
-
|
163 |
-
def augment_spectral_data(
|
164 |
-
X: np.ndarray,
|
165 |
-
y: np.ndarray,
|
166 |
-
noise_level: float = 0.01,
|
167 |
-
augmentation_factor: int = 2,
|
168 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
169 |
-
"""Augment spectral data with realistic noise and variations"""
|
170 |
-
if augmentation_factor <= 1:
|
171 |
-
return X, y
|
172 |
-
|
173 |
-
augmented_X = [X]
|
174 |
-
augmented_y = [y]
|
175 |
-
|
176 |
-
for i in range(augmentation_factor - 1):
|
177 |
-
# Add Gaussian noise
|
178 |
-
noise = np.random.normal(0, noise_level, X.shape)
|
179 |
-
X_noisy = X + noise
|
180 |
-
|
181 |
-
# Add baseline drift (common in spectroscopy)
|
182 |
-
baseline_drift = np.random.normal(0, noise_level * 0.5, (X.shape[0], 1))
|
183 |
-
X_drift = X_noisy + baseline_drift
|
184 |
-
|
185 |
-
# Add intensity scaling variation
|
186 |
-
intensity_scale = np.random.normal(1.0, 0.05, (X.shape[0], 1))
|
187 |
-
X_scaled = X_drift * intensity_scale
|
188 |
-
|
189 |
-
# Ensure no negative values
|
190 |
-
X_scaled = np.maximum(X_scaled, 0)
|
191 |
-
|
192 |
-
augmented_X.append(X_scaled)
|
193 |
-
augmented_y.append(y)
|
194 |
-
|
195 |
-
return np.vstack(augmented_X), np.hstack(augmented_y)
|
196 |
-
|
197 |
-
|
198 |
-
class TrainingStatus(Enum):
|
199 |
-
"""Training job status enumeration"""
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
COMPLETED = "completed"
|
204 |
-
FAILED = "failed"
|
205 |
-
CANCELLED = "cancelled"
|
206 |
|
207 |
|
208 |
-
|
209 |
-
|
|
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
|
215 |
|
216 |
@dataclass
|
@@ -224,15 +177,12 @@ class TrainingConfig:
|
|
224 |
epochs: int = 10
|
225 |
learning_rate: float = 1e-3
|
226 |
num_folds: int = 10
|
227 |
-
baseline_correction: bool = True
|
228 |
-
smoothing: bool = True
|
229 |
-
normalization: bool = True
|
230 |
modality: str = "raman"
|
231 |
device: str = "auto" # auto, cpu, cuda
|
232 |
cv_strategy: str = "stratified_kfold" # New field for CV strategy
|
233 |
spectral_weight: float = 0.1 # Weight for spectroscopy-specific metrics
|
234 |
-
|
235 |
-
|
236 |
|
237 |
def to_dict(self) -> Dict[str, Any]:
|
238 |
"""Convert to dictionary for serialization"""
|
@@ -308,10 +258,6 @@ class TrainingManager:
|
|
308 |
self.output_dir = Path(output_dir)
|
309 |
self.output_dir.mkdir(exist_ok=True)
|
310 |
(self.output_dir / "weights").mkdir(exist_ok=True)
|
311 |
-
(self.output_dir / "logs").mkdir(exist_ok=True)
|
312 |
-
|
313 |
-
# Progress callbacks for UI updates
|
314 |
-
self.progress_callbacks: Dict[str, List[Callable]] = {}
|
315 |
|
316 |
def generate_job_id(self) -> str:
|
317 |
"""Generate unique job ID"""
|
@@ -324,20 +270,12 @@ class TrainingManager:
|
|
324 |
job_id = self.generate_job_id()
|
325 |
job = TrainingJob(job_id=job_id, config=config)
|
326 |
|
327 |
-
# Set up output paths
|
328 |
-
job.weights_path = str(self.output_dir / "weights" / f"{job_id}_model.pth")
|
329 |
-
job.logs_path = str(self.output_dir / "logs" / f"{job_id}_log.json")
|
330 |
-
|
331 |
self.jobs[job_id] = job
|
332 |
|
333 |
-
# Register progress callback
|
334 |
-
if progress_callback:
|
335 |
-
if job_id not in self.progress_callbacks:
|
336 |
-
self.progress_callbacks[job_id] = []
|
337 |
-
self.progress_callbacks[job_id].append(progress_callback)
|
338 |
-
|
339 |
# Submit to thread pool
|
340 |
-
self.executor.submit(
|
|
|
|
|
341 |
|
342 |
return job_id
|
343 |
|
@@ -346,25 +284,39 @@ class TrainingManager:
|
|
346 |
try:
|
347 |
job.status = TrainingStatus.RUNNING
|
348 |
job.started_at = datetime.now()
|
349 |
-
job.progress
|
350 |
-
|
351 |
-
self._notify_progress(job.job_id, job)
|
352 |
|
353 |
-
|
354 |
-
|
355 |
|
356 |
# Load and preprocess data
|
357 |
X, y = self._load_and_preprocess_data(job)
|
358 |
if X is None or y is None:
|
359 |
raise ValueError("Failed to load dataset")
|
360 |
|
361 |
-
#
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
self._save_training_results(job)
|
369 |
|
370 |
job.status = TrainingStatus.COMPLETED
|
@@ -377,16 +329,8 @@ class TrainingManager:
|
|
377 |
job.completed_at = datetime.now()
|
378 |
|
379 |
finally:
|
380 |
-
|
381 |
-
|
382 |
-
def _get_device(self, device_preference: str) -> torch.device:
|
383 |
-
"""Get appropriate device for training"""
|
384 |
-
if device_preference == "auto":
|
385 |
-
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
386 |
-
elif device_preference == "cuda" and torch.cuda.is_available():
|
387 |
-
return torch.device("cuda")
|
388 |
-
else:
|
389 |
-
return torch.device("cpu")
|
390 |
|
391 |
def _load_and_preprocess_data(
|
392 |
self, job: TrainingJob
|
@@ -576,134 +520,19 @@ class TrainingManager:
|
|
576 |
print(f"Error loading dataset: {e}")
|
577 |
return None, None
|
578 |
|
579 |
-
def
|
580 |
-
"""
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
torch.cuda.manual_seed_all(SEED)
|
586 |
-
torch.backends.cudnn.deterministic = True
|
587 |
-
torch.backends.cudnn.benchmark = False
|
588 |
-
|
589 |
-
def _run_cross_validation(
|
590 |
-
self, job: TrainingJob, X: np.ndarray, y: np.ndarray, device: torch.device
|
591 |
-
):
|
592 |
-
"""Run configurable cross-validation training with spectroscopy metrics"""
|
593 |
-
config = job.config
|
594 |
-
|
595 |
-
# Apply data augmentation if enabled
|
596 |
-
if config.enable_augmentation:
|
597 |
-
X, y = augment_spectral_data(
|
598 |
-
X, y, noise_level=config.noise_level, augmentation_factor=2
|
599 |
-
)
|
600 |
-
|
601 |
-
# Get appropriate CV splitter
|
602 |
-
cv_splitter = get_cv_splitter(config.cv_strategy, config.num_folds)
|
603 |
-
|
604 |
-
fold_accuracies = []
|
605 |
-
confusion_matrices = []
|
606 |
-
spectroscopy_metrics = []
|
607 |
-
|
608 |
-
for fold, (train_idx, val_idx) in enumerate(cv_splitter.split(X, y), 1):
|
609 |
-
job.progress.current_fold = fold
|
610 |
-
job.progress.current_epoch = 0
|
611 |
-
|
612 |
-
# Prepare data
|
613 |
-
X_train, X_val = X[train_idx], X[val_idx]
|
614 |
-
y_train, y_val = y[train_idx], y[val_idx]
|
615 |
-
|
616 |
-
train_loader = DataLoader(
|
617 |
-
TensorDataset(
|
618 |
-
torch.tensor(X_train, dtype=torch.float32),
|
619 |
-
torch.tensor(y_train, dtype=torch.long),
|
620 |
-
),
|
621 |
-
batch_size=config.batch_size,
|
622 |
-
shuffle=True,
|
623 |
-
)
|
624 |
-
val_loader = DataLoader(
|
625 |
-
TensorDataset(
|
626 |
-
torch.tensor(X_val, dtype=torch.float32),
|
627 |
-
torch.tensor(y_val, dtype=torch.long),
|
628 |
-
),
|
629 |
-
batch_size=config.batch_size,
|
630 |
-
shuffle=False,
|
631 |
-
)
|
632 |
-
|
633 |
-
# Initialize model
|
634 |
-
model = build_model(config.model_name, config.target_len).to(device)
|
635 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
|
636 |
-
criterion = nn.CrossEntropyLoss()
|
637 |
-
|
638 |
-
# Training loop
|
639 |
-
for epoch in range(config.epochs):
|
640 |
-
job.progress.current_epoch = epoch + 1
|
641 |
-
model.train()
|
642 |
-
running_loss = 0.0
|
643 |
-
correct = 0
|
644 |
-
total = 0
|
645 |
-
|
646 |
-
for inputs, labels in train_loader:
|
647 |
-
inputs = inputs.unsqueeze(1).to(device)
|
648 |
-
labels = labels.to(device)
|
649 |
-
|
650 |
-
optimizer.zero_grad()
|
651 |
-
outputs = model(inputs)
|
652 |
-
loss = criterion(outputs, labels)
|
653 |
-
loss.backward()
|
654 |
-
optimizer.step()
|
655 |
-
|
656 |
-
running_loss += loss.item()
|
657 |
-
_, predicted = torch.max(outputs.data, 1)
|
658 |
-
total += labels.size(0)
|
659 |
-
correct += (predicted == labels).sum().item()
|
660 |
-
|
661 |
-
job.progress.current_loss = running_loss / len(train_loader)
|
662 |
-
job.progress.current_accuracy = correct / total
|
663 |
-
|
664 |
-
self._notify_progress(job.job_id, job)
|
665 |
-
|
666 |
-
# Validation with comprehensive metrics
|
667 |
-
model.eval()
|
668 |
-
val_predictions = []
|
669 |
-
val_true = []
|
670 |
-
val_probabilities = []
|
671 |
-
|
672 |
-
with torch.no_grad():
|
673 |
-
for inputs, labels in val_loader:
|
674 |
-
inputs = inputs.unsqueeze(1).to(device)
|
675 |
-
outputs = model(inputs)
|
676 |
-
probabilities = torch.softmax(outputs, dim=1)
|
677 |
-
_, predicted = torch.max(outputs, 1)
|
678 |
-
|
679 |
-
val_predictions.extend(predicted.cpu().numpy())
|
680 |
-
val_true.extend(labels.numpy())
|
681 |
-
val_probabilities.extend(probabilities.cpu().numpy())
|
682 |
-
|
683 |
-
# Calculate standard metrics
|
684 |
-
fold_accuracy = accuracy_score(val_true, val_predictions)
|
685 |
-
fold_cm = confusion_matrix(val_true, val_predictions).tolist()
|
686 |
-
|
687 |
-
# Calculate spectroscopy-specific metrics
|
688 |
-
val_probabilities = np.array(val_probabilities)
|
689 |
-
spectro_metrics = calculate_spectroscopy_metrics(
|
690 |
-
np.array(val_true), np.array(val_predictions), val_probabilities
|
691 |
-
)
|
692 |
-
|
693 |
-
fold_accuracies.append(fold_accuracy)
|
694 |
-
confusion_matrices.append(fold_cm)
|
695 |
-
spectroscopy_metrics.append(spectro_metrics)
|
696 |
-
|
697 |
-
# Save best model weights (from last fold for now)
|
698 |
-
if fold == config.num_folds:
|
699 |
-
torch.save(model.state_dict(), job.weights_path)
|
700 |
-
|
701 |
-
job.progress.fold_accuracies = fold_accuracies
|
702 |
-
job.progress.confusion_matrices = confusion_matrices
|
703 |
-
job.progress.spectroscopy_metrics = spectroscopy_metrics
|
704 |
|
705 |
def _save_training_results(self, job: TrainingJob):
|
706 |
"""Save training results and logs with enhanced metrics"""
|
|
|
|
|
|
|
|
|
707 |
# Calculate comprehensive summary metrics
|
708 |
spectro_summary = {}
|
709 |
if job.progress.spectroscopy_metrics:
|
@@ -744,17 +573,9 @@ class TrainingManager:
|
|
744 |
"error_message": job.error_message,
|
745 |
}
|
746 |
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
def _notify_progress(self, job_id: str, job: TrainingJob):
|
751 |
-
"""Notify registered callbacks about progress updates"""
|
752 |
-
if job_id in self.progress_callbacks:
|
753 |
-
for callback in self.progress_callbacks[job_id]:
|
754 |
-
try:
|
755 |
-
callback(job)
|
756 |
-
except Exception as e:
|
757 |
-
print(f"Error in progress callback: {e}")
|
758 |
|
759 |
def get_job_status(self, job_id: str) -> Optional[TrainingJob]:
|
760 |
"""Get current status of a training job"""
|
|
|
12 |
import concurrent.futures
|
13 |
import multiprocessing
|
14 |
from datetime import datetime, timedelta
|
|
|
|
|
15 |
from typing import Dict, List, Optional, Callable, Any, Tuple
|
16 |
from pathlib import Path
|
17 |
+
from dataclasses import dataclass, field
|
18 |
|
19 |
import torch
|
20 |
import torch.nn as nn
|
21 |
import numpy as np
|
22 |
from torch.utils.data import TensorDataset, DataLoader
|
|
|
23 |
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
|
24 |
from sklearn.metrics.pairwise import cosine_similarity
|
25 |
from scipy.signal import find_peaks
|
|
|
28 |
# Add project-specific imports
|
29 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
30 |
from models.registry import choices as model_choices, build as build_model
|
31 |
+
from utils.training_engine import TrainingEngine
|
32 |
+
from utils.training_types import (
|
33 |
+
TrainingConfig,
|
34 |
+
TrainingProgress,
|
35 |
+
TrainingStatus,
|
36 |
+
CVStrategy,
|
37 |
+
get_cv_splitter,
|
38 |
+
)
|
39 |
from utils.preprocessing import preprocess_spectrum
|
40 |
|
41 |
|
|
|
149 |
return metrics
|
150 |
|
151 |
|
152 |
+
@dataclass
|
153 |
+
class AugmentationConfig:
|
154 |
+
"""Data augmentation configuration"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
enable_augmentation: bool = False
|
157 |
+
noise_level: float = 0.01 # Noise level for augmentation
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
+
@dataclass
|
161 |
+
class PreprocessingConfig:
|
162 |
+
"""Preprocessing configuration"""
|
163 |
|
164 |
+
baseline_correction: bool = True
|
165 |
+
smoothing: bool = True
|
166 |
+
normalization: bool = True
|
167 |
|
168 |
|
169 |
@dataclass
|
|
|
177 |
epochs: int = 10
|
178 |
learning_rate: float = 1e-3
|
179 |
num_folds: int = 10
|
|
|
|
|
|
|
180 |
modality: str = "raman"
|
181 |
device: str = "auto" # auto, cpu, cuda
|
182 |
cv_strategy: str = "stratified_kfold" # New field for CV strategy
|
183 |
spectral_weight: float = 0.1 # Weight for spectroscopy-specific metrics
|
184 |
+
augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
|
185 |
+
preprocessing: PreprocessingConfig = field(default_factory=PreprocessingConfig)
|
186 |
|
187 |
def to_dict(self) -> Dict[str, Any]:
|
188 |
"""Convert to dictionary for serialization"""
|
|
|
258 |
self.output_dir = Path(output_dir)
|
259 |
self.output_dir.mkdir(exist_ok=True)
|
260 |
(self.output_dir / "weights").mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
261 |
|
262 |
def generate_job_id(self) -> str:
|
263 |
"""Generate unique job ID"""
|
|
|
270 |
job_id = self.generate_job_id()
|
271 |
job = TrainingJob(job_id=job_id, config=config)
|
272 |
|
|
|
|
|
|
|
|
|
273 |
self.jobs[job_id] = job
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
# Submit to thread pool
|
276 |
+
self.executor.submit(
|
277 |
+
self._run_training_job, job, progress_callback=progress_callback
|
278 |
+
)
|
279 |
|
280 |
return job_id
|
281 |
|
|
|
284 |
try:
|
285 |
job.status = TrainingStatus.RUNNING
|
286 |
job.started_at = datetime.now()
|
287 |
+
if job.progress:
|
288 |
+
job.progress.start_time = job.started_at
|
|
|
289 |
|
290 |
+
if progress_callback:
|
291 |
+
progress_callback(job)
|
292 |
|
293 |
# Load and preprocess data
|
294 |
X, y = self._load_and_preprocess_data(job)
|
295 |
if X is None or y is None:
|
296 |
raise ValueError("Failed to load dataset")
|
297 |
|
298 |
+
# Define a callback to update the job's progress object
|
299 |
+
def engine_progress_callback(progress_data: dict):
|
300 |
+
if job.progress:
|
301 |
+
if progress_data["type"] == "fold_start":
|
302 |
+
job.progress.current_fold = progress_data["fold"]
|
303 |
+
elif progress_data["type"] == "epoch_end":
|
304 |
+
job.progress.current_epoch = progress_data["epoch"]
|
305 |
+
job.progress.current_loss = progress_data["loss"]
|
306 |
+
if progress_callback:
|
307 |
+
progress_callback(job)
|
308 |
+
|
309 |
+
# Instantiate and run the training engine
|
310 |
+
engine = TrainingEngine(job.config)
|
311 |
+
results = engine.run(X, y, progress_callback=engine_progress_callback)
|
312 |
+
|
313 |
+
# Update job with results
|
314 |
+
if job.progress:
|
315 |
+
job.progress.fold_accuracies = results["fold_accuracies"]
|
316 |
+
job.progress.confusion_matrices = results["confusion_matrices"]
|
317 |
+
|
318 |
+
# Save model weights and logs
|
319 |
+
self._save_model_weights(job, results["model_state_dict"])
|
320 |
self._save_training_results(job)
|
321 |
|
322 |
job.status = TrainingStatus.COMPLETED
|
|
|
329 |
job.completed_at = datetime.now()
|
330 |
|
331 |
finally:
|
332 |
+
if progress_callback:
|
333 |
+
progress_callback(job)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
def _load_and_preprocess_data(
|
336 |
self, job: TrainingJob
|
|
|
520 |
print(f"Error loading dataset: {e}")
|
521 |
return None, None
|
522 |
|
523 |
+
def _save_model_weights(self, job: TrainingJob, model_state_dict: dict):
|
524 |
+
"""Saves the model's state dictionary to a file."""
|
525 |
+
weights_dir = self.output_dir / "weights"
|
526 |
+
weights_dir.mkdir(exist_ok=True)
|
527 |
+
job.weights_path = str(weights_dir / f"{job.config.model_name}_model.pth")
|
528 |
+
torch.save(model_state_dict, job.weights_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
def _save_training_results(self, job: TrainingJob):
|
531 |
"""Save training results and logs with enhanced metrics"""
|
532 |
+
logs_dir = self.output_dir / "logs"
|
533 |
+
logs_dir.mkdir(exist_ok=True)
|
534 |
+
job.logs_path = str(logs_dir / f"{job.job_id}_log.json")
|
535 |
+
|
536 |
# Calculate comprehensive summary metrics
|
537 |
spectro_summary = {}
|
538 |
if job.progress.spectroscopy_metrics:
|
|
|
573 |
"error_message": job.error_message,
|
574 |
}
|
575 |
|
576 |
+
if job.logs_path:
|
577 |
+
with open(job.logs_path, "w") as f:
|
578 |
+
json.dump(results, f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
|
580 |
def get_job_status(self, job_id: str) -> Optional[TrainingJob]:
|
581 |
"""Get current status of a training job"""
|
File without changes
|