devjas1 commited on
Commit
6e2806f
·
1 Parent(s): a602039

(TESTS)[Comprehensive Test Suite for Training Manager]: Validate training manager, config, data loading, augmentation, and metrics

Browse files

- Added extensive tests covering job submission, execution, device selection, job listing, and error handling for invalid paths.
- Validated cross-validation strategies, including fallback logic for unknown strategies.
- Tested spectroscopy metrics, spectral similarity, and data augmentation to ensure robust domain-specific evaluation.
- Enhanced tests for secure and versatile dataset loading, supporting CSV, JSON, and TXT formats with proper class balance.
- Ensured all new features and edge cases are covered to maintain reliability and future extensibility of training backend.

Files changed (1) hide show
  1. tests/test_training_manager.py +368 -0
tests/test_training_manager.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the training manager functionality.
3
+ """
4
+
5
+ import pytest
6
+ import tempfile
7
+ import shutil
8
+ from pathlib import Path
9
+ import numpy as np
10
+ import torch
11
+ import json
12
+ import pandas as pd
13
+
14
+ from utils.training_manager import (
15
+ TrainingManager,
16
+ TrainingConfig,
17
+ TrainingStatus,
18
+ get_training_manager,
19
+ CVStrategy,
20
+ get_cv_splitter,
21
+ calculate_spectroscopy_metrics,
22
+ augment_spectral_data,
23
+ spectral_cosine_similarity,
24
+ )
25
+
26
+
27
+ def create_test_dataset(dataset_path: Path, num_samples: int = 10):
28
+ """Create a test dataset for training"""
29
+ # Create directories
30
+ (dataset_path / "stable").mkdir(parents=True, exist_ok=True)
31
+ (dataset_path / "weathered").mkdir(parents=True, exist_ok=True)
32
+
33
+ # Generate synthetic spectra
34
+ wavenumbers = np.linspace(400, 4000, 200)
35
+
36
+ for i in range(num_samples // 2):
37
+ # Stable samples
38
+ intensities = np.random.normal(0.5, 0.1, len(wavenumbers))
39
+ data = np.column_stack([wavenumbers, intensities])
40
+ np.savetxt(dataset_path / "stable" / f"stable_{i}.txt", data)
41
+
42
+ # Weathered samples
43
+ intensities = np.random.normal(0.3, 0.1, len(wavenumbers))
44
+ data = np.column_stack([wavenumbers, intensities])
45
+ np.savetxt(dataset_path / "weathered" / f"weathered_{i}.txt", data)
46
+
47
+
48
+ @pytest.fixture
49
+ def temp_dataset():
50
+ """Create temporary dataset for testing"""
51
+ temp_dir = Path(tempfile.mkdtemp())
52
+ dataset_path = temp_dir / "test_dataset"
53
+ create_test_dataset(dataset_path)
54
+ yield dataset_path
55
+ shutil.rmtree(temp_dir)
56
+
57
+
58
+ @pytest.fixture
59
+ def training_manager():
60
+ """Create training manager for testing"""
61
+ temp_dir = Path(tempfile.mkdtemp())
62
+ # Use ThreadPoolExecutor for tests to avoid multiprocessing complexities
63
+ manager = TrainingManager(
64
+ max_workers=1, output_dir=str(temp_dir), use_multiprocessing=False
65
+ )
66
+ yield manager
67
+ manager.shutdown()
68
+ shutil.rmtree(temp_dir)
69
+
70
+
71
+ def test_training_config():
72
+ """Test training configuration creation"""
73
+ config = TrainingConfig(
74
+ model_name="figure2", dataset_path="/test/path", epochs=5, batch_size=8
75
+ )
76
+
77
+ assert config.model_name == "figure2"
78
+ assert config.epochs == 5
79
+ assert config.batch_size == 8
80
+ assert config.device == "auto"
81
+
82
+
83
+ def test_training_manager_initialization(training_manager):
84
+ """Test training manager initialization"""
85
+ assert training_manager.max_workers == 1
86
+ assert len(training_manager.jobs) == 0
87
+
88
+
89
+ def test_submit_training_job(training_manager, temp_dataset):
90
+ """Test submitting a training job"""
91
+ config = TrainingConfig(
92
+ model_name="figure2", dataset_path=str(temp_dataset), epochs=1, batch_size=4
93
+ )
94
+
95
+ job_id = training_manager.submit_training_job(config)
96
+
97
+ assert job_id is not None
98
+ assert len(job_id) > 0
99
+ assert job_id in training_manager.jobs
100
+
101
+ job = training_manager.get_job_status(job_id)
102
+ assert job is not None
103
+ assert job.config.model_name == "figure2"
104
+
105
+
106
+ def test_training_job_execution(training_manager, temp_dataset):
107
+ """Test actual training job execution (lightweight test)"""
108
+ config = TrainingConfig(
109
+ model_name="figure2",
110
+ dataset_path=str(temp_dataset),
111
+ epochs=1,
112
+ num_folds=2, # Reduced for testing
113
+ batch_size=4,
114
+ )
115
+
116
+ job_id = training_manager.submit_training_job(config)
117
+
118
+ # Wait a moment for job to start
119
+ import time
120
+
121
+ time.sleep(1)
122
+
123
+ job = training_manager.get_job_status(job_id)
124
+ assert job.status in [
125
+ TrainingStatus.PENDING,
126
+ TrainingStatus.RUNNING,
127
+ TrainingStatus.COMPLETED,
128
+ TrainingStatus.FAILED,
129
+ ]
130
+
131
+
132
+ def test_list_jobs(training_manager, temp_dataset):
133
+ """Test listing jobs with filters"""
134
+ config = TrainingConfig(
135
+ model_name="figure2", dataset_path=str(temp_dataset), epochs=1
136
+ )
137
+
138
+ job_id = training_manager.submit_training_job(config)
139
+
140
+ all_jobs = training_manager.list_jobs()
141
+ assert len(all_jobs) >= 1
142
+
143
+ pending_jobs = training_manager.list_jobs(TrainingStatus.PENDING)
144
+ running_jobs = training_manager.list_jobs(TrainingStatus.RUNNING)
145
+
146
+ # Job should be in one of these states
147
+ assert len(pending_jobs) + len(running_jobs) >= 1
148
+
149
+
150
+ def test_global_training_manager():
151
+ """Test global training manager singleton"""
152
+ manager1 = get_training_manager()
153
+ manager2 = get_training_manager()
154
+
155
+ assert manager1 is manager2 # Should be same instance
156
+
157
+
158
+ def test_device_selection(training_manager):
159
+ """Test device selection logic"""
160
+ # Test auto device selection
161
+ device = training_manager._get_device("auto")
162
+ assert device.type in ["cpu", "cuda"]
163
+
164
+ # Test CPU selection
165
+ device = training_manager._get_device("cpu")
166
+ assert device.type == "cpu"
167
+
168
+ # Test CUDA selection (should fallback to CPU if not available)
169
+ device = training_manager._get_device("cuda")
170
+ if torch.cuda.is_available():
171
+ assert device.type == "cuda"
172
+ else:
173
+ assert device.type == "cpu"
174
+
175
+
176
+ def test_invalid_dataset_path(training_manager):
177
+ """Test handling of invalid dataset path"""
178
+ config = TrainingConfig(
179
+ model_name="figure2", dataset_path="/nonexistent/path", epochs=1
180
+ )
181
+
182
+ job_id = training_manager.submit_training_job(config)
183
+
184
+ # Wait for job to process
185
+ import time
186
+
187
+ time.sleep(2)
188
+
189
+ job = training_manager.get_job_status(job_id)
190
+ assert job.status == TrainingStatus.FAILED
191
+ assert "dataset" in job.error_message.lower()
192
+
193
+
194
+ def test_configurable_cv_strategies():
195
+ """Test different cross-validation strategies"""
196
+ # Test StratifiedKFold
197
+ skf = get_cv_splitter("stratified_kfold", n_splits=5)
198
+ assert hasattr(skf, "split")
199
+
200
+ # Test KFold
201
+ kf = get_cv_splitter("kfold", n_splits=5)
202
+ assert hasattr(kf, "split")
203
+
204
+ # Test TimeSeriesSplit
205
+ tss = get_cv_splitter("time_series_split", n_splits=5)
206
+ assert hasattr(tss, "split")
207
+
208
+ # Test default fallback
209
+ default = get_cv_splitter("invalid_strategy", n_splits=5)
210
+ assert hasattr(default, "split")
211
+
212
+
213
+ def test_spectroscopy_metrics():
214
+ """Test spectroscopy-specific metrics calculation"""
215
+ # Create test data
216
+ y_true = np.array([0, 0, 1, 1, 0, 1])
217
+ y_pred = np.array([0, 1, 1, 1, 0, 0])
218
+ probabilities = np.array(
219
+ [[0.8, 0.2], [0.4, 0.6], [0.3, 0.7], [0.2, 0.8], [0.9, 0.1], [0.6, 0.4]]
220
+ )
221
+
222
+ metrics = calculate_spectroscopy_metrics(y_true, y_pred, probabilities)
223
+
224
+ # Check that all expected metrics are present
225
+ assert "accuracy" in metrics
226
+ assert "f1_score" in metrics
227
+ assert "cosine_similarity" in metrics
228
+ assert "distribution_similarity" in metrics
229
+
230
+ # Check that metrics are reasonable
231
+ assert 0 <= metrics["accuracy"] <= 1
232
+ assert 0 <= metrics["f1_score"] <= 1
233
+ assert -1 <= metrics["cosine_similarity"] <= 1
234
+ assert 0 <= metrics["distribution_similarity"] <= 1
235
+
236
+
237
+ def test_spectral_cosine_similarity():
238
+ """Test cosine similarity calculation for spectral data"""
239
+ # Create test spectra
240
+ spectrum1 = np.array([1, 2, 3, 4, 5])
241
+ spectrum2 = np.array([2, 4, 6, 8, 10]) # Perfect correlation
242
+ spectrum3 = np.array([5, 4, 3, 2, 1]) # Anti-correlation
243
+
244
+ # Test perfect correlation
245
+ sim1 = spectral_cosine_similarity(spectrum1, spectrum2)
246
+ assert abs(sim1 - 1.0) < 1e-10
247
+
248
+ # Test that similarity exists
249
+ sim2 = spectral_cosine_similarity(spectrum1, spectrum3)
250
+ assert -1 <= sim2 <= 1 # Valid cosine similarity range
251
+
252
+ # Test self-similarity
253
+ sim3 = spectral_cosine_similarity(spectrum1, spectrum1)
254
+ assert abs(sim3 - 1.0) < 1e-10
255
+
256
+
257
+ def test_data_augmentation():
258
+ """Test spectral data augmentation"""
259
+ # Create test data
260
+ X = np.random.rand(10, 100)
261
+ y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
262
+
263
+ # Test augmentation
264
+ X_aug, y_aug = augment_spectral_data(X, y, noise_level=0.01, augmentation_factor=3)
265
+
266
+ # Check that data is augmented
267
+ assert X_aug.shape[0] == X.shape[0] * 3
268
+ assert y_aug.shape[0] == y.shape[0] * 3
269
+ assert X_aug.shape[1] == X.shape[1] # Same number of features
270
+
271
+ # Test no augmentation
272
+ X_no_aug, y_no_aug = augment_spectral_data(X, y, augmentation_factor=1)
273
+ assert np.array_equal(X_no_aug, X)
274
+ assert np.array_equal(y_no_aug, y)
275
+
276
+
277
+ def test_enhanced_training_config():
278
+ """Test enhanced training configuration with new parameters"""
279
+ config = TrainingConfig(
280
+ model_name="figure2",
281
+ dataset_path="/test/path",
282
+ cv_strategy="time_series_split",
283
+ enable_augmentation=True,
284
+ noise_level=0.02,
285
+ spectral_weight=0.2,
286
+ )
287
+
288
+ assert config.cv_strategy == "time_series_split"
289
+ assert config.enable_augmentation == True
290
+ assert config.noise_level == 0.02
291
+ assert config.spectral_weight == 0.2
292
+
293
+ # Test serialization includes new fields
294
+ config_dict = config.to_dict()
295
+ assert "cv_strategy" in config_dict
296
+ assert "enable_augmentation" in config_dict
297
+ assert "noise_level" in config_dict
298
+ assert "spectral_weight" in config_dict
299
+
300
+
301
+ def test_enhanced_dataset_loading_security():
302
+ """Test enhanced dataset loading with security features"""
303
+ temp_dir = Path(tempfile.mkdtemp())
304
+ training_manager = TrainingManager(
305
+ max_workers=1, output_dir=str(temp_dir), use_multiprocessing=False
306
+ )
307
+
308
+ try:
309
+ # Create a test dataset with different file formats
310
+ dataset_dir = temp_dir / "test_dataset"
311
+ (dataset_dir / "stable").mkdir(parents=True)
312
+ (dataset_dir / "weathered").mkdir(parents=True)
313
+
314
+ # Create multiple files to meet minimum requirements
315
+ for i in range(6): # Create 6 files per class
316
+ # Create CSV files
317
+ csv_data = pd.DataFrame(
318
+ {
319
+ "wavenumber": np.linspace(400, 4000, 100),
320
+ "intensity": np.random.rand(100),
321
+ }
322
+ )
323
+ csv_data.to_csv(
324
+ dataset_dir / "stable" / f"test_stable_{i}.csv", index=False
325
+ )
326
+
327
+ # Create JSON files
328
+ json_data = {
329
+ "x": np.linspace(400, 4000, 100).tolist(),
330
+ "y": np.random.rand(100).tolist(),
331
+ }
332
+ with open(dataset_dir / "weathered" / f"test_weathered_{i}.json", "w") as f:
333
+ json.dump(json_data, f)
334
+
335
+ # Test configuration with enhanced features
336
+ config = TrainingConfig(
337
+ model_name="figure2",
338
+ dataset_path=str(dataset_dir),
339
+ epochs=1,
340
+ cv_strategy="kfold",
341
+ enable_augmentation=True,
342
+ noise_level=0.01,
343
+ )
344
+
345
+ # Test that the enhanced loading works
346
+ from utils.training_manager import TrainingJob, TrainingProgress
347
+
348
+ job = TrainingJob(job_id="test", config=config, progress=TrainingProgress())
349
+
350
+ # This should work with the enhanced data loading
351
+ X, y = training_manager._load_and_preprocess_data(job)
352
+
353
+ # Should load data from multiple formats
354
+ assert X is not None
355
+ assert y is not None
356
+ assert len(X) >= 10 # Should have at least 10 samples total
357
+
358
+ # Test that we have both classes
359
+ unique_classes = np.unique(y)
360
+ assert len(unique_classes) >= 2
361
+
362
+ finally:
363
+ training_manager.shutdown()
364
+ shutil.rmtree(temp_dir)
365
+
366
+
367
+ if __name__ == "__main__":
368
+ pytest.main([__file__])