devjas1 commited on
Commit
fe030dd
·
1 Parent(s): 7184c06

(FEAT)[Enhanced Results Widget]: Integrate advanced probability breakdown, QC, and provenance export

Browse files

- Updated ui_components.py and 2_Enhanced_Analysis.py to include a research-grade probability breakdown widget.
- Added entropy, margin, calibration, and provenance export to results column.
- Ensured QC summary and preprocessing parameters are computed and stored in session_state for meaningful diagnostics.

(FIX)[QC Summary Assignment]: Ensure valid spectrum QC metrics in results
- Modified core_logic.py and app.py to compute and assign n_points, x_min, x_max, monotonic_x, nan_free, and variance_proxy after spectrum parsing.
- Eliminated NULL values in QC summary for reliable reporting.

(DOCS)[Markdown & Citation]: Improve README formatting and citation style
- Updated README.md to replace bare DOI URLs with markdown link syntax, resolving markdownlint MD034 and improving citation readability.

(FEAT)[Model Inspection Utility]: Add inspect_weights.py for model weight analysis
- Introduced inspect_weights.py to support model weight inspection and debugging for training and inference workflows.

(FEAT+FIX)[Batch & Image Processing]: Refine batch utilities and image processing
- Enhanced multifile.py and figure2_model.py for multi-format batch support, error resilience, and improved spectrum parsing.
- Improved training_engine.py for robust batch processing and model training integration.

Refactor training architecture and enhance model training capabilities

- Unified training logic by introducing a central `TrainingEngine` class in `utils/training_engine.py`.
- Decoupled data structures for training configuration and status into `utils/training_types.py`.
- Updated CLI script (`scripts/train_model.py`) to utilize the new training engine, improving maintainability.
- Enhanced `TrainingManager` in `utils/training_manager.py` to support the new training engine and provide better job management.
- Added diagnostics script (`inspect_weights.py`) for inspecting model weights and identifying potential issues.
- Improved data parsing by modifying `utils/multifile.py` to streamline spectrum data handling.
- Updated model weight handling and logging mechanisms to ensure better tracking of training progress and results.
- Created comprehensive README documentation for the training modules, detailing usage and project structure.

app.py DELETED
@@ -1,72 +0,0 @@
1
- # In App.py
2
- import streamlit as st
3
-
4
- from modules.callbacks import init_session_state
5
-
6
- from modules.ui_components import (
7
- render_sidebar,
8
- render_results_column,
9
- render_input_column,
10
- render_comparison_tab,
11
- render_performance_tab,
12
- load_css,
13
- )
14
-
15
- from modules.training_ui import render_training_tab
16
-
17
- from utils.image_processing import render_image_upload_interface
18
-
19
- st.set_page_config(
20
- page_title="ML Polymer Classification",
21
- page_icon="🔬",
22
- layout="wide",
23
- initial_sidebar_state="expanded",
24
- menu_items=None,
25
- )
26
-
27
-
28
- def main():
29
- """Modularized main content to other scripts to clean the main app"""
30
- load_css("static/style.css")
31
- init_session_state()
32
-
33
- render_sidebar()
34
-
35
- # Create main tabs for different analysis modes
36
- tab1, tab2, tab3, tab4, tab5 = st.tabs(
37
- [
38
- "Standard Analysis",
39
- "Model Comparison",
40
- "Model Training",
41
- "Image Analysis",
42
- "Performance Tracking",
43
- ]
44
- )
45
-
46
- with tab1:
47
- # Standard single-model analysis
48
- col1, col2 = st.columns([1, 1.35], gap="small")
49
- with col1:
50
- render_input_column()
51
- with col2:
52
- render_results_column()
53
-
54
- with tab2:
55
- # Multi-model comparison interface
56
- render_comparison_tab()
57
-
58
- with tab3:
59
- # Model training interface
60
- render_training_tab()
61
-
62
- with tab4:
63
- # Image analysis interface
64
- render_image_upload_interface()
65
-
66
- with tab5:
67
- # Performance tracking interface
68
- render_performance_tab()
69
-
70
-
71
- if __name__ == "__main__":
72
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core_logic.py CHANGED
@@ -10,7 +10,6 @@ import numpy as np
10
  import streamlit as st
11
  from pathlib import Path
12
  from config import SAMPLE_DATA_DIR
13
- from datetime import datetime
14
  from models.registry import build, choices
15
 
16
 
@@ -27,7 +26,7 @@ def label_file(filename: str) -> int:
27
 
28
 
29
  @st.cache_data
30
- def load_state_dict(_mtime, model_path):
31
  """Load state dict with mtime in cache key to detect file changes"""
32
  try:
33
  return torch.load(model_path, map_location="cpu")
@@ -61,6 +60,7 @@ def load_model(model_name):
61
  model.load_state_dict(state_dict, strict=True)
62
  model.eval()
63
  weights_loaded = True
 
64
 
65
  except (OSError, RuntimeError):
66
  continue
@@ -88,7 +88,7 @@ def cleanup_memory():
88
 
89
 
90
  @st.cache_data
91
- def run_inference(y_resampled, model_choice, modality: str, _cache_key=None):
92
  """Run model inference and cache results with performance tracking"""
93
  from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
94
  from datetime import datetime
@@ -169,58 +169,3 @@ def get_sample_files():
169
  if sample_dir.exists():
170
  return sorted(list(sample_dir.glob("*.txt")))
171
  return []
172
-
173
-
174
- def parse_spectrum_data(raw_text):
175
- """Parse spectrum data from text with robust error handling and validation"""
176
- x_vals, y_vals = [], []
177
-
178
- for line in raw_text.splitlines():
179
- line = line.strip()
180
- if not line or line.startswith("#"): # Skip empty lines and comments
181
- continue
182
-
183
- try:
184
- # Handle different separators
185
- parts = line.replace(",", " ").split()
186
- numbers = [
187
- p
188
- for p in parts
189
- if p.replace(".", "", 1)
190
- .replace("-", "", 1)
191
- .replace("+", "", 1)
192
- .isdigit()
193
- ]
194
-
195
- if len(numbers) >= 2:
196
- x, y = float(numbers[0]), float(numbers[1])
197
- x_vals.append(x)
198
- y_vals.append(y)
199
-
200
- except ValueError:
201
- # Skip problematic lines but don't fail completely
202
- continue
203
-
204
- if len(x_vals) < 10: # Minimum reasonable spectrum length
205
- raise ValueError(
206
- f"Insufficient data points: {len(x_vals)}. Need at least 10 points."
207
- )
208
-
209
- x = np.array(x_vals)
210
- y = np.array(y_vals)
211
-
212
- # Check for NaNs
213
- if np.any(np.isnan(x)) or np.any(np.isnan(y)):
214
- raise ValueError("Input data contains NaN values")
215
-
216
- # Check monotonic increasing x
217
- if not np.all(np.diff(x) > 0):
218
- raise ValueError("Wavenumbers must be strictly increasing")
219
-
220
- # Check reasonable range for Raman spectroscopy
221
- if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100:
222
- raise ValueError(
223
- f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100"
224
- )
225
-
226
- return x, y
 
10
  import streamlit as st
11
  from pathlib import Path
12
  from config import SAMPLE_DATA_DIR
 
13
  from models.registry import build, choices
14
 
15
 
 
26
 
27
 
28
  @st.cache_data
29
+ def load_state_dict(mtime, model_path):
30
  """Load state dict with mtime in cache key to detect file changes"""
31
  try:
32
  return torch.load(model_path, map_location="cpu")
 
60
  model.load_state_dict(state_dict, strict=True)
61
  model.eval()
62
  weights_loaded = True
63
+ break # Exit loop after successful load
64
 
65
  except (OSError, RuntimeError):
66
  continue
 
88
 
89
 
90
  @st.cache_data
91
+ def run_inference(y_resampled, model_choice, modality: str, cache_key=None):
92
  """Run model inference and cache results with performance tracking"""
93
  from utils.performance_tracker import get_performance_tracker, PerformanceMetrics
94
  from datetime import datetime
 
169
  if sample_dir.exists():
170
  return sorted(list(sample_dir.glob("*.txt")))
171
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inspect_weights.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnostic script to inspect the weights within a PyTorch .pth file.
3
+
4
+ This utility loads a model's state dictionary and prints summary statistics
5
+ (mean, std, min, max) for each parameter tensor. It helps diagnose issues
6
+ like corrupted weights from failed or interrupted training runs, which might
7
+ result in a model producing constant, incorrect outputs.
8
+
9
+ Usage:
10
+ python scripts/inspect_weights.py path/to/your/model_weights.pth
11
+ """
12
+
13
+ import torch
14
+ import argparse
15
+ import os
16
+ from pathlib import Path
17
+ import sys
18
+
19
+ # Add project root to path to allow imports from other modules
20
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
21
+
22
+
23
+ def inspect_weights(file_path: str):
24
+ """
25
+ Loads a model state_dict from a .pth file and prints statistics
26
+ for each parameter tensor to help diagnose corrupted weights.
27
+ """
28
+ if not os.path.exists(file_path):
29
+ print(f"❌ Error: File not found at {file_path}")
30
+ return
31
+
32
+ print(f"🔍 Inspecting weights for: {file_path}\n")
33
+
34
+ try:
35
+ # Load the state dictionary
36
+ # Use weights_only=True for security and to supress the warning
37
+ try:
38
+ state_dict = torch.load(
39
+ file_path, map_location=torch.device("cpu"), weights_only=True
40
+ )
41
+ except TypeError: # Fallback for older torch versions
42
+ state_dict = torch.load(file_path, map_location=torch.device("cpu"))
43
+
44
+ # Handle checkpoints that save the model in a sub-dictionary
45
+ if "model_state_dict" in state_dict:
46
+ state_dict = state_dict["model_state_dict"]
47
+ elif "model" in state_dict:
48
+ state_dict = state_dict["model"]
49
+
50
+ if not state_dict:
51
+ print("⚠️ State dictionary is empty.")
52
+ return
53
+
54
+ print(
55
+ f"{'Parameter Name':<40} {'Shape':<20} {'Mean':<15} {'Std Dev':<15} {'Min':<15} {'Max':<15}"
56
+ )
57
+ print("-" * 120)
58
+ all_stds = []
59
+
60
+ for name, param in state_dict.items():
61
+ if isinstance(param, torch.Tensor):
62
+ # Ensure tensor is float for stats, but don't fail if not
63
+ try:
64
+ param_float = param.float()
65
+ mean_val = f"{param_float.mean().item():.4e}"
66
+ std_val_float = param_float.std().item()
67
+ std_val = f"{std_val_float:.4e}"
68
+ min_val = f"{param_float.min().item():.4e}"
69
+ max_val = f"{param_float.max().item():.4e}"
70
+ all_stds.append(std_val_float)
71
+ except (RuntimeError, TypeError):
72
+ mean_val, std_val, min_val, max_val = "N/A", "N/A", "N/A", "N/A"
73
+
74
+ shape_str = str(list(param.shape))
75
+ print(
76
+ f"{name:<40} {shape_str:<20} {mean_val:<15} {std_val:<15} {min_val:<15} {max_val:<15}"
77
+ )
78
+ else:
79
+ print(f"{name:<40} {'Non-Tensor':<20} {str(param):<60}")
80
+
81
+ print("\n" + "-" * 120)
82
+ print("✅ Inspection complete.")
83
+ print("\nDiagnosis:")
84
+ print(
85
+ "- If you see all zeros, NaNs, or very small (e.g., e-38) uniform values, the weights file is likely corrupted."
86
+ )
87
+ if all(s < 1e-6 for s in all_stds if s is not None):
88
+ print(
89
+ "- WARNING: All parameter standard deviations are extremely low. The model may be 'dead' and insensitive to input."
90
+ )
91
+ else:
92
+ print(
93
+ "- The weight statistics appear varied, suggesting the file is not corrupted with zeros/NaNs."
94
+ )
95
+ print(
96
+ "- If the model still produces constant output, it is likely poorly trained."
97
+ )
98
+
99
+ print("\nRecommendation: Retraining the model is the correct solution.")
100
+
101
+ except Exception as e:
102
+ print(f"❌ An error occurred while inspecting the weights file: {e}")
103
+ import traceback
104
+
105
+ traceback.print_exc()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser(
110
+ description="Inspect PyTorch model weights in a .pth file."
111
+ )
112
+ parser.add_argument(
113
+ "file_path", type=str, help="Path to the .pth model weights file."
114
+ )
115
+ args = parser.parse_args()
116
+ inspect_weights(args.file_path)
modules/TRAINING_MODELS_README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # POLYMEROS: AI-Driven Polymer Aging Analysis & Classification
2
+
3
+ POLYMEROS is an advanced, AI-driven platform for analyzing and classifying polymer degradation using spectroscopic data. This project extends a baseline CNN model to incorporate multi-modal analysis (Raman & FTIR), modern machine learning architectures, a comprehensive data pipeline, and an interactive educational framework.
4
+
5
+ [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
6
+
7
+ ---
8
+
9
+ ## 🚀 Key Features & Recent Enhancements
10
+
11
+ This platform has been significantly enhanced with a suite of research-grade features. Recent architectural improvements have focused on creating a robust, maintainable, and unified training system.
12
+
13
+ ### Unified Training Architecture
14
+
15
+ Previously, the project contained two separate implementations of the model training logic: one in the command-line script (`scripts/train_model.py`) and another powering the backend of the web UI (`utils/training_manager.py`). This duplication led to inconsistencies and made maintenance difficult.
16
+
17
+ The system has been refactored to follow the **Don't Repeat Yourself (DRY)** principle:
18
+
19
+ 1. **Central `TrainingEngine`**: A new `utils/training_engine.py` module was created to house the core, canonical training and cross-validation loop. This engine is now the single source of truth for how models are trained.
20
+
21
+ 2. **Decoupled Data Structures**: Shared data classes like `TrainingConfig` and `TrainingStatus` were moved to a dedicated `utils/training_types.py` file. This resolved circular import errors and improved modularity.
22
+
23
+ 3. **Refactored Interfaces**:
24
+ - The **CLI script** (`scripts/train_model.py`) is now a lightweight wrapper that parses command-line arguments and calls the `TrainingEngine`.
25
+ - The **UI backend** (`utils/training_manager.py`) now also uses the `TrainingEngine` to run training jobs submitted from the "Model Training Hub".
26
+
27
+ This unified architecture ensures that any improvements to the training process are immediately available to both developers using the CLI and users interacting with the web UI.
28
+
29
+ ---
30
+
31
+ ## 🛠️ How to Train Models
32
+
33
+ With the new unified architecture, you can train models using either the command line or the interactive web UI, depending on your needs.
34
+
35
+ ### 1. CLI Training (For Developers & Automation)
36
+
37
+ The command-line interface is the ideal method for reproducible experiments, automated workflows, or training on a remote server. It provides full control over all training hyperparameters.
38
+
39
+ **Why use the CLI?**
40
+
41
+ - For scripting multiple training runs.
42
+ - For integration into CI/CD pipelines.
43
+ - When working in a non-GUI environment.
44
+
45
+ **Example Command:**
46
+ To run a 10-fold cross-validation for the `figure2` model, run the following from the project's root directory:
47
+
48
+ ```bash
49
+ python scripts/train_model.py --model figure2 --epochs 15 --baseline --smooth --normalize
50
+ ```
51
+
52
+ This command will:
53
+
54
+ - Load the default dataset from `datasets/rdwp`.
55
+ - Apply the specified preprocessing steps.
56
+ - Run the training using the central `TrainingEngine`.
57
+ - Save the final model weights to `outputs/weights/figure2_model.pth` and a detailed JSON log to `outputs/logs/`.
58
+
59
+ ### 2. UI Training Hub (For Interactive Use)
60
+
61
+ The "Model Training Hub" within the web application provides a user-friendly, graphical interface for training models. It's designed for interactive experimentation and for users who may not be comfortable with the command line.
62
+
63
+ **Why use the UI?**
64
+
65
+ - To easily train models on your own uploaded datasets.
66
+ - To interactively tweak hyperparameters and see their effect.
67
+ - To monitor training progress in real-time with visual feedback.
68
+
69
+ **How to use it:**
70
+
71
+ 1. Navigate to the **Model Training Hub** tab in the application.
72
+ 2. **Configure Your Job**:
73
+ - Select a model architecture.
74
+ - Upload a new dataset or choose an existing one.
75
+ - Adjust training parameters like epochs, learning rate, and batch size.
76
+ 3. Click **"🚀 Start Training"**.
77
+ 4. The job will run in the background, and you can monitor its progress in the "Training Status" and "Training Progress" sections. Completed models and logs can be downloaded directly from the UI.
78
+
79
+ ---
80
+
81
+ ## Project Structure Overview
82
+
83
+ - `app.py`: Main Streamlit application entry point.
84
+ - `modules/`: Contains all major feature modules.
85
+ - `training_ui.py`: Renders the "Model Training Hub" tab.
86
+ - `scripts/`: Contains command-line tools.
87
+ - `train_model.py`: The CLI for running training jobs.
88
+ - `inspect_weights.py`: A diagnostic tool to check model weight files.
89
+ - `utils/`: Core utilities for the application.
90
+ - `training_engine.py`: **The new central training logic.**
91
+ - `training_manager.py`: The backend manager for UI-based training jobs.
92
+ - `training_types.py`: **New file for shared training data structures.**
93
+ - `models/`: Model definitions and the central model registry.
94
+ - `outputs/`: Default directory for saved model weights and training logs.
modules/training_ui.py CHANGED
@@ -17,12 +17,8 @@ import json
17
  from datetime import datetime, timedelta
18
 
19
  from models.registry import choices as model_choices, get_model_info
20
- from utils.training_manager import (
21
- get_training_manager,
22
- TrainingConfig,
23
- TrainingStatus,
24
- TrainingJob,
25
- )
26
 
27
 
28
  def render_training_tab():
 
17
  from datetime import datetime, timedelta
18
 
19
  from models.registry import choices as model_choices, get_model_info
20
+ from utils.training_manager import get_training_manager, TrainingJob
21
+ from utils.training_types import TrainingConfig, TrainingStatus
 
 
 
 
22
 
23
 
24
  def render_training_tab():
modules/ui_components.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from typing import Union
 
10
  import time
11
  from config import TARGET_LEN, LABEL_MAP, MODEL_WEIGHTS_DIR
12
  from models.registry import choices, get_model_info
@@ -18,16 +19,13 @@ from modules.callbacks import (
18
  reset_ephemeral_state,
19
  log_message,
20
  )
21
- from core_logic import (
22
- get_sample_files,
23
- load_model,
24
- run_inference,
25
- parse_spectrum_data,
26
- label_file,
27
- )
28
  from utils.results_manager import ResultsManager
29
- from utils.multifile import process_multiple_files
30
- from utils.preprocessing import resample_spectrum, validate_spectrum_modality
 
 
 
31
  from utils.confidence import calculate_softmax_confidence
32
 
33
 
@@ -69,9 +67,6 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
69
  return Image.open(buf)
70
 
71
 
72
- # //////////////////////////////////////////
73
-
74
-
75
  def render_confidence_progress(
76
  probs: np.ndarray,
77
  labels: list[str] = ["Stable", "Weathered"],
@@ -132,9 +127,6 @@ def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
132
  st.caption(f"**{k}:** {v}")
133
 
134
 
135
- # //////////////////////////////////////////
136
-
137
-
138
  def render_model_meta(model_choice: str):
139
  info = get_model_info(model_choice)
140
  emoji = info.get("emoji", "")
@@ -152,9 +144,6 @@ def render_model_meta(model_choice: str):
152
  st.caption(desc)
153
 
154
 
155
- # //////////////////////////////////////////
156
-
157
-
158
  def get_confidence_description(logit_margin):
159
  """Get human-readable confidence description"""
160
  if logit_margin > 1000:
@@ -167,9 +156,6 @@ def get_confidence_description(logit_margin):
167
  return "LOW", "🔴"
168
 
169
 
170
- # //////////////////////////////////////////
171
-
172
-
173
  def render_sidebar():
174
  with st.sidebar:
175
  # Header
@@ -254,7 +240,6 @@ def render_sidebar():
254
  )
255
 
256
 
257
- # //////////////////////////////////////////
258
  def render_input_column():
259
  st.markdown("##### Data Input")
260
 
@@ -393,6 +378,7 @@ def render_input_column():
393
 
394
  # Handle form submission
395
  if submitted and inference_ready:
 
396
  if st.session_state.get("batch_mode"):
397
  batch_files = st.session_state.get("batch_files", [])
398
  with st.spinner(f"Processing {len(batch_files)} files ..."):
@@ -405,7 +391,31 @@ def render_input_column():
405
  )
406
  else:
407
  try:
408
- x_raw, y_raw = parse_spectrum_data(st.session_state["input_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  # Validate that spectrum matches selected modality
411
  selected_modality = st.session_state.get("modality_select", "raman")
@@ -430,7 +440,10 @@ def render_input_column():
430
  else:
431
  st.stop() # Stop processing until user confirms
432
 
433
- x_resampled, y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN)
 
 
 
434
  st.session_state.update(
435
  {
436
  "x_raw": x_raw,
@@ -444,9 +457,6 @@ def render_input_column():
444
  st.error(f"Error processing spectrum data: {e}")
445
 
446
 
447
- # //////////////////////////////////////////
448
-
449
-
450
  def render_results_column():
451
  # Get the current mode and check for batch results
452
  is_batch_mode = st.session_state.get("batch_mode", False)
@@ -483,7 +493,7 @@ def render_results_column():
483
  else None
484
  ),
485
  modality=st.session_state.get("modality_select", "raman"),
486
- _cache_key=cache_key,
487
  )
488
  if prediction is None:
489
  st.error(
@@ -491,6 +501,11 @@ def render_results_column():
491
  )
492
  st.stop() # prevents the rest of the code in this block from executing
493
 
 
 
 
 
 
494
  log_message(
495
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
496
  )
@@ -556,6 +571,8 @@ def render_results_column():
556
  "⚠️ Model choice is not defined. Please select a model from the sidebar."
557
  )
558
  st.stop()
 
 
559
  model_path = os.path.join(MODEL_WEIGHTS_DIR, f"{model_choice}_model.pth")
560
  mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
561
  file_hash = (
@@ -573,88 +590,188 @@ def render_results_column():
573
  )
574
 
575
  if active_tab == "Details":
576
- st.markdown('<div class="expander-results">', unsafe_allow_html=True)
577
  # Use a dynamic and informative title for the expander
578
  with st.expander(f"Results for {filename}", expanded=True):
579
 
580
- # --- START: STREAMLINED METRICS ---
581
- # A single, powerful row for the most important results.
582
- key_metric_cols = st.columns(3)
 
 
583
 
584
- # Metric 1: The Prediction
585
- key_metric_cols[0].metric("Prediction", predicted_class)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
 
587
- # Metric 2: The Confidence (with level in tooltip)
588
- confidence_icon = (
589
- "🟢"
590
- if max_confidence >= 0.8
591
- else "🟡" if max_confidence >= 0.6 else "🔴"
592
  )
593
- key_metric_cols[1].metric(
594
- "Confidence",
595
- f"{confidence_icon} {max_confidence:.1%}",
596
- help=f"Confidence Level: {confidence_desc}",
597
  )
598
 
599
- # Metric 3: Ground Truth + Correctness (Combined)
600
- if true_label_idx is not None:
601
- is_correct = predicted_class == true_label_str
602
- delta_text = "✅ Correct" if is_correct else "❌ Incorrect"
603
- # Use delta_color="normal" to let the icon provide the visual cue
604
- key_metric_cols[2].metric(
605
- "Ground Truth",
606
- true_label_str,
607
- delta=delta_text,
608
- delta_color="normal",
 
 
 
 
 
 
609
  )
610
- else:
611
- key_metric_cols[2].metric("Ground Truth", "N/A")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
 
613
- st.divider()
614
- # --- END: STREAMLINED METRICS ---
615
 
616
- # --- START: CONSOLIDATED CONFIDENCE ANALYSIS ---
617
- st.markdown("##### Probability Breakdown")
 
 
 
 
618
 
619
- # This custom bullet bar logic remains as it is highly specific and valuable
620
- def create_bullet_bar(probability, width=20, predicted=False):
621
- filled_count = int(probability * width)
622
- bar = "▤" * filled_count + "▢" * (width - filled_count)
623
- percentage = f"{probability:.1%}"
624
- pred_marker = "↩ Predicted" if predicted else ""
625
- return f"{bar} {percentage} {pred_marker}"
 
 
626
 
627
- if probs is not None:
628
- stable_prob, weathered_prob = probs[0], probs[1]
629
- else:
630
- st.error(
631
- " Probability values are missing. Please check the inference process."
 
 
 
 
 
632
  )
633
- # Default values to prevent further errors
634
- stable_prob, weathered_prob = 0.0, 0.0
635
- is_stable_predicted, is_weathered_predicted = (
636
- int(prediction) == 0
637
- ), (int(prediction) == 1)
638
-
639
- st.markdown(
640
- f"""
641
- <div style="font-family: 'Fira Code', monospace;">
642
- Stable (Unweathered)<br>
643
- {create_bullet_bar(stable_prob, predicted=is_stable_predicted)}<br><br>
644
- Weathered (Degraded)<br>
645
- {create_bullet_bar(weathered_prob, predicted=is_weathered_predicted)}
646
- </div>
647
- """,
648
- unsafe_allow_html=True,
649
- )
650
 
651
- st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
  # METADATA FOOTER
654
  st.caption(
655
- f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
656
  )
657
- st.markdown("</div>", unsafe_allow_html=True)
658
 
659
  elif active_tab == "Technical":
660
  with st.container():
@@ -879,9 +996,6 @@ def render_results_column():
879
 
880
  # Technical details
881
  # MODIFIED: Wrap the expander in a div with the 'expander-advanced' class
882
- st.markdown(
883
- '<div class="expander-advanced">', unsafe_allow_html=True
884
- )
885
  with st.expander("🔧 Technical Details", expanded=False):
886
  st.markdown(
887
  """
@@ -902,9 +1016,6 @@ def render_results_column():
902
  - Normalization: None (preserves intensity relationships)
903
  """
904
  )
905
- st.markdown(
906
- "</div>", unsafe_allow_html=True
907
- ) # Close the wrapper div
908
 
909
  render_time = time.time() - start_render
910
  log_message(
@@ -987,9 +1098,6 @@ def render_results_column():
987
  )
988
 
989
 
990
- # //////////////////////////////////////////
991
-
992
-
993
  def render_comparison_tab():
994
  """Render the multi-model comparison interface"""
995
  import streamlit as st
@@ -1001,7 +1109,7 @@ def render_comparison_tab():
1001
  get_models_metadata,
1002
  )
1003
  from utils.results_manager import ResultsManager
1004
- from core_logic import get_sample_files, run_inference, parse_spectrum_data
1005
  from utils.preprocessing import preprocess_spectrum
1006
  from utils.multifile import parse_spectrum_data
1007
  import numpy as np
@@ -1159,8 +1267,16 @@ def render_comparison_tab():
1159
  start_time = time.time()
1160
 
1161
  # Run inference
 
 
 
1162
  prediction, logits_list, probs, inference_time, logits = (
1163
- run_inference(y_processed, model_name, modality=modality)
 
 
 
 
 
1164
  )
1165
 
1166
  processing_time = time.time() - start_time
@@ -1587,15 +1703,9 @@ def render_comparison_tab():
1587
  )
1588
 
1589
 
1590
- # //////////////////////////////////////////
1591
-
1592
-
1593
  from utils.performance_tracker import display_performance_dashboard
1594
 
1595
 
1596
  def render_performance_tab():
1597
  """Render the performance tracking and analysis tab."""
1598
  display_performance_dashboard()
1599
-
1600
-
1601
- # //////////////////////////////////////////
 
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from typing import Union
10
+ import uuid
11
  import time
12
  from config import TARGET_LEN, LABEL_MAP, MODEL_WEIGHTS_DIR
13
  from models.registry import choices, get_model_info
 
19
  reset_ephemeral_state,
20
  log_message,
21
  )
22
+ from core_logic import get_sample_files, load_model, run_inference, label_file
 
 
 
 
 
 
23
  from utils.results_manager import ResultsManager
24
+ from utils.multifile import process_multiple_files, parse_spectrum_data
25
+ from utils.preprocessing import (
26
+ validate_spectrum_modality,
27
+ preprocess_spectrum,
28
+ )
29
  from utils.confidence import calculate_softmax_confidence
30
 
31
 
 
67
  return Image.open(buf)
68
 
69
 
 
 
 
70
  def render_confidence_progress(
71
  probs: np.ndarray,
72
  labels: list[str] = ["Stable", "Weathered"],
 
127
  st.caption(f"**{k}:** {v}")
128
 
129
 
 
 
 
130
  def render_model_meta(model_choice: str):
131
  info = get_model_info(model_choice)
132
  emoji = info.get("emoji", "")
 
144
  st.caption(desc)
145
 
146
 
 
 
 
147
  def get_confidence_description(logit_margin):
148
  """Get human-readable confidence description"""
149
  if logit_margin > 1000:
 
156
  return "LOW", "🔴"
157
 
158
 
 
 
 
159
  def render_sidebar():
160
  with st.sidebar:
161
  # Header
 
240
  )
241
 
242
 
 
243
  def render_input_column():
244
  st.markdown("##### Data Input")
245
 
 
378
 
379
  # Handle form submission
380
  if submitted and inference_ready:
381
+ st.session_state["run_uuid"] = uuid.uuid4().hex[:8]
382
  if st.session_state.get("batch_mode"):
383
  batch_files = st.session_state.get("batch_files", [])
384
  with st.spinner(f"Processing {len(batch_files)} files ..."):
 
391
  )
392
  else:
393
  try:
394
+ x_raw, y_raw = parse_spectrum_data(
395
+ st.session_state["input_text"],
396
+ filename=st.session_state.get("filename", "unknown"),
397
+ )
398
+
399
+ # QC Summary
400
+ st.session_state["qc_summary"] = {
401
+ "n_points": len(x_raw),
402
+ "x_min": f"{np.min(x_raw):.1f}",
403
+ "x_max": f"{np.max(x_raw):.1f}",
404
+ "monotonic_x": bool(np.all(np.diff(x_raw) > 0)),
405
+ "nan_free": not (
406
+ np.any(np.isnan(x_raw)) or np.any(np.isnan(y_raw))
407
+ ),
408
+ "variance_proxy": f"{np.var(y_raw):.2e}",
409
+ }
410
+
411
+ # Preprocessing parameters
412
+ preproc_params = {
413
+ "target_len": TARGET_LEN,
414
+ "modality": st.session_state.get("modality_select", "raman"),
415
+ "do_baseline": True,
416
+ "do_smooth": True,
417
+ "do_normalize": True,
418
+ }
419
 
420
  # Validate that spectrum matches selected modality
421
  selected_modality = st.session_state.get("modality_select", "raman")
 
440
  else:
441
  st.stop() # Stop processing until user confirms
442
 
443
+ x_resampled, y_resampled = preprocess_spectrum(
444
+ x_raw, y_raw, **preproc_params
445
+ )
446
+ st.session_state["preproc_params"] = preproc_params
447
  st.session_state.update(
448
  {
449
  "x_raw": x_raw,
 
457
  st.error(f"Error processing spectrum data: {e}")
458
 
459
 
 
 
 
460
  def render_results_column():
461
  # Get the current mode and check for batch results
462
  is_batch_mode = st.session_state.get("batch_mode", False)
 
493
  else None
494
  ),
495
  modality=st.session_state.get("modality_select", "raman"),
496
+ cache_key=cache_key,
497
  )
498
  if prediction is None:
499
  st.error(
 
501
  )
502
  st.stop() # prevents the rest of the code in this block from executing
503
 
504
+ # Store results in session state for the Details tab
505
+ st.session_state["prediction"] = prediction
506
+ st.session_state["probs"] = probs
507
+ st.session_state["inference_time"] = inference_time
508
+
509
  log_message(
510
  f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
511
  )
 
571
  "⚠️ Model choice is not defined. Please select a model from the sidebar."
572
  )
573
  st.stop()
574
+ model_info = get_model_info(model_choice)
575
+ st.session_state["model_info"] = model_info
576
  model_path = os.path.join(MODEL_WEIGHTS_DIR, f"{model_choice}_model.pth")
577
  mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None
578
  file_hash = (
 
590
  )
591
 
592
  if active_tab == "Details":
 
593
  # Use a dynamic and informative title for the expander
594
  with st.expander(f"Results for {filename}", expanded=True):
595
 
596
+ # ...inside the Details tab, after metrics...
597
+
598
+ import json, math, uuid
599
+
600
+ st.subheader("Probability Breakdown")
601
 
602
+ def _entropy(ps):
603
+ ps = [max(min(float(p), 1.0), 1e-12) for p in ps]
604
+ return -sum(p * math.log(p) for p in ps)
605
+
606
+ def _badge(text, kind="info"):
607
+ palette = {
608
+ "info": ("#334155", "#e2e8f0"),
609
+ "warn": ("#7c2d12", "#fde68a"),
610
+ "good": ("#064e3b", "#bbf7d0"),
611
+ "bad": ("#7f1d1d", "#fecaca"),
612
+ }
613
+ bg, fg = palette.get(kind, palette["info"])
614
+ st.markdown(
615
+ f"<span style='background:{bg};color:{fg};padding:4px 8px;"
616
+ f"border-radius:6px;font-size:0.80rem;white-space:nowrap'>{text}</span>",
617
+ unsafe_allow_html=True,
618
+ )
619
+
620
+ def _render_prob_row(label: str, prob: float, is_pred: bool):
621
+ c1, c2, c3 = st.columns([2, 7, 3])
622
+ with c1:
623
+ st.write(label)
624
+ with c2:
625
+ st.progress(min(max(prob, 0.0), 1.0))
626
+ with c3:
627
+ suffix = " \u2190 Predicted" if is_pred else ""
628
+ st.write(f"{prob:.1%}{suffix}")
629
+
630
+ probs = st.session_state.get("probs")
631
+ prediction = st.session_state.get("prediction")
632
+ inference_time = float(st.session_state.get("inference_time", 0.0))
633
+
634
+ if probs is None or len(probs) != 2:
635
+ st.error(
636
+ "❌ Probability values are missing or invalid. Check the inference process."
637
+ )
638
+ stable_prob, weathered_prob = 0.0, 0.0
639
+ else:
640
+ stable_prob, weathered_prob = float(probs[0]), float(probs[1])
641
 
642
+ is_stable_predicted = (
643
+ (int(prediction) == 0)
644
+ if prediction is not None
645
+ else (stable_prob >= weathered_prob)
 
646
  )
647
+ is_weathered_predicted = (
648
+ (int(prediction) == 1)
649
+ if prediction is not None
650
+ else (weathered_prob > stable_prob)
651
  )
652
 
653
+ margin = abs(stable_prob - weathered_prob)
654
+ entropy = _entropy([stable_prob, weathered_prob])
655
+ thresh = float(st.session_state.get("decision_threshold", 0.5))
656
+ cal = st.session_state.get("calibration", {}) or {}
657
+ cal_enabled = bool(cal.get("enabled", False))
658
+ ece = cal.get("ece", None)
659
+
660
+ ABSTAIN_TAU = 0.10
661
+ OOD_MAX_SOFT = 0.60
662
+ max_softmax = max(stable_prob, weathered_prob)
663
+
664
+ colA, colB, colC, colD = st.columns([3, 3, 3, 3])
665
+ with colA:
666
+ st.metric(
667
+ "Predicted",
668
+ "Stable" if is_stable_predicted else "Weathered",
669
  )
670
+ with colB:
671
+ st.metric("Decision Margin", f"{margin:.2f}")
672
+ with colC:
673
+ st.metric("Entropy", f"{entropy:.3f}")
674
+ with colD:
675
+ st.metric("Threshold", f"{thresh:.2f}")
676
+
677
+ row = st.columns([3, 3, 6])
678
+ with row[0]:
679
+ if margin < ABSTAIN_TAU:
680
+ _badge("Low margin — consider abstain / re-measure", "warn")
681
+ with row[1]:
682
+ if max_softmax < OOD_MAX_SOFT:
683
+ _badge("Low confidence — possible OOD", "bad")
684
+ with row[2]:
685
+ if cal_enabled:
686
+ _badge(
687
+ (
688
+ f"Calibrated (ECE={ece:.2%})"
689
+ if isinstance(ece, (int, float))
690
+ else "Calibrated"
691
+ ),
692
+ "good",
693
+ )
694
+ else:
695
+ _badge(
696
+ "Uncalibrated — probabilities may be miscalibrated",
697
+ "info",
698
+ )
699
 
700
+ st.write("")
 
701
 
702
+ _render_prob_row(
703
+ "Stable (Unweathered)", stable_prob, is_stable_predicted
704
+ )
705
+ _render_prob_row(
706
+ "Weathered (Degraded)", weathered_prob, is_weathered_predicted
707
+ )
708
 
709
+ qc = st.session_state.get("qc_summary", {}) or {}
710
+ pp = st.session_state.get("preproc_params", {}) or {}
711
+ model_info = st.session_state.get("model_info", {}) or {}
712
+ run_info = {
713
+ "model": model_choice,
714
+ "inference_time_s": inference_time,
715
+ "run_uuid": st.session_state.get("run_uuid", ""),
716
+ "app_commit": st.session_state.get("app_commit", "unknown"),
717
+ }
718
 
719
+ with st.expander("Input QC"):
720
+ st.write(
721
+ {
722
+ "n_points": qc.get("n_points", "N/A"),
723
+ "x_min_cm-1": qc.get("x_min", "N/A"),
724
+ "x_max_cm-1": qc.get("x_max", "N/A"),
725
+ "monotonic_x": qc.get("monotonic_x", "N/A"),
726
+ "nan_free": qc.get("nan_free", "N/A"),
727
+ "variance_proxy": qc.get("variance_proxy", "N/A"),
728
+ }
729
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
+ with st.expander("Preprocessing (applied)"):
732
+ st.write(pp)
733
+
734
+ with st.expander("Model & Run"):
735
+ st.write(
736
+ {
737
+ "model_name": model_info.get("name", model_choice),
738
+ "version": model_info.get("version", "n/a"),
739
+ "weights_mtime": model_info.get("weights_mtime", "n/a"),
740
+ "cv_accuracy": model_info.get("cv_accuracy", "n/a"),
741
+ "class_priors": model_info.get("class_priors", "n/a"),
742
+ **run_info,
743
+ }
744
+ )
745
+
746
+ export_payload = {
747
+ "prediction": "stable" if is_stable_predicted else "weathered",
748
+ "probs": {"stable": stable_prob, "weathered": weathered_prob},
749
+ "margin": margin,
750
+ "entropy": entropy,
751
+ "threshold": thresh,
752
+ "calibration": {
753
+ "enabled": cal_enabled,
754
+ "ece": ece,
755
+ "method": cal.get("method"),
756
+ "T": cal.get("T"),
757
+ },
758
+ "qc": qc,
759
+ "preprocessing": pp,
760
+ "model_info": model_info,
761
+ "run_info": run_info,
762
+ }
763
+ fname = f"result_{run_info['run_uuid'] or uuid.uuid4().hex}.json"
764
+ st.download_button(
765
+ "Download result JSON",
766
+ json.dumps(export_payload, indent=2),
767
+ file_name=fname,
768
+ mime="application/json",
769
+ )
770
 
771
  # METADATA FOOTER
772
  st.caption(
773
+ f"Analyzed with **{run_info['model']}** in **{inference_time:.2f}s**."
774
  )
 
775
 
776
  elif active_tab == "Technical":
777
  with st.container():
 
996
 
997
  # Technical details
998
  # MODIFIED: Wrap the expander in a div with the 'expander-advanced' class
 
 
 
999
  with st.expander("🔧 Technical Details", expanded=False):
1000
  st.markdown(
1001
  """
 
1016
  - Normalization: None (preserves intensity relationships)
1017
  """
1018
  )
 
 
 
1019
 
1020
  render_time = time.time() - start_render
1021
  log_message(
 
1098
  )
1099
 
1100
 
 
 
 
1101
  def render_comparison_tab():
1102
  """Render the multi-model comparison interface"""
1103
  import streamlit as st
 
1109
  get_models_metadata,
1110
  )
1111
  from utils.results_manager import ResultsManager
1112
+ from core_logic import get_sample_files, run_inference
1113
  from utils.preprocessing import preprocess_spectrum
1114
  from utils.multifile import parse_spectrum_data
1115
  import numpy as np
 
1267
  start_time = time.time()
1268
 
1269
  # Run inference
1270
+ cache_key = hashlib.md5(
1271
+ f"{y_processed.tobytes()}{model_name}".encode()
1272
+ ).hexdigest()
1273
  prediction, logits_list, probs, inference_time, logits = (
1274
+ run_inference(
1275
+ y_processed,
1276
+ model_name,
1277
+ modality=modality,
1278
+ cache_key=cache_key,
1279
+ )
1280
  )
1281
 
1282
  processing_time = time.time() - start_time
 
1703
  )
1704
 
1705
 
 
 
 
1706
  from utils.performance_tracker import display_performance_dashboard
1707
 
1708
 
1709
  def render_performance_tab():
1710
  """Render the performance tracking and analysis tab."""
1711
  display_performance_dashboard()
 
 
 
outputs/figure2_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:95d706b97f4eee611c48983b13f72e8684ae4ca78f9b68976e07d01891225241
3
  size 4418520
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852247bf0540aa947c9887a7e004c0858d622cfa0413e9b26bd9f5dab359ad5e
3
  size 4418520
pages/2_Enhanced_Analysis.py CHANGED
@@ -27,7 +27,8 @@ from modules.modern_ml_architecture import (
27
  ModernMLPipeline,
28
  )
29
  from modules.enhanced_data_pipeline import EnhancedDataPipeline
30
- from core_logic import load_model, parse_spectrum_data
 
31
  from models.registry import choices
32
  from config import TARGET_LEN
33
 
 
27
  ModernMLPipeline,
28
  )
29
  from modules.enhanced_data_pipeline import EnhancedDataPipeline
30
+ from core_logic import load_model
31
+ from utils.multifile import parse_spectrum_data
32
  from models.registry import choices
33
  from config import TARGET_LEN
34
 
scripts/train_model.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
4
  from datetime import datetime
5
  import argparse, numpy as np, torch
@@ -10,6 +11,9 @@ from sklearn.metrics import confusion_matrix
10
  import random
11
  import json
12
 
 
 
 
13
  # Reproducibility
14
  SEED = 42
15
  random.seed(SEED)
@@ -36,8 +40,26 @@ parser.add_argument("--batch-size", type=int, default=16)
36
  parser.add_argument("--epochs", type=int, default=10)
37
  parser.add_argument("--learning-rate", type=float, default=1e-3)
38
  parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  args = parser.parse_args()
 
41
 
42
  # Constants
43
  # Raman-only dataset (RDWP)
@@ -48,6 +70,18 @@ NUM_FOLDS = 10
48
  # Ensure output dirs exist
49
  os.makedirs("outputs", exist_ok=True)
50
  os.makedirs("outputs/logs", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  print("Preprocessing Configuration:")
53
  print(f" Resample to : {args.target_len}")
@@ -55,6 +89,27 @@ print(f" Resample to : {args.target_len}")
55
  print(f" Baseline Correct: {'✅' if args.baseline else '❌'}")
56
  print(f" Smoothing : {'✅' if args.smooth else '❌'}")
57
  print(f" Normalization : {'✅' if args.normalize else '❌'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Load + Preprocess data
60
  print("🔄 Loading and preprocessing data ...")
@@ -73,24 +128,52 @@ print(f"🔍 Using model: {args.model}")
73
  skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
74
  fold_accuracies = []
75
  all_conf_matrices = []
 
 
 
76
 
77
  for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
78
  print(f"\n🔁 Fold {fold}/{NUM_FOLDS}")
 
 
 
79
 
80
  X_train, X_val = X[train_idx], X[val_idx]
81
  y_train, y_val = y[train_idx], y[val_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  train_loader = DataLoader(
84
  TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
85
  batch_size=args.batch_size, shuffle=True)
86
  val_loader = DataLoader(
87
  TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)))
 
 
 
 
88
 
89
  # Model selection
90
  model = build_model(args.model, args.target_len).to(DEVICE)
91
 
92
  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
93
  criterion = torch.nn.CrossEntropyLoss()
 
 
 
94
 
95
  for epoch in range(args.epochs):
96
  model.train()
@@ -98,12 +181,18 @@ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
98
  for inputs, labels in train_loader:
99
  inputs = inputs.unsqueeze(1).to(DEVICE)
100
  labels = labels.to(DEVICE)
 
 
 
 
101
 
102
  optimizer.zero_grad()
103
  loss = criterion(model(inputs), labels)
104
  loss.backward()
105
  optimizer.step()
106
  RUNNING_LOSS += loss.item()
 
 
107
 
108
  # After fold loop (outside the epoch loop), print 1 line:
109
  print(f"✅ Fold {fold} done. Final loss: {RUNNING_LOSS:.4f}")
@@ -169,4 +258,6 @@ def save_diagnostics_log(fold_acc, confs, args_param, output_path):
169
  print(f"🧠 Diagnostics written to {output_path}")
170
 
171
  log_path = f"outputs/logs/raman_{args.model}_diagnostics.json"
172
- save_diagnostics_log(fold_accuracies, all_conf_matrices, args, log_path)
 
 
 
1
  import os
2
  import sys
3
+
4
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
  from datetime import datetime
6
  import argparse, numpy as np, torch
 
11
  import random
12
  import json
13
 
14
+ from utils.training_engine import TrainingEngine
15
+ from utils.training_manager import TrainingConfig
16
+
17
  # Reproducibility
18
  SEED = 42
19
  random.seed(SEED)
 
40
  parser.add_argument("--epochs", type=int, default=10)
41
  parser.add_argument("--learning-rate", type=float, default=1e-3)
42
  parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
43
+ def parse_args():
44
+ """Parses command-line arguments for training."""
45
+ parser = argparse.ArgumentParser(
46
+ description="Run 10-fold CV on Raman data with optional preprocessing."
47
+ )
48
+ parser.add_argument("--target-len", type=int, default=500)
49
+ parser.add_argument("--baseline", action="store_true")
50
+ parser.add_argument("--smooth", action="store_true")
51
+ parser.add_argument("--normalize", action="store_true")
52
+ parser.add_argument("--batch-size", type=int, default=16)
53
+ parser.add_argument("--epochs", type=int, default=10)
54
+ parser.add_argument("--learning-rate", type=float, default=1e-3)
55
+ parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
56
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
57
+ parser.add_argument("--dataset-path", type=str, default="datasets/rdwp")
58
+ parser.add_argument("--num-folds", type=int, default=10)
59
+ parser.add_argument("--cv-strategy", type=str, default="stratified_kfold", choices=["stratified_kfold", "kfold"])
60
 
61
  args = parser.parse_args()
62
+ return parser.parse_args()
63
 
64
  # Constants
65
  # Raman-only dataset (RDWP)
 
70
  # Ensure output dirs exist
71
  os.makedirs("outputs", exist_ok=True)
72
  os.makedirs("outputs/logs", exist_ok=True)
73
+ def cli_progress_callback(progress_data: dict):
74
+ """A simple callback to print progress to the console."""
75
+ if progress_data["type"] == "fold_start":
76
+ print(f"\n🔁 Fold {progress_data['fold']}/{progress_data['total_folds']}")
77
+ elif progress_data["type"] == "epoch_end":
78
+ # Print progress on the same line
79
+ print(
80
+ f" Epoch {progress_data['epoch']}/{progress_data['total_epochs']} | Loss: {progress_data['loss']:.4f}",
81
+ end="\r",
82
+ )
83
+ elif progress_data["type"] == "fold_end":
84
+ print(f"\n✅ Fold {progress_data['fold']} Accuracy: {progress_data['accuracy'] * 100:.2f}%")
85
 
86
  print("Preprocessing Configuration:")
87
  print(f" Resample to : {args.target_len}")
 
89
  print(f" Baseline Correct: {'✅' if args.baseline else '❌'}")
90
  print(f" Smoothing : {'✅' if args.smooth else '❌'}")
91
  print(f" Normalization : {'✅' if args.normalize else '❌'}")
92
+ def save_diagnostics_log(results: dict, config: TrainingConfig, output_path: str):
93
+ """Saves a JSON log file with training diagnostics."""
94
+ fold_metrics = [
95
+ {"fold": i + 1, "accuracy": float(acc), "confusion_matrix": cm}
96
+ for i, (acc, cm) in enumerate(
97
+ zip(results["fold_accuracies"], results["confusion_matrices"])
98
+ )
99
+ ]
100
+ log = {
101
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
102
+ "model_name": config.model_name,
103
+ "config": config.to_dict(),
104
+ "fold_metrics": fold_metrics,
105
+ "overall": {
106
+ "mean_accuracy": results["mean_accuracy"],
107
+ "std_accuracy": results["std_accuracy"],
108
+ },
109
+ }
110
+ with open(output_path, "w", encoding="utf-8") as f:
111
+ json.dump(log, f, indent=2)
112
+ print(f"🧠 Diagnostics written to {output_path}")
113
 
114
  # Load + Preprocess data
115
  print("🔄 Loading and preprocessing data ...")
 
128
  skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
129
  fold_accuracies = []
130
  all_conf_matrices = []
131
+ def main():
132
+ """Main function to run the training process from the CLI."""
133
+ args = parse_args()
134
 
135
  for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
136
  print(f"\n🔁 Fold {fold}/{NUM_FOLDS}")
137
+ # Ensure output dirs exist
138
+ os.makedirs("outputs/weights", exist_ok=True)
139
+ os.makedirs("outputs/logs", exist_ok=True)
140
 
141
  X_train, X_val = X[train_idx], X[val_idx]
142
  y_train, y_val = y[train_idx], y[val_idx]
143
+ # Create TrainingConfig from CLI args
144
+ config = TrainingConfig(
145
+ model_name=args.model,
146
+ dataset_path=args.dataset_path,
147
+ target_len=args.target_len,
148
+ batch_size=args.batch_size,
149
+ epochs=args.epochs,
150
+ learning_rate=args.learning_rate,
151
+ num_folds=args.num_folds,
152
+ baseline_correction=args.baseline,
153
+ smoothing=args.smooth,
154
+ normalization=args.normalize,
155
+ device=args.device,
156
+ cv_strategy=args.cv_strategy,
157
+ )
158
 
159
  train_loader = DataLoader(
160
  TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
161
  batch_size=args.batch_size, shuffle=True)
162
  val_loader = DataLoader(
163
  TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)))
164
+ print("🔄 Loading and preprocessing data...")
165
+ X, y = preprocess_dataset(config.dataset_path, target_len=config.target_len)
166
+ print(f"✅ Data Loaded: {X.shape[0]} samples, {X.shape[1]} features each.")
167
+ print(f"🔍 Using model: {config.model_name}")
168
 
169
  # Model selection
170
  model = build_model(args.model, args.target_len).to(DEVICE)
171
 
172
  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
173
  criterion = torch.nn.CrossEntropyLoss()
174
+ # Run training
175
+ engine = TrainingEngine(config)
176
+ results = engine.run(X, y, progress_callback=cli_progress_callback)
177
 
178
  for epoch in range(args.epochs):
179
  model.train()
 
181
  for inputs, labels in train_loader:
182
  inputs = inputs.unsqueeze(1).to(DEVICE)
183
  labels = labels.to(DEVICE)
184
+ # Save final model and logs
185
+ model_path = f"outputs/weights/{config.model_name}_model.pth"
186
+ torch.save(results["model_state_dict"], model_path)
187
+ print(f"\n✅ Model saved to {model_path}")
188
 
189
  optimizer.zero_grad()
190
  loss = criterion(model(inputs), labels)
191
  loss.backward()
192
  optimizer.step()
193
  RUNNING_LOSS += loss.item()
194
+ log_path = f"outputs/logs/{config.model_name}_cli_diagnostics.json"
195
+ save_diagnostics_log(results, config, log_path)
196
 
197
  # After fold loop (outside the epoch loop), print 1 line:
198
  print(f"✅ Fold {fold} done. Final loss: {RUNNING_LOSS:.4f}")
 
258
  print(f"🧠 Diagnostics written to {output_path}")
259
 
260
  log_path = f"outputs/logs/raman_{args.model}_diagnostics.json"
261
+ save_diagnostics_log(fold_accuracies, all_conf_matrices, args, log_path)
262
+ if __name__ == "__main__":
263
+ main()
utils/multifile.py CHANGED
@@ -11,6 +11,7 @@ import json
11
  import csv
12
  import io
13
  from pathlib import Path
 
14
 
15
  from .preprocessing import preprocess_spectrum
16
  from .errors import ErrorHandler, safe_execute
@@ -35,7 +36,7 @@ def detect_file_format(filename: str, content: str) -> str:
35
  try:
36
  json.loads(content)
37
  return "json"
38
- except:
39
  pass
40
  elif suffix == ".csv":
41
  return "csv"
@@ -50,7 +51,7 @@ def detect_file_format(filename: str, content: str) -> str:
50
  try:
51
  json.loads(content)
52
  return "json"
53
- except:
54
  pass
55
 
56
  # Try CSV (look for commas in first few lines)
@@ -63,12 +64,7 @@ def detect_file_format(filename: str, content: str) -> str:
63
  return "txt"
64
 
65
 
66
- # /////////////////////////////////////////////////////
67
-
68
-
69
- def parse_json_spectrum(
70
- content: str, filename: str = "unknown"
71
- ) -> Tuple[np.ndarray, np.ndarray]:
72
  """
73
  Parse spectrum data from JSON format.
74
 
@@ -79,7 +75,7 @@ def parse_json_spectrum(
79
  """
80
 
81
  try:
82
- data = json.load(content)
83
 
84
  # Format 1: Object with arrays
85
  if isinstance(data, dict):
@@ -135,12 +131,9 @@ def parse_json_spectrum(
135
  )
136
 
137
  except json.JSONDecodeError as e:
138
- raise ValueError(f"Invalid JSON format: {str(e)}")
139
  except Exception as e:
140
- raise ValueError(f"Failed to parse JSON spectrum: {str(e)}")
141
-
142
-
143
- # /////////////////////////////////////////////////////
144
 
145
 
146
  def parse_csv_spectrum(
@@ -208,10 +201,7 @@ def parse_csv_spectrum(
208
  return np.array(x_vals), np.array(y_vals)
209
 
210
  except Exception as e:
211
- raise ValueError(f"Failed to parse CSV spectrum: {str(e)}")
212
-
213
-
214
- # /////////////////////////////////////////////////////
215
 
216
 
217
  def parse_spectrum_data(
@@ -235,7 +225,7 @@ def parse_spectrum_data(
235
 
236
  # Parse based on detected/specified format
237
  if file_format == "json":
238
- x, y = parse_json_spectrum(text_content, filename)
239
  elif file_format == "csv":
240
  x, y = parse_csv_spectrum(text_content, filename)
241
  else: # Default to TXT format
@@ -247,10 +237,7 @@ def parse_spectrum_data(
247
  return x, y
248
 
249
  except Exception as e:
250
- raise ValueError(f"Failed to parse spectrum data: {str(e)}")
251
-
252
-
253
- # /////////////////////////////////////////////////////
254
 
255
 
256
  def parse_txt_spectrum(
@@ -287,7 +274,7 @@ def parse_txt_spectrum(
287
  f"Parsing {filename}",
288
  )
289
 
290
- except Exception as e:
291
  ErrorHandler.log_warning(
292
  f"Error parsing line {i+1}: '{line}'. Error: {e}",
293
  f"Parsing {filename}",
@@ -302,9 +289,6 @@ def parse_txt_spectrum(
302
  return np.array(x_vals), np.array(y_vals)
303
 
304
 
305
- # /////////////////////////////////////////////////////
306
-
307
-
308
  def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
309
  """
310
  Validate parsed spectrum data for common issues.
@@ -332,9 +316,6 @@ def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
332
  )
333
 
334
 
335
- # /////////////////////////////////////////////////////
336
-
337
-
338
  def process_single_file(
339
  filename: str,
340
  text_content: str,
@@ -369,8 +350,11 @@ def process_single_file(
369
  )
370
 
371
  # 3. Run inference, passing modality
 
 
 
372
  prediction, logits_list, probs, inference_time, logits = run_inference_func(
373
- y_resampled, model_choice, modality=modality
374
  )
375
 
376
  if prediction is None:
@@ -418,7 +402,7 @@ def process_single_file(
418
  "y_resampled": y_resampled,
419
  }
420
 
421
- except Exception as e:
422
  ErrorHandler.log_error(e, f"processing {filename}")
423
  return {
424
  "filename": filename,
@@ -501,7 +485,7 @@ def process_multiple_files(
501
  },
502
  )
503
 
504
- except Exception as e:
505
  ErrorHandler.log_error(e, f"reading file {uploaded_file.name}")
506
  results.append(
507
  {
 
11
  import csv
12
  import io
13
  from pathlib import Path
14
+ import hashlib
15
 
16
  from .preprocessing import preprocess_spectrum
17
  from .errors import ErrorHandler, safe_execute
 
36
  try:
37
  json.loads(content)
38
  return "json"
39
+ except json.JSONDecodeError:
40
  pass
41
  elif suffix == ".csv":
42
  return "csv"
 
51
  try:
52
  json.loads(content)
53
  return "json"
54
+ except json.JSONDecodeError:
55
  pass
56
 
57
  # Try CSV (look for commas in first few lines)
 
64
  return "txt"
65
 
66
 
67
+ def parse_json_spectrum(content: str) -> Tuple[np.ndarray, np.ndarray]:
 
 
 
 
 
68
  """
69
  Parse spectrum data from JSON format.
70
 
 
75
  """
76
 
77
  try:
78
+ data = json.loads(content)
79
 
80
  # Format 1: Object with arrays
81
  if isinstance(data, dict):
 
131
  )
132
 
133
  except json.JSONDecodeError as e:
134
+ raise ValueError(f"Invalid JSON format: {str(e)}") from e
135
  except Exception as e:
136
+ raise ValueError(f"Failed to parse JSON spectrum: {str(e)}") from e
 
 
 
137
 
138
 
139
  def parse_csv_spectrum(
 
201
  return np.array(x_vals), np.array(y_vals)
202
 
203
  except Exception as e:
204
+ raise ValueError(f"Failed to parse CSV spectrum: {str(e)}") from e
 
 
 
205
 
206
 
207
  def parse_spectrum_data(
 
225
 
226
  # Parse based on detected/specified format
227
  if file_format == "json":
228
+ x, y = parse_json_spectrum(text_content)
229
  elif file_format == "csv":
230
  x, y = parse_csv_spectrum(text_content, filename)
231
  else: # Default to TXT format
 
237
  return x, y
238
 
239
  except Exception as e:
240
+ raise ValueError(f"Failed to parse spectrum data: {str(e)}") from e
 
 
 
241
 
242
 
243
  def parse_txt_spectrum(
 
274
  f"Parsing {filename}",
275
  )
276
 
277
+ except ValueError as e:
278
  ErrorHandler.log_warning(
279
  f"Error parsing line {i+1}: '{line}'. Error: {e}",
280
  f"Parsing {filename}",
 
289
  return np.array(x_vals), np.array(y_vals)
290
 
291
 
 
 
 
292
  def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None:
293
  """
294
  Validate parsed spectrum data for common issues.
 
316
  )
317
 
318
 
 
 
 
319
  def process_single_file(
320
  filename: str,
321
  text_content: str,
 
350
  )
351
 
352
  # 3. Run inference, passing modality
353
+ cache_key = hashlib.md5(
354
+ f"{y_resampled.tobytes()}{model_choice}".encode()
355
+ ).hexdigest()
356
  prediction, logits_list, probs, inference_time, logits = run_inference_func(
357
+ y_resampled, model_choice, modality=modality, cache_key=cache_key
358
  )
359
 
360
  if prediction is None:
 
402
  "y_resampled": y_resampled,
403
  }
404
 
405
+ except ValueError as e:
406
  ErrorHandler.log_error(e, f"processing {filename}")
407
  return {
408
  "filename": filename,
 
485
  },
486
  )
487
 
488
+ except ValueError as e:
489
  ErrorHandler.log_error(e, f"reading file {uploaded_file.name}")
490
  results.append(
491
  {
utils/training_engine.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core Training Engine for the POLYMEROS project.
3
+
4
+ This module contains the primary logic for model training and validation,
5
+ encapsulated in a reusable `TrainingEngine` class. It is designed to be
6
+ called by different interfaces, such as the command-line script
7
+ (train_model.py) and the web UI's TrainingManager.
8
+
9
+ This approach ensures that the core training process is consistent,
10
+ maintainable, and follows the DRY (Don't Repeat Yourself) principle.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from torch.utils.data import TensorDataset, DataLoader
17
+ from sklearn.metrics import confusion_matrix, accuracy_score
18
+
19
+ from .training_types import (
20
+ TrainingConfig,
21
+ TrainingProgress,
22
+ get_cv_splitter,
23
+ augment_spectral_data,
24
+ )
25
+ from models.registry import build as build_model
26
+
27
+
28
+ class TrainingEngine:
29
+ """Encapsulates the core model training and validation logic."""
30
+
31
+ def __init__(self, config: TrainingConfig):
32
+ """
33
+ Initializes the TrainingEngine with a given configuration.
34
+
35
+ Args:
36
+ config (TrainingConfig): The configuration object for the training run.
37
+ """
38
+ self.config = config
39
+ self.device = self._get_device()
40
+
41
+ def _get_device(self) -> torch.device:
42
+ """Selects the appropriate compute device."""
43
+ if self.config.device == "auto":
44
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ return torch.device(self.config.device)
46
+
47
+ def run(
48
+ self, X: np.ndarray, y: np.ndarray, progress_callback: callable = None
49
+ ) -> dict:
50
+ """
51
+ Executes the full cross-validation training and evaluation loop.
52
+
53
+ Args:
54
+ X (np.ndarray): Feature data.
55
+ y (np.ndarray): Label data.
56
+ progress_callback (callable, optional): A function to call with
57
+ progress updates. Defaults to None.
58
+
59
+ Returns:
60
+ dict: A dictionary containing the final results and metrics.
61
+ """
62
+ cv_splitter = get_cv_splitter(self.config.cv_strategy, self.config.num_folds)
63
+
64
+ fold_accuracies = []
65
+ all_conf_matrices = []
66
+ final_model_state = None
67
+
68
+ for fold, (train_idx, val_idx) in enumerate(cv_splitter.split(X, y), 1):
69
+ if progress_callback:
70
+ progress_callback(
71
+ {
72
+ "type": "fold_start",
73
+ "fold": fold,
74
+ "total_folds": self.config.num_folds,
75
+ }
76
+ )
77
+
78
+ X_train, X_val = X[train_idx], X[val_idx]
79
+ y_train, y_val = y[train_idx], y[val_idx]
80
+
81
+ # Apply data augmentation if enabled
82
+ if self.config.enable_augmentation:
83
+ X_train, y_train = augment_spectral_data(
84
+ X_train, y_train, noise_level=self.config.noise_level
85
+ )
86
+
87
+ train_loader = DataLoader(
88
+ TensorDataset(
89
+ torch.tensor(X_train, dtype=torch.float32),
90
+ torch.tensor(y_train, dtype=torch.long),
91
+ ),
92
+ batch_size=self.config.batch_size,
93
+ shuffle=True,
94
+ )
95
+ val_loader = DataLoader(
96
+ TensorDataset(
97
+ torch.tensor(X_val, dtype=torch.float32),
98
+ torch.tensor(y_val, dtype=torch.long),
99
+ )
100
+ )
101
+
102
+ model = build_model(self.config.model_name, self.config.target_len).to(
103
+ self.device
104
+ )
105
+ optimizer = torch.optim.Adam(
106
+ model.parameters(), lr=self.config.learning_rate
107
+ )
108
+ criterion = nn.CrossEntropyLoss()
109
+
110
+ for epoch in range(self.config.epochs):
111
+ model.train()
112
+ running_loss = 0.0
113
+ for inputs, labels in train_loader:
114
+ inputs = inputs.unsqueeze(1).to(self.device)
115
+ labels = labels.to(self.device)
116
+
117
+ optimizer.zero_grad()
118
+ outputs = model(inputs)
119
+ loss = criterion(outputs, labels)
120
+ loss.backward()
121
+ optimizer.step()
122
+ running_loss += loss.item()
123
+
124
+ if progress_callback:
125
+ progress_callback(
126
+ {
127
+ "type": "epoch_end",
128
+ "fold": fold,
129
+ "epoch": epoch + 1,
130
+ "total_epochs": self.config.epochs,
131
+ "loss": running_loss / len(train_loader),
132
+ }
133
+ )
134
+
135
+ # Validation
136
+ model.eval()
137
+ all_true, all_pred = [], []
138
+ with torch.no_grad():
139
+ for inputs, labels in val_loader:
140
+ inputs = inputs.unsqueeze(1).to(self.device)
141
+ outputs = model(inputs)
142
+ _, predicted = torch.max(outputs, 1)
143
+ all_true.extend(labels.cpu().numpy())
144
+ all_pred.extend(predicted.cpu().numpy())
145
+
146
+ acc = accuracy_score(all_true, all_pred)
147
+ fold_accuracies.append(acc)
148
+ all_conf_matrices.append(confusion_matrix(all_true, all_pred).tolist())
149
+ final_model_state = model.state_dict()
150
+
151
+ if progress_callback:
152
+ progress_callback({"type": "fold_end", "fold": fold, "accuracy": acc})
153
+
154
+ return {
155
+ "fold_accuracies": fold_accuracies,
156
+ "confusion_matrices": all_conf_matrices,
157
+ "mean_accuracy": np.mean(fold_accuracies),
158
+ "std_accuracy": np.std(fold_accuracies),
159
+ "model_state_dict": final_model_state,
160
+ }
utils/training_manager.py CHANGED
@@ -12,16 +12,14 @@ import threading
12
  import concurrent.futures
13
  import multiprocessing
14
  from datetime import datetime, timedelta
15
- from dataclasses import dataclass, asdict, field
16
- from enum import Enum
17
  from typing import Dict, List, Optional, Callable, Any, Tuple
18
  from pathlib import Path
 
19
 
20
  import torch
21
  import torch.nn as nn
22
  import numpy as np
23
  from torch.utils.data import TensorDataset, DataLoader
24
- from sklearn.model_selection import StratifiedKFold, KFold, TimeSeriesSplit
25
  from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
26
  from sklearn.metrics.pairwise import cosine_similarity
27
  from scipy.signal import find_peaks
@@ -30,6 +28,14 @@ from scipy.spatial.distance import euclidean
30
  # Add project-specific imports
31
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
32
  from models.registry import choices as model_choices, build as build_model
 
 
 
 
 
 
 
 
33
  from utils.preprocessing import preprocess_spectrum
34
 
35
 
@@ -143,74 +149,21 @@ def calculate_spectroscopy_metrics(
143
  return metrics
144
 
145
 
146
- def get_cv_splitter(strategy: str, n_splits: int = 10, random_state: int = 42):
147
- """Get cross-validation splitter based on strategy"""
148
- if strategy == "stratified_kfold":
149
- return StratifiedKFold(
150
- n_splits=n_splits, shuffle=True, random_state=random_state
151
- )
152
- elif strategy == "kfold":
153
- return KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
154
- elif strategy == "time_series_split":
155
- return TimeSeriesSplit(n_splits=n_splits)
156
- else:
157
- # Default to stratified k-fold
158
- return StratifiedKFold(
159
- n_splits=n_splits, shuffle=True, random_state=random_state
160
- )
161
-
162
-
163
- def augment_spectral_data(
164
- X: np.ndarray,
165
- y: np.ndarray,
166
- noise_level: float = 0.01,
167
- augmentation_factor: int = 2,
168
- ) -> Tuple[np.ndarray, np.ndarray]:
169
- """Augment spectral data with realistic noise and variations"""
170
- if augmentation_factor <= 1:
171
- return X, y
172
-
173
- augmented_X = [X]
174
- augmented_y = [y]
175
-
176
- for i in range(augmentation_factor - 1):
177
- # Add Gaussian noise
178
- noise = np.random.normal(0, noise_level, X.shape)
179
- X_noisy = X + noise
180
-
181
- # Add baseline drift (common in spectroscopy)
182
- baseline_drift = np.random.normal(0, noise_level * 0.5, (X.shape[0], 1))
183
- X_drift = X_noisy + baseline_drift
184
-
185
- # Add intensity scaling variation
186
- intensity_scale = np.random.normal(1.0, 0.05, (X.shape[0], 1))
187
- X_scaled = X_drift * intensity_scale
188
-
189
- # Ensure no negative values
190
- X_scaled = np.maximum(X_scaled, 0)
191
-
192
- augmented_X.append(X_scaled)
193
- augmented_y.append(y)
194
-
195
- return np.vstack(augmented_X), np.hstack(augmented_y)
196
-
197
-
198
- class TrainingStatus(Enum):
199
- """Training job status enumeration"""
200
 
201
- PENDING = "pending"
202
- RUNNING = "running"
203
- COMPLETED = "completed"
204
- FAILED = "failed"
205
- CANCELLED = "cancelled"
206
 
207
 
208
- class CVStrategy(Enum):
209
- """Cross-validation strategy enumeration"""
 
210
 
211
- STRATIFIED_KFOLD = "stratified_kfold"
212
- KFOLD = "kfold"
213
- TIME_SERIES_SPLIT = "time_series_split"
214
 
215
 
216
  @dataclass
@@ -224,15 +177,12 @@ class TrainingConfig:
224
  epochs: int = 10
225
  learning_rate: float = 1e-3
226
  num_folds: int = 10
227
- baseline_correction: bool = True
228
- smoothing: bool = True
229
- normalization: bool = True
230
  modality: str = "raman"
231
  device: str = "auto" # auto, cpu, cuda
232
  cv_strategy: str = "stratified_kfold" # New field for CV strategy
233
  spectral_weight: float = 0.1 # Weight for spectroscopy-specific metrics
234
- enable_augmentation: bool = False # Enable data augmentation
235
- noise_level: float = 0.01 # Noise level for augmentation
236
 
237
  def to_dict(self) -> Dict[str, Any]:
238
  """Convert to dictionary for serialization"""
@@ -308,10 +258,6 @@ class TrainingManager:
308
  self.output_dir = Path(output_dir)
309
  self.output_dir.mkdir(exist_ok=True)
310
  (self.output_dir / "weights").mkdir(exist_ok=True)
311
- (self.output_dir / "logs").mkdir(exist_ok=True)
312
-
313
- # Progress callbacks for UI updates
314
- self.progress_callbacks: Dict[str, List[Callable]] = {}
315
 
316
  def generate_job_id(self) -> str:
317
  """Generate unique job ID"""
@@ -324,20 +270,12 @@ class TrainingManager:
324
  job_id = self.generate_job_id()
325
  job = TrainingJob(job_id=job_id, config=config)
326
 
327
- # Set up output paths
328
- job.weights_path = str(self.output_dir / "weights" / f"{job_id}_model.pth")
329
- job.logs_path = str(self.output_dir / "logs" / f"{job_id}_log.json")
330
-
331
  self.jobs[job_id] = job
332
 
333
- # Register progress callback
334
- if progress_callback:
335
- if job_id not in self.progress_callbacks:
336
- self.progress_callbacks[job_id] = []
337
- self.progress_callbacks[job_id].append(progress_callback)
338
-
339
  # Submit to thread pool
340
- self.executor.submit(self._run_training_job, job)
 
 
341
 
342
  return job_id
343
 
@@ -346,25 +284,39 @@ class TrainingManager:
346
  try:
347
  job.status = TrainingStatus.RUNNING
348
  job.started_at = datetime.now()
349
- job.progress.start_time = job.started_at
350
-
351
- self._notify_progress(job.job_id, job)
352
 
353
- # Device selection
354
- device = self._get_device(job.config.device)
355
 
356
  # Load and preprocess data
357
  X, y = self._load_and_preprocess_data(job)
358
  if X is None or y is None:
359
  raise ValueError("Failed to load dataset")
360
 
361
- # Set reproducibility
362
- self._set_reproducibility()
363
-
364
- # Run cross-validation training
365
- self._run_cross_validation(job, X, y, device)
366
-
367
- # Save final results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  self._save_training_results(job)
369
 
370
  job.status = TrainingStatus.COMPLETED
@@ -377,16 +329,8 @@ class TrainingManager:
377
  job.completed_at = datetime.now()
378
 
379
  finally:
380
- self._notify_progress(job.job_id, job)
381
-
382
- def _get_device(self, device_preference: str) -> torch.device:
383
- """Get appropriate device for training"""
384
- if device_preference == "auto":
385
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
386
- elif device_preference == "cuda" and torch.cuda.is_available():
387
- return torch.device("cuda")
388
- else:
389
- return torch.device("cpu")
390
 
391
  def _load_and_preprocess_data(
392
  self, job: TrainingJob
@@ -576,134 +520,19 @@ class TrainingManager:
576
  print(f"Error loading dataset: {e}")
577
  return None, None
578
 
579
- def _set_reproducibility(self):
580
- """Set random seeds for reproducibility"""
581
- SEED = 42
582
- np.random.seed(SEED)
583
- torch.manual_seed(SEED)
584
- if torch.cuda.is_available():
585
- torch.cuda.manual_seed_all(SEED)
586
- torch.backends.cudnn.deterministic = True
587
- torch.backends.cudnn.benchmark = False
588
-
589
- def _run_cross_validation(
590
- self, job: TrainingJob, X: np.ndarray, y: np.ndarray, device: torch.device
591
- ):
592
- """Run configurable cross-validation training with spectroscopy metrics"""
593
- config = job.config
594
-
595
- # Apply data augmentation if enabled
596
- if config.enable_augmentation:
597
- X, y = augment_spectral_data(
598
- X, y, noise_level=config.noise_level, augmentation_factor=2
599
- )
600
-
601
- # Get appropriate CV splitter
602
- cv_splitter = get_cv_splitter(config.cv_strategy, config.num_folds)
603
-
604
- fold_accuracies = []
605
- confusion_matrices = []
606
- spectroscopy_metrics = []
607
-
608
- for fold, (train_idx, val_idx) in enumerate(cv_splitter.split(X, y), 1):
609
- job.progress.current_fold = fold
610
- job.progress.current_epoch = 0
611
-
612
- # Prepare data
613
- X_train, X_val = X[train_idx], X[val_idx]
614
- y_train, y_val = y[train_idx], y[val_idx]
615
-
616
- train_loader = DataLoader(
617
- TensorDataset(
618
- torch.tensor(X_train, dtype=torch.float32),
619
- torch.tensor(y_train, dtype=torch.long),
620
- ),
621
- batch_size=config.batch_size,
622
- shuffle=True,
623
- )
624
- val_loader = DataLoader(
625
- TensorDataset(
626
- torch.tensor(X_val, dtype=torch.float32),
627
- torch.tensor(y_val, dtype=torch.long),
628
- ),
629
- batch_size=config.batch_size,
630
- shuffle=False,
631
- )
632
-
633
- # Initialize model
634
- model = build_model(config.model_name, config.target_len).to(device)
635
- optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
636
- criterion = nn.CrossEntropyLoss()
637
-
638
- # Training loop
639
- for epoch in range(config.epochs):
640
- job.progress.current_epoch = epoch + 1
641
- model.train()
642
- running_loss = 0.0
643
- correct = 0
644
- total = 0
645
-
646
- for inputs, labels in train_loader:
647
- inputs = inputs.unsqueeze(1).to(device)
648
- labels = labels.to(device)
649
-
650
- optimizer.zero_grad()
651
- outputs = model(inputs)
652
- loss = criterion(outputs, labels)
653
- loss.backward()
654
- optimizer.step()
655
-
656
- running_loss += loss.item()
657
- _, predicted = torch.max(outputs.data, 1)
658
- total += labels.size(0)
659
- correct += (predicted == labels).sum().item()
660
-
661
- job.progress.current_loss = running_loss / len(train_loader)
662
- job.progress.current_accuracy = correct / total
663
-
664
- self._notify_progress(job.job_id, job)
665
-
666
- # Validation with comprehensive metrics
667
- model.eval()
668
- val_predictions = []
669
- val_true = []
670
- val_probabilities = []
671
-
672
- with torch.no_grad():
673
- for inputs, labels in val_loader:
674
- inputs = inputs.unsqueeze(1).to(device)
675
- outputs = model(inputs)
676
- probabilities = torch.softmax(outputs, dim=1)
677
- _, predicted = torch.max(outputs, 1)
678
-
679
- val_predictions.extend(predicted.cpu().numpy())
680
- val_true.extend(labels.numpy())
681
- val_probabilities.extend(probabilities.cpu().numpy())
682
-
683
- # Calculate standard metrics
684
- fold_accuracy = accuracy_score(val_true, val_predictions)
685
- fold_cm = confusion_matrix(val_true, val_predictions).tolist()
686
-
687
- # Calculate spectroscopy-specific metrics
688
- val_probabilities = np.array(val_probabilities)
689
- spectro_metrics = calculate_spectroscopy_metrics(
690
- np.array(val_true), np.array(val_predictions), val_probabilities
691
- )
692
-
693
- fold_accuracies.append(fold_accuracy)
694
- confusion_matrices.append(fold_cm)
695
- spectroscopy_metrics.append(spectro_metrics)
696
-
697
- # Save best model weights (from last fold for now)
698
- if fold == config.num_folds:
699
- torch.save(model.state_dict(), job.weights_path)
700
-
701
- job.progress.fold_accuracies = fold_accuracies
702
- job.progress.confusion_matrices = confusion_matrices
703
- job.progress.spectroscopy_metrics = spectroscopy_metrics
704
 
705
  def _save_training_results(self, job: TrainingJob):
706
  """Save training results and logs with enhanced metrics"""
 
 
 
 
707
  # Calculate comprehensive summary metrics
708
  spectro_summary = {}
709
  if job.progress.spectroscopy_metrics:
@@ -744,17 +573,9 @@ class TrainingManager:
744
  "error_message": job.error_message,
745
  }
746
 
747
- with open(job.logs_path, "w") as f:
748
- json.dump(results, f, indent=2)
749
-
750
- def _notify_progress(self, job_id: str, job: TrainingJob):
751
- """Notify registered callbacks about progress updates"""
752
- if job_id in self.progress_callbacks:
753
- for callback in self.progress_callbacks[job_id]:
754
- try:
755
- callback(job)
756
- except Exception as e:
757
- print(f"Error in progress callback: {e}")
758
 
759
  def get_job_status(self, job_id: str) -> Optional[TrainingJob]:
760
  """Get current status of a training job"""
 
12
  import concurrent.futures
13
  import multiprocessing
14
  from datetime import datetime, timedelta
 
 
15
  from typing import Dict, List, Optional, Callable, Any, Tuple
16
  from pathlib import Path
17
+ from dataclasses import dataclass, field
18
 
19
  import torch
20
  import torch.nn as nn
21
  import numpy as np
22
  from torch.utils.data import TensorDataset, DataLoader
 
23
  from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
24
  from sklearn.metrics.pairwise import cosine_similarity
25
  from scipy.signal import find_peaks
 
28
  # Add project-specific imports
29
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
30
  from models.registry import choices as model_choices, build as build_model
31
+ from utils.training_engine import TrainingEngine
32
+ from utils.training_types import (
33
+ TrainingConfig,
34
+ TrainingProgress,
35
+ TrainingStatus,
36
+ CVStrategy,
37
+ get_cv_splitter,
38
+ )
39
  from utils.preprocessing import preprocess_spectrum
40
 
41
 
 
149
  return metrics
150
 
151
 
152
+ @dataclass
153
+ class AugmentationConfig:
154
+ """Data augmentation configuration"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ enable_augmentation: bool = False
157
+ noise_level: float = 0.01 # Noise level for augmentation
 
 
 
158
 
159
 
160
+ @dataclass
161
+ class PreprocessingConfig:
162
+ """Preprocessing configuration"""
163
 
164
+ baseline_correction: bool = True
165
+ smoothing: bool = True
166
+ normalization: bool = True
167
 
168
 
169
  @dataclass
 
177
  epochs: int = 10
178
  learning_rate: float = 1e-3
179
  num_folds: int = 10
 
 
 
180
  modality: str = "raman"
181
  device: str = "auto" # auto, cpu, cuda
182
  cv_strategy: str = "stratified_kfold" # New field for CV strategy
183
  spectral_weight: float = 0.1 # Weight for spectroscopy-specific metrics
184
+ augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
185
+ preprocessing: PreprocessingConfig = field(default_factory=PreprocessingConfig)
186
 
187
  def to_dict(self) -> Dict[str, Any]:
188
  """Convert to dictionary for serialization"""
 
258
  self.output_dir = Path(output_dir)
259
  self.output_dir.mkdir(exist_ok=True)
260
  (self.output_dir / "weights").mkdir(exist_ok=True)
 
 
 
 
261
 
262
  def generate_job_id(self) -> str:
263
  """Generate unique job ID"""
 
270
  job_id = self.generate_job_id()
271
  job = TrainingJob(job_id=job_id, config=config)
272
 
 
 
 
 
273
  self.jobs[job_id] = job
274
 
 
 
 
 
 
 
275
  # Submit to thread pool
276
+ self.executor.submit(
277
+ self._run_training_job, job, progress_callback=progress_callback
278
+ )
279
 
280
  return job_id
281
 
 
284
  try:
285
  job.status = TrainingStatus.RUNNING
286
  job.started_at = datetime.now()
287
+ if job.progress:
288
+ job.progress.start_time = job.started_at
 
289
 
290
+ if progress_callback:
291
+ progress_callback(job)
292
 
293
  # Load and preprocess data
294
  X, y = self._load_and_preprocess_data(job)
295
  if X is None or y is None:
296
  raise ValueError("Failed to load dataset")
297
 
298
+ # Define a callback to update the job's progress object
299
+ def engine_progress_callback(progress_data: dict):
300
+ if job.progress:
301
+ if progress_data["type"] == "fold_start":
302
+ job.progress.current_fold = progress_data["fold"]
303
+ elif progress_data["type"] == "epoch_end":
304
+ job.progress.current_epoch = progress_data["epoch"]
305
+ job.progress.current_loss = progress_data["loss"]
306
+ if progress_callback:
307
+ progress_callback(job)
308
+
309
+ # Instantiate and run the training engine
310
+ engine = TrainingEngine(job.config)
311
+ results = engine.run(X, y, progress_callback=engine_progress_callback)
312
+
313
+ # Update job with results
314
+ if job.progress:
315
+ job.progress.fold_accuracies = results["fold_accuracies"]
316
+ job.progress.confusion_matrices = results["confusion_matrices"]
317
+
318
+ # Save model weights and logs
319
+ self._save_model_weights(job, results["model_state_dict"])
320
  self._save_training_results(job)
321
 
322
  job.status = TrainingStatus.COMPLETED
 
329
  job.completed_at = datetime.now()
330
 
331
  finally:
332
+ if progress_callback:
333
+ progress_callback(job)
 
 
 
 
 
 
 
 
334
 
335
  def _load_and_preprocess_data(
336
  self, job: TrainingJob
 
520
  print(f"Error loading dataset: {e}")
521
  return None, None
522
 
523
+ def _save_model_weights(self, job: TrainingJob, model_state_dict: dict):
524
+ """Saves the model's state dictionary to a file."""
525
+ weights_dir = self.output_dir / "weights"
526
+ weights_dir.mkdir(exist_ok=True)
527
+ job.weights_path = str(weights_dir / f"{job.config.model_name}_model.pth")
528
+ torch.save(model_state_dict, job.weights_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
  def _save_training_results(self, job: TrainingJob):
531
  """Save training results and logs with enhanced metrics"""
532
+ logs_dir = self.output_dir / "logs"
533
+ logs_dir.mkdir(exist_ok=True)
534
+ job.logs_path = str(logs_dir / f"{job.job_id}_log.json")
535
+
536
  # Calculate comprehensive summary metrics
537
  spectro_summary = {}
538
  if job.progress.spectroscopy_metrics:
 
573
  "error_message": job.error_message,
574
  }
575
 
576
+ if job.logs_path:
577
+ with open(job.logs_path, "w") as f:
578
+ json.dump(results, f, indent=2)
 
 
 
 
 
 
 
 
579
 
580
  def get_job_status(self, job_id: str) -> Optional[TrainingJob]:
581
  """Get current status of a training job"""
utils/training_types.py ADDED
File without changes