devjas1 commited on
Commit
aecd727
·
1 Parent(s): 9fe46f4

(FEAT)[Implement Robust Training Management Backend]: Add training manager, config classes, data augmentation, metrics, and cross-validation utilities

Browse files

- Developed 'TrainingManager' class to orchestrate training jobs, including job submission, tracking, and resource allocation.
- Defined 'TrainingConfig' and 'TrainingStatus' for flexible experiment configuration and state monitoring.
- Implemented multiple cross-validation strategies (KFold, StratifiedKFold, TimeSeriesSplit) for flexible ML evaluation.
- Added spectroscopy-specific metrics and spectral cosine similarity computation for domain-relevant model assessment.
- Integrated secure data loading, preprocessing, and augmentation logic to support various file formats and enforce data integrity.

Files changed (1) hide show
  1. utils/training_manager.py +817 -0
utils/training_manager.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training job management system for ML Hub functionality.
3
+ Handles asynchronous training jobs, progress tracking, and result management.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import time
10
+ import uuid
11
+ import threading
12
+ import concurrent.futures
13
+ import multiprocessing
14
+ from datetime import datetime, timedelta
15
+ from dataclasses import dataclass, asdict, field
16
+ from enum import Enum
17
+ from typing import Dict, List, Optional, Callable, Any, Tuple
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import numpy as np
23
+ from torch.utils.data import TensorDataset, DataLoader
24
+ from sklearn.model_selection import StratifiedKFold, KFold, TimeSeriesSplit
25
+ from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
26
+ from sklearn.metrics.pairwise import cosine_similarity
27
+ from scipy.signal import find_peaks
28
+ from scipy.spatial.distance import euclidean
29
+
30
+ # Add project-specific imports
31
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
32
+ from models.registry import choices as model_choices, build as build_model
33
+ from utils.preprocessing import preprocess_spectrum
34
+
35
+
36
+ def spectral_cosine_similarity(y_true: np.ndarray, y_pred: np.ndarray) -> float:
37
+ """Calculate cosine similarity between spectral predictions and true values"""
38
+ # Reshape if needed for cosine similarity calculation
39
+ if y_true.ndim == 1:
40
+ y_true = y_true.reshape(1, -1)
41
+ if y_pred.ndim == 1:
42
+ y_pred = y_pred.reshape(1, -1)
43
+
44
+ return float(cosine_similarity(y_true, y_pred)[0, 0])
45
+
46
+
47
+ def peak_matching_score(
48
+ spectrum1: np.ndarray,
49
+ spectrum2: np.ndarray,
50
+ height_threshold: float = 0.1,
51
+ distance: int = 5,
52
+ ) -> float:
53
+ """Calculate peak matching score between two spectra"""
54
+ try:
55
+ # Find peaks in both spectra
56
+ peaks1, _ = find_peaks(spectrum1, height=height_threshold, distance=distance)
57
+ peaks2, _ = find_peaks(spectrum2, height=height_threshold, distance=distance)
58
+
59
+ if len(peaks1) == 0 or len(peaks2) == 0:
60
+ return 0.0
61
+
62
+ # Calculate matching peaks (within tolerance)
63
+ tolerance = 3 # wavenumber tolerance
64
+ matches = 0
65
+
66
+ for peak1 in peaks1:
67
+ for peak2 in peaks2:
68
+ if abs(peak1 - peak2) <= tolerance:
69
+ matches += 1
70
+ break
71
+
72
+ # Return normalized matching score
73
+ return matches / max(len(peaks1), len(peaks2))
74
+ except:
75
+ return 0.0
76
+
77
+
78
+ def spectral_euclidean_distance(y_true: np.ndarray, y_pred: np.ndarray) -> float:
79
+ """Calculate normalized Euclidean distance between spectra"""
80
+ try:
81
+ distance = euclidean(y_true.flatten(), y_pred.flatten())
82
+ # Normalize by the length of the spectrum
83
+ return distance / len(y_true.flatten())
84
+ except:
85
+ return float("inf")
86
+
87
+
88
+ def calculate_spectroscopy_metrics(
89
+ y_true: np.ndarray, y_pred: np.ndarray, probabilities: Optional[np.ndarray] = None
90
+ ) -> Dict[str, float]:
91
+ """Calculate comprehensive spectroscopy-specific metrics"""
92
+ metrics = {}
93
+
94
+ try:
95
+ # Standard classification metrics
96
+ metrics["accuracy"] = accuracy_score(y_true, y_pred)
97
+ metrics["f1_score"] = f1_score(y_true, y_pred, average="weighted")
98
+
99
+ # Spectroscopy-specific metrics
100
+ if probabilities is not None and len(probabilities.shape) > 1:
101
+ # For classification with probabilities, use cosine similarity on prob distributions
102
+ unique_classes = np.unique(y_true)
103
+ if len(unique_classes) > 1:
104
+ # Convert true labels to one-hot for similarity calculation
105
+ y_true_onehot = np.eye(len(unique_classes))[y_true]
106
+ metrics["cosine_similarity"] = float(
107
+ cosine_similarity(
108
+ y_true_onehot.mean(axis=0).reshape(1, -1),
109
+ probabilities.mean(axis=0).reshape(1, -1),
110
+ )[0, 0]
111
+ )
112
+
113
+ # Add bias audit metric (class distribution comparison)
114
+ unique_true, counts_true = np.unique(y_true, return_counts=True)
115
+ unique_pred, counts_pred = np.unique(y_pred, return_counts=True)
116
+
117
+ # Calculate distribution difference (Jensen-Shannon divergence approximation)
118
+ true_dist = counts_true / len(y_true)
119
+ pred_dist = np.zeros_like(true_dist)
120
+
121
+ for i, class_label in enumerate(unique_true):
122
+ if class_label in unique_pred:
123
+ pred_idx = np.where(unique_pred == class_label)[0][0]
124
+ pred_dist[i] = counts_pred[pred_idx] / len(y_pred)
125
+
126
+ # Simple distribution similarity (1 - average absolute difference)
127
+ metrics["distribution_similarity"] = 1.0 - np.mean(
128
+ np.abs(true_dist - pred_dist)
129
+ )
130
+
131
+ except Exception as e:
132
+ print(f"Error calculating spectroscopy metrics: {e}")
133
+ # Return basic metrics
134
+ metrics = {
135
+ "accuracy": accuracy_score(y_true, y_pred) if len(y_true) > 0 else 0.0,
136
+ "f1_score": (
137
+ f1_score(y_true, y_pred, average="weighted") if len(y_true) > 0 else 0.0
138
+ ),
139
+ "cosine_similarity": 0.0,
140
+ "distribution_similarity": 0.0,
141
+ }
142
+
143
+ return metrics
144
+
145
+
146
+ def get_cv_splitter(strategy: str, n_splits: int = 10, random_state: int = 42):
147
+ """Get cross-validation splitter based on strategy"""
148
+ if strategy == "stratified_kfold":
149
+ return StratifiedKFold(
150
+ n_splits=n_splits, shuffle=True, random_state=random_state
151
+ )
152
+ elif strategy == "kfold":
153
+ return KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
154
+ elif strategy == "time_series_split":
155
+ return TimeSeriesSplit(n_splits=n_splits)
156
+ else:
157
+ # Default to stratified k-fold
158
+ return StratifiedKFold(
159
+ n_splits=n_splits, shuffle=True, random_state=random_state
160
+ )
161
+
162
+
163
+ def augment_spectral_data(
164
+ X: np.ndarray,
165
+ y: np.ndarray,
166
+ noise_level: float = 0.01,
167
+ augmentation_factor: int = 2,
168
+ ) -> Tuple[np.ndarray, np.ndarray]:
169
+ """Augment spectral data with realistic noise and variations"""
170
+ if augmentation_factor <= 1:
171
+ return X, y
172
+
173
+ augmented_X = [X]
174
+ augmented_y = [y]
175
+
176
+ for i in range(augmentation_factor - 1):
177
+ # Add Gaussian noise
178
+ noise = np.random.normal(0, noise_level, X.shape)
179
+ X_noisy = X + noise
180
+
181
+ # Add baseline drift (common in spectroscopy)
182
+ baseline_drift = np.random.normal(0, noise_level * 0.5, (X.shape[0], 1))
183
+ X_drift = X_noisy + baseline_drift
184
+
185
+ # Add intensity scaling variation
186
+ intensity_scale = np.random.normal(1.0, 0.05, (X.shape[0], 1))
187
+ X_scaled = X_drift * intensity_scale
188
+
189
+ # Ensure no negative values
190
+ X_scaled = np.maximum(X_scaled, 0)
191
+
192
+ augmented_X.append(X_scaled)
193
+ augmented_y.append(y)
194
+
195
+ return np.vstack(augmented_X), np.hstack(augmented_y)
196
+
197
+
198
+ class TrainingStatus(Enum):
199
+ """Training job status enumeration"""
200
+
201
+ PENDING = "pending"
202
+ RUNNING = "running"
203
+ COMPLETED = "completed"
204
+ FAILED = "failed"
205
+ CANCELLED = "cancelled"
206
+
207
+
208
+ class CVStrategy(Enum):
209
+ """Cross-validation strategy enumeration"""
210
+
211
+ STRATIFIED_KFOLD = "stratified_kfold"
212
+ KFOLD = "kfold"
213
+ TIME_SERIES_SPLIT = "time_series_split"
214
+
215
+
216
+ @dataclass
217
+ class TrainingConfig:
218
+ """Training configuration parameters"""
219
+
220
+ model_name: str
221
+ dataset_path: str
222
+ target_len: int = 500
223
+ batch_size: int = 16
224
+ epochs: int = 10
225
+ learning_rate: float = 1e-3
226
+ num_folds: int = 10
227
+ baseline_correction: bool = True
228
+ smoothing: bool = True
229
+ normalization: bool = True
230
+ modality: str = "raman"
231
+ device: str = "auto" # auto, cpu, cuda
232
+ cv_strategy: str = "stratified_kfold" # New field for CV strategy
233
+ spectral_weight: float = 0.1 # Weight for spectroscopy-specific metrics
234
+ enable_augmentation: bool = False # Enable data augmentation
235
+ noise_level: float = 0.01 # Noise level for augmentation
236
+
237
+ def to_dict(self) -> Dict[str, Any]:
238
+ """Convert to dictionary for serialization"""
239
+ return asdict(self)
240
+
241
+
242
+ @dataclass
243
+ class TrainingProgress:
244
+ """Training progress tracking with enhanced metrics"""
245
+
246
+ current_fold: int = 0
247
+ total_folds: int = 10
248
+ current_epoch: int = 0
249
+ total_epochs: int = 10
250
+ current_loss: float = 0.0
251
+ current_accuracy: float = 0.0
252
+ fold_accuracies: List[float] = field(default_factory=list)
253
+ confusion_matrices: List[List[List[int]]] = field(default_factory=list)
254
+ spectroscopy_metrics: List[Dict[str, float]] = field(default_factory=list)
255
+ start_time: Optional[datetime] = None
256
+ end_time: Optional[datetime] = None
257
+
258
+
259
+ @dataclass
260
+ class TrainingJob:
261
+ """Training job container"""
262
+
263
+ job_id: str
264
+ config: TrainingConfig
265
+ status: TrainingStatus = TrainingStatus.PENDING
266
+ progress: TrainingProgress = None
267
+ error_message: Optional[str] = None
268
+ created_at: datetime = None
269
+ started_at: Optional[datetime] = None
270
+ completed_at: Optional[datetime] = None
271
+ weights_path: Optional[str] = None
272
+ logs_path: Optional[str] = None
273
+
274
+ def __post_init__(self):
275
+ if self.progress is None:
276
+ self.progress = TrainingProgress(
277
+ total_folds=self.config.num_folds, total_epochs=self.config.epochs
278
+ )
279
+ if self.created_at is None:
280
+ self.created_at = datetime.now()
281
+
282
+
283
+ class TrainingManager:
284
+ """Manager for training jobs with async execution and progress tracking"""
285
+
286
+ def __init__(
287
+ self,
288
+ max_workers: int = 2,
289
+ output_dir: str = "outputs",
290
+ use_multiprocessing: bool = True,
291
+ ):
292
+ self.max_workers = max_workers
293
+ self.use_multiprocessing = use_multiprocessing
294
+
295
+ # Use ProcessPoolExecutor for CPU/GPU-bound tasks, ThreadPoolExecutor for I/O-bound
296
+ if use_multiprocessing:
297
+ # Limit workers to available CPU cores to prevent oversubscription
298
+ actual_workers = min(max_workers, multiprocessing.cpu_count())
299
+ self.executor = concurrent.futures.ProcessPoolExecutor(
300
+ max_workers=actual_workers
301
+ )
302
+ else:
303
+ self.executor = concurrent.futures.ThreadPoolExecutor(
304
+ max_workers=max_workers
305
+ )
306
+
307
+ self.jobs: Dict[str, TrainingJob] = {}
308
+ self.output_dir = Path(output_dir)
309
+ self.output_dir.mkdir(exist_ok=True)
310
+ (self.output_dir / "weights").mkdir(exist_ok=True)
311
+ (self.output_dir / "logs").mkdir(exist_ok=True)
312
+
313
+ # Progress callbacks for UI updates
314
+ self.progress_callbacks: Dict[str, List[Callable]] = {}
315
+
316
+ def generate_job_id(self) -> str:
317
+ """Generate unique job ID"""
318
+ return f"train_{uuid.uuid4().hex[:8]}_{int(time.time())}"
319
+
320
+ def submit_training_job(
321
+ self, config: TrainingConfig, progress_callback: Optional[Callable] = None
322
+ ) -> str:
323
+ """Submit a new training job"""
324
+ job_id = self.generate_job_id()
325
+ job = TrainingJob(job_id=job_id, config=config)
326
+
327
+ # Set up output paths
328
+ job.weights_path = str(self.output_dir / "weights" / f"{job_id}_model.pth")
329
+ job.logs_path = str(self.output_dir / "logs" / f"{job_id}_log.json")
330
+
331
+ self.jobs[job_id] = job
332
+
333
+ # Register progress callback
334
+ if progress_callback:
335
+ if job_id not in self.progress_callbacks:
336
+ self.progress_callbacks[job_id] = []
337
+ self.progress_callbacks[job_id].append(progress_callback)
338
+
339
+ # Submit to thread pool
340
+ self.executor.submit(self._run_training_job, job)
341
+
342
+ return job_id
343
+
344
+ def _run_training_job(self, job: TrainingJob) -> None:
345
+ """Execute training job (runs in separate thread)"""
346
+ try:
347
+ job.status = TrainingStatus.RUNNING
348
+ job.started_at = datetime.now()
349
+ job.progress.start_time = job.started_at
350
+
351
+ self._notify_progress(job.job_id, job)
352
+
353
+ # Device selection
354
+ device = self._get_device(job.config.device)
355
+
356
+ # Load and preprocess data
357
+ X, y = self._load_and_preprocess_data(job)
358
+ if X is None or y is None:
359
+ raise ValueError("Failed to load dataset")
360
+
361
+ # Set reproducibility
362
+ self._set_reproducibility()
363
+
364
+ # Run cross-validation training
365
+ self._run_cross_validation(job, X, y, device)
366
+
367
+ # Save final results
368
+ self._save_training_results(job)
369
+
370
+ job.status = TrainingStatus.COMPLETED
371
+ job.completed_at = datetime.now()
372
+ job.progress.end_time = job.completed_at
373
+
374
+ except Exception as e:
375
+ job.status = TrainingStatus.FAILED
376
+ job.error_message = str(e)
377
+ job.completed_at = datetime.now()
378
+
379
+ finally:
380
+ self._notify_progress(job.job_id, job)
381
+
382
+ def _get_device(self, device_preference: str) -> torch.device:
383
+ """Get appropriate device for training"""
384
+ if device_preference == "auto":
385
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
386
+ elif device_preference == "cuda" and torch.cuda.is_available():
387
+ return torch.device("cuda")
388
+ else:
389
+ return torch.device("cpu")
390
+
391
+ def _load_and_preprocess_data(
392
+ self, job: TrainingJob
393
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
394
+ """Load and preprocess dataset with enhanced validation and security"""
395
+ try:
396
+ config = job.config
397
+ dataset_path = Path(config.dataset_path)
398
+
399
+ # Enhanced path validation and security
400
+ if not dataset_path.exists():
401
+ raise FileNotFoundError(f"Dataset path not found: {dataset_path}")
402
+
403
+ # Validate dataset path is within allowed directories (security)
404
+ try:
405
+ dataset_path = dataset_path.resolve()
406
+ allowed_bases = [
407
+ Path("datasets").resolve(),
408
+ Path("data").resolve(),
409
+ Path("/tmp").resolve(),
410
+ ]
411
+ if not any(
412
+ str(dataset_path).startswith(str(base)) for base in allowed_bases
413
+ ):
414
+ raise ValueError(
415
+ f"Dataset path outside allowed directories: {dataset_path}"
416
+ )
417
+ except Exception as e:
418
+ print(f"Path validation error: {e}")
419
+ raise ValueError("Invalid dataset path")
420
+
421
+ # Load data from dataset directory
422
+ X, y = [], []
423
+ total_files = 0
424
+ processed_files = 0
425
+ max_files_per_class = 1000 # Limit to prevent memory issues
426
+ max_file_size = 10 * 1024 * 1024 # 10MB per file
427
+
428
+ # Look for data files in the dataset directory
429
+ for label_dir in dataset_path.iterdir():
430
+ if not label_dir.is_dir():
431
+ continue
432
+
433
+ label = 0 if "stable" in label_dir.name.lower() else 1
434
+ files_in_class = 0
435
+
436
+ # Support multiple file formats
437
+ file_patterns = ["*.txt", "*.csv", "*.json"]
438
+
439
+ for pattern in file_patterns:
440
+ for file_path in label_dir.glob(pattern):
441
+ total_files += 1
442
+
443
+ # Security: Check file size
444
+ if file_path.stat().st_size > max_file_size:
445
+ print(
446
+ f"Skipping large file: {file_path} ({file_path.stat().st_size} bytes)"
447
+ )
448
+ continue
449
+
450
+ # Limit files per class
451
+ if files_in_class >= max_files_per_class:
452
+ print(
453
+ f"Reached maximum files per class ({max_files_per_class}) for {label_dir.name}"
454
+ )
455
+ break
456
+
457
+ try:
458
+ # Load spectrum data based on file type
459
+ if file_path.suffix.lower() == ".txt":
460
+ data = np.loadtxt(file_path)
461
+ if data.ndim == 2 and data.shape[1] >= 2:
462
+ x_raw, y_raw = data[:, 0], data[:, 1]
463
+ elif data.ndim == 1:
464
+ # Single column data
465
+ x_raw = np.arange(len(data))
466
+ y_raw = data
467
+ else:
468
+ continue
469
+
470
+ elif file_path.suffix.lower() == ".csv":
471
+ import pandas as pd
472
+
473
+ df = pd.read_csv(file_path)
474
+ if df.shape[1] >= 2:
475
+ x_raw, y_raw = (
476
+ df.iloc[:, 0].values,
477
+ df.iloc[:, 1].values,
478
+ )
479
+ else:
480
+ x_raw = np.arange(len(df))
481
+ y_raw = df.iloc[:, 0].values
482
+
483
+ elif file_path.suffix.lower() == ".json":
484
+ with open(file_path, "r") as f:
485
+ data_dict = json.load(f)
486
+ if isinstance(data_dict, dict):
487
+ if "x" in data_dict and "y" in data_dict:
488
+ x_raw, y_raw = np.array(
489
+ data_dict["x"]
490
+ ), np.array(data_dict["y"])
491
+ elif "spectrum" in data_dict:
492
+ y_raw = np.array(data_dict["spectrum"])
493
+ x_raw = np.arange(len(y_raw))
494
+ else:
495
+ continue
496
+ else:
497
+ continue
498
+ else:
499
+ continue
500
+
501
+ # Validate data integrity
502
+ if len(x_raw) != len(y_raw) or len(x_raw) < 10:
503
+ print(
504
+ f"Invalid data in file {file_path}: insufficient data points"
505
+ )
506
+ continue
507
+
508
+ # Check for NaN or infinite values
509
+ if np.any(np.isnan(y_raw)) or np.any(np.isinf(y_raw)):
510
+ print(
511
+ f"Invalid data in file {file_path}: NaN or infinite values"
512
+ )
513
+ continue
514
+
515
+ # Validate reasonable value ranges for spectroscopy
516
+ if np.min(y_raw) < -1000 or np.max(y_raw) > 1e6:
517
+ print(
518
+ f"Suspicious data values in file {file_path}: outside expected range"
519
+ )
520
+ continue
521
+
522
+ # Preprocess spectrum
523
+ _, y_processed = preprocess_spectrum(
524
+ x_raw,
525
+ y_raw,
526
+ modality=config.modality,
527
+ target_len=config.target_len,
528
+ do_baseline=config.baseline_correction,
529
+ do_smooth=config.smoothing,
530
+ do_normalize=config.normalization,
531
+ )
532
+
533
+ # Final validation of processed data
534
+ if (
535
+ y_processed is None
536
+ or len(y_processed) != config.target_len
537
+ ):
538
+ print(f"Preprocessing failed for file {file_path}")
539
+ continue
540
+
541
+ X.append(y_processed)
542
+ y.append(label)
543
+ files_in_class += 1
544
+ processed_files += 1
545
+
546
+ except Exception as e:
547
+ print(f"Error processing file {file_path}: {e}")
548
+ continue
549
+
550
+ # Validate final dataset
551
+ if len(X) == 0:
552
+ raise ValueError("No valid data files found in dataset")
553
+
554
+ if len(X) < 10:
555
+ raise ValueError(
556
+ f"Insufficient data: only {len(X)} samples found (minimum 10 required)"
557
+ )
558
+
559
+ # Check class balance
560
+ unique_labels, counts = np.unique(y, return_counts=True)
561
+ if len(unique_labels) < 2:
562
+ raise ValueError("Dataset must contain at least 2 classes")
563
+
564
+ min_class_size = min(counts)
565
+ if min_class_size < 3:
566
+ raise ValueError(
567
+ f"Insufficient samples in one class: minimum {min_class_size} (need at least 3)"
568
+ )
569
+
570
+ print(f"Dataset loaded: {processed_files}/{total_files} files processed")
571
+ print(f"Class distribution: {dict(zip(unique_labels, counts))}")
572
+
573
+ return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)
574
+
575
+ except Exception as e:
576
+ print(f"Error loading dataset: {e}")
577
+ return None, None
578
+
579
+ def _set_reproducibility(self):
580
+ """Set random seeds for reproducibility"""
581
+ SEED = 42
582
+ np.random.seed(SEED)
583
+ torch.manual_seed(SEED)
584
+ if torch.cuda.is_available():
585
+ torch.cuda.manual_seed_all(SEED)
586
+ torch.backends.cudnn.deterministic = True
587
+ torch.backends.cudnn.benchmark = False
588
+
589
+ def _run_cross_validation(
590
+ self, job: TrainingJob, X: np.ndarray, y: np.ndarray, device: torch.device
591
+ ):
592
+ """Run configurable cross-validation training with spectroscopy metrics"""
593
+ config = job.config
594
+
595
+ # Apply data augmentation if enabled
596
+ if config.enable_augmentation:
597
+ X, y = augment_spectral_data(
598
+ X, y, noise_level=config.noise_level, augmentation_factor=2
599
+ )
600
+
601
+ # Get appropriate CV splitter
602
+ cv_splitter = get_cv_splitter(config.cv_strategy, config.num_folds)
603
+
604
+ fold_accuracies = []
605
+ confusion_matrices = []
606
+ spectroscopy_metrics = []
607
+
608
+ for fold, (train_idx, val_idx) in enumerate(cv_splitter.split(X, y), 1):
609
+ job.progress.current_fold = fold
610
+ job.progress.current_epoch = 0
611
+
612
+ # Prepare data
613
+ X_train, X_val = X[train_idx], X[val_idx]
614
+ y_train, y_val = y[train_idx], y[val_idx]
615
+
616
+ train_loader = DataLoader(
617
+ TensorDataset(
618
+ torch.tensor(X_train, dtype=torch.float32),
619
+ torch.tensor(y_train, dtype=torch.long),
620
+ ),
621
+ batch_size=config.batch_size,
622
+ shuffle=True,
623
+ )
624
+ val_loader = DataLoader(
625
+ TensorDataset(
626
+ torch.tensor(X_val, dtype=torch.float32),
627
+ torch.tensor(y_val, dtype=torch.long),
628
+ ),
629
+ batch_size=config.batch_size,
630
+ shuffle=False,
631
+ )
632
+
633
+ # Initialize model
634
+ model = build_model(config.model_name, config.target_len).to(device)
635
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
636
+ criterion = nn.CrossEntropyLoss()
637
+
638
+ # Training loop
639
+ for epoch in range(config.epochs):
640
+ job.progress.current_epoch = epoch + 1
641
+ model.train()
642
+ running_loss = 0.0
643
+ correct = 0
644
+ total = 0
645
+
646
+ for inputs, labels in train_loader:
647
+ inputs = inputs.unsqueeze(1).to(device)
648
+ labels = labels.to(device)
649
+
650
+ optimizer.zero_grad()
651
+ outputs = model(inputs)
652
+ loss = criterion(outputs, labels)
653
+ loss.backward()
654
+ optimizer.step()
655
+
656
+ running_loss += loss.item()
657
+ _, predicted = torch.max(outputs.data, 1)
658
+ total += labels.size(0)
659
+ correct += (predicted == labels).sum().item()
660
+
661
+ job.progress.current_loss = running_loss / len(train_loader)
662
+ job.progress.current_accuracy = correct / total
663
+
664
+ self._notify_progress(job.job_id, job)
665
+
666
+ # Validation with comprehensive metrics
667
+ model.eval()
668
+ val_predictions = []
669
+ val_true = []
670
+ val_probabilities = []
671
+
672
+ with torch.no_grad():
673
+ for inputs, labels in val_loader:
674
+ inputs = inputs.unsqueeze(1).to(device)
675
+ outputs = model(inputs)
676
+ probabilities = torch.softmax(outputs, dim=1)
677
+ _, predicted = torch.max(outputs, 1)
678
+
679
+ val_predictions.extend(predicted.cpu().numpy())
680
+ val_true.extend(labels.numpy())
681
+ val_probabilities.extend(probabilities.cpu().numpy())
682
+
683
+ # Calculate standard metrics
684
+ fold_accuracy = accuracy_score(val_true, val_predictions)
685
+ fold_cm = confusion_matrix(val_true, val_predictions).tolist()
686
+
687
+ # Calculate spectroscopy-specific metrics
688
+ val_probabilities = np.array(val_probabilities)
689
+ spectro_metrics = calculate_spectroscopy_metrics(
690
+ np.array(val_true), np.array(val_predictions), val_probabilities
691
+ )
692
+
693
+ fold_accuracies.append(fold_accuracy)
694
+ confusion_matrices.append(fold_cm)
695
+ spectroscopy_metrics.append(spectro_metrics)
696
+
697
+ # Save best model weights (from last fold for now)
698
+ if fold == config.num_folds:
699
+ torch.save(model.state_dict(), job.weights_path)
700
+
701
+ job.progress.fold_accuracies = fold_accuracies
702
+ job.progress.confusion_matrices = confusion_matrices
703
+ job.progress.spectroscopy_metrics = spectroscopy_metrics
704
+
705
+ def _save_training_results(self, job: TrainingJob):
706
+ """Save training results and logs with enhanced metrics"""
707
+ # Calculate comprehensive summary metrics
708
+ spectro_summary = {}
709
+ if job.progress.spectroscopy_metrics:
710
+ # Average across all folds for each metric
711
+ metric_keys = job.progress.spectroscopy_metrics[0].keys()
712
+ for key in metric_keys:
713
+ values = [
714
+ fold_metrics.get(key, 0.0)
715
+ for fold_metrics in job.progress.spectroscopy_metrics
716
+ ]
717
+ spectro_summary[f"mean_{key}"] = float(np.mean(values))
718
+ spectro_summary[f"std_{key}"] = float(np.std(values))
719
+
720
+ results = {
721
+ "job_id": job.job_id,
722
+ "config": job.config.to_dict(),
723
+ "status": job.status.value,
724
+ "created_at": job.created_at.isoformat(),
725
+ "started_at": job.started_at.isoformat() if job.started_at else None,
726
+ "completed_at": job.completed_at.isoformat() if job.completed_at else None,
727
+ "progress": {
728
+ "fold_accuracies": job.progress.fold_accuracies,
729
+ "confusion_matrices": job.progress.confusion_matrices,
730
+ "spectroscopy_metrics": job.progress.spectroscopy_metrics,
731
+ "mean_accuracy": (
732
+ np.mean(job.progress.fold_accuracies)
733
+ if job.progress.fold_accuracies
734
+ else 0.0
735
+ ),
736
+ "std_accuracy": (
737
+ np.std(job.progress.fold_accuracies)
738
+ if job.progress.fold_accuracies
739
+ else 0.0
740
+ ),
741
+ "spectroscopy_summary": spectro_summary,
742
+ },
743
+ "weights_path": job.weights_path,
744
+ "error_message": job.error_message,
745
+ }
746
+
747
+ with open(job.logs_path, "w") as f:
748
+ json.dump(results, f, indent=2)
749
+
750
+ def _notify_progress(self, job_id: str, job: TrainingJob):
751
+ """Notify registered callbacks about progress updates"""
752
+ if job_id in self.progress_callbacks:
753
+ for callback in self.progress_callbacks[job_id]:
754
+ try:
755
+ callback(job)
756
+ except Exception as e:
757
+ print(f"Error in progress callback: {e}")
758
+
759
+ def get_job_status(self, job_id: str) -> Optional[TrainingJob]:
760
+ """Get current status of a training job"""
761
+ return self.jobs.get(job_id)
762
+
763
+ def list_jobs(
764
+ self, status_filter: Optional[TrainingStatus] = None
765
+ ) -> List[TrainingJob]:
766
+ """List all jobs, optionally filtered by status"""
767
+ jobs = list(self.jobs.values())
768
+ if status_filter:
769
+ jobs = [job for job in jobs if job.status == status_filter]
770
+ return sorted(jobs, key=lambda j: j.created_at, reverse=True)
771
+
772
+ def cancel_job(self, job_id: str) -> bool:
773
+ """Cancel a running job"""
774
+ job = self.jobs.get(job_id)
775
+ if job and job.status == TrainingStatus.RUNNING:
776
+ job.status = TrainingStatus.CANCELLED
777
+ job.completed_at = datetime.now()
778
+ # Note: This is a simple cancellation - actual thread termination is more complex
779
+ return True
780
+ return False
781
+
782
+ def cleanup_old_jobs(self, max_age_hours: int = 24):
783
+ """Clean up old completed/failed jobs"""
784
+ cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
785
+ to_remove = []
786
+
787
+ for job_id, job in self.jobs.items():
788
+ if (
789
+ job.status
790
+ in [
791
+ TrainingStatus.COMPLETED,
792
+ TrainingStatus.FAILED,
793
+ TrainingStatus.CANCELLED,
794
+ ]
795
+ and job.completed_at
796
+ and job.completed_at < cutoff_time
797
+ ):
798
+ to_remove.append(job_id)
799
+
800
+ for job_id in to_remove:
801
+ del self.jobs[job_id]
802
+
803
+ def shutdown(self):
804
+ """Shutdown the training manager"""
805
+ self.executor.shutdown(wait=True)
806
+
807
+
808
+ # Global training manager instance
809
+ _training_manager = None
810
+
811
+
812
+ def get_training_manager() -> TrainingManager:
813
+ """Get global training manager instance"""
814
+ global _training_manager
815
+ if _training_manager is None:
816
+ _training_manager = TrainingManager()
817
+ return _training_manager