devjas1 commited on
Commit
9fe46f4
·
1 Parent(s): 4dd9134

(FEAT)[Create Model Training UI Component]: Introduce comprehensive UI for model training and experiment management

Browse files

- Added a new module dedicated to rendering the model training interface, enabling users to configure, launch, and track ML experiments.
- Established a code structure for future expansion, including support for job status monitoring, dataset selection, and advanced configuration.
- Provided foundation for interactive feedback and integration with backend training manager.

Files changed (1) hide show
  1. modules/training_ui.py +1035 -0
modules/training_ui.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training UI components for the ML Hub functionality.
3
+ Provides interface for model training, dataset management, and progress tracking.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import torch
9
+ import streamlit as st
10
+ import pandas as pd
11
+ import numpy as np
12
+ import plotly.graph_objects as go
13
+ from plotly.subplots import make_subplots
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional
16
+ import json
17
+ from datetime import datetime, timedelta
18
+
19
+ from models.registry import choices as model_choices, get_model_info
20
+ from utils.training_manager import (
21
+ get_training_manager,
22
+ TrainingConfig,
23
+ TrainingStatus,
24
+ TrainingJob,
25
+ )
26
+
27
+
28
+ def render_training_tab():
29
+ """Render the main training interface tab"""
30
+ st.markdown("## 🎯 Model Training Hub")
31
+ st.markdown(
32
+ "Train any model from the registry on your datasets with real-time progress tracking."
33
+ )
34
+
35
+ # Create columns for layout
36
+ config_col, status_col = st.columns([1, 1])
37
+
38
+ with config_col:
39
+ render_training_configuration()
40
+
41
+ with status_col:
42
+ render_training_status()
43
+
44
+ # Full-width progress and results section
45
+ st.markdown("---")
46
+ render_training_progress()
47
+
48
+ st.markdown("---")
49
+ render_training_history()
50
+
51
+
52
+ def render_training_configuration():
53
+ """Render training configuration panel"""
54
+ st.markdown("### ⚙️ Training Configuration")
55
+
56
+ with st.expander("Model Selection", expanded=True):
57
+ # Model selection
58
+ available_models = model_choices()
59
+ selected_model = st.selectbox(
60
+ "Select Model Architecture",
61
+ available_models,
62
+ help="Choose from available model architectures in the registry",
63
+ )
64
+
65
+ # Store in session state
66
+ st.session_state["selected_model"] = selected_model
67
+
68
+ # Display model info
69
+ if selected_model:
70
+ try:
71
+ model_info = get_model_info(selected_model)
72
+ st.info(
73
+ f"**{selected_model}**: {model_info.get('description', 'No description available')}"
74
+ )
75
+
76
+ # Model specs
77
+ col1, col2 = st.columns(2)
78
+ with col1:
79
+ st.metric("Parameters", model_info.get("parameters", "Unknown"))
80
+ st.metric("Speed", model_info.get("speed", "Unknown"))
81
+ with col2:
82
+ if "performance" in model_info:
83
+ perf = model_info["performance"]
84
+ st.metric("Accuracy", f"{perf.get('accuracy', 0):.3f}")
85
+ st.metric("F1 Score", f"{perf.get('f1_score', 0):.3f}")
86
+ except KeyError:
87
+ st.warning(f"Model info not available for {selected_model}")
88
+
89
+ with st.expander("Dataset Selection", expanded=True):
90
+ render_dataset_selection()
91
+
92
+ with st.expander("Training Parameters", expanded=True):
93
+ render_training_parameters()
94
+
95
+ # Training action button
96
+ st.markdown("---")
97
+ if st.button("🚀 Start Training", type="primary", use_container_width=True):
98
+ start_training_job()
99
+
100
+
101
+ def render_dataset_selection():
102
+ """Render dataset selection and upload interface"""
103
+ st.markdown("#### Dataset Management")
104
+
105
+ # Dataset source selection
106
+ dataset_source = st.radio(
107
+ "Dataset Source",
108
+ ["Upload New Dataset", "Use Existing Dataset"],
109
+ horizontal=True,
110
+ )
111
+
112
+ if dataset_source == "Upload New Dataset":
113
+ render_dataset_upload()
114
+ else:
115
+ render_existing_dataset_selection()
116
+
117
+
118
+ def render_dataset_upload():
119
+ """Render dataset upload interface"""
120
+ st.markdown("##### Upload Dataset")
121
+
122
+ uploaded_files = st.file_uploader(
123
+ "Upload spectrum files (.txt, .csv, .json)",
124
+ accept_multiple_files=True,
125
+ type=["txt", "csv", "json"],
126
+ help="Upload multiple spectrum files. Organize them in folders named 'stable' and 'weathered' or label them accordingly.",
127
+ )
128
+
129
+ if uploaded_files:
130
+ st.success(f"✅ {len(uploaded_files)} files uploaded")
131
+
132
+ # Dataset organization
133
+ st.markdown("##### Dataset Organization")
134
+
135
+ dataset_name = st.text_input(
136
+ "Dataset Name",
137
+ placeholder="e.g., my_polymer_dataset",
138
+ help="Name for your dataset (will create a folder)",
139
+ )
140
+
141
+ # File labeling
142
+ st.markdown("**Label your files:**")
143
+ file_labels = {}
144
+
145
+ for i, file in enumerate(uploaded_files[:10]): # Limit display for performance
146
+ col1, col2 = st.columns([2, 1])
147
+ with col1:
148
+ st.text(file.name)
149
+ with col2:
150
+ file_labels[file.name] = st.selectbox(
151
+ f"Label for {file.name}", ["stable", "weathered"], key=f"label_{i}"
152
+ )
153
+
154
+ if len(uploaded_files) > 10:
155
+ st.info(
156
+ f"Showing first 10 files. {len(uploaded_files) - 10} more files will use default labeling based on filename."
157
+ )
158
+
159
+ if st.button("💾 Save Dataset") and dataset_name:
160
+ save_uploaded_dataset(uploaded_files, dataset_name, file_labels)
161
+
162
+
163
+ def render_existing_dataset_selection():
164
+ """Render existing dataset selection"""
165
+ st.markdown("##### Available Datasets")
166
+
167
+ # Scan for existing datasets
168
+ datasets_dir = Path("datasets")
169
+ if datasets_dir.exists():
170
+ available_datasets = [d.name for d in datasets_dir.iterdir() if d.is_dir()]
171
+
172
+ if available_datasets:
173
+ selected_dataset = st.selectbox(
174
+ "Select Dataset",
175
+ available_datasets,
176
+ help="Choose from previously uploaded or existing datasets",
177
+ )
178
+
179
+ if selected_dataset:
180
+ st.session_state["selected_dataset"] = str(
181
+ datasets_dir / selected_dataset
182
+ )
183
+ display_dataset_info(datasets_dir / selected_dataset)
184
+ else:
185
+ st.warning("No datasets found. Please upload a dataset first.")
186
+ else:
187
+ st.warning("Datasets directory not found. Please upload a dataset first.")
188
+
189
+
190
+ def display_dataset_info(dataset_path: Path):
191
+ """Display information about selected dataset"""
192
+ if not dataset_path.exists():
193
+ return
194
+
195
+ # Count files by category
196
+ file_counts = {}
197
+ total_files = 0
198
+
199
+ for category_dir in dataset_path.iterdir():
200
+ if category_dir.is_dir():
201
+ count = (
202
+ len(list(category_dir.glob("*.txt")))
203
+ + len(list(category_dir.glob("*.csv")))
204
+ + len(list(category_dir.glob("*.json")))
205
+ )
206
+ file_counts[category_dir.name] = count
207
+ total_files += count
208
+
209
+ if file_counts:
210
+ st.info(f"**Dataset**: {dataset_path.name}")
211
+
212
+ col1, col2 = st.columns(2)
213
+ with col1:
214
+ st.metric("Total Files", total_files)
215
+ with col2:
216
+ st.metric("Categories", len(file_counts))
217
+
218
+ # Display breakdown
219
+ for category, count in file_counts.items():
220
+ st.text(f"• {category}: {count} files")
221
+
222
+
223
+ def render_training_parameters():
224
+ """Render training parameter configuration with enhanced options"""
225
+ st.markdown("#### Training Parameters")
226
+
227
+ col1, col2 = st.columns(2)
228
+
229
+ with col1:
230
+ epochs = st.number_input("Epochs", min_value=1, max_value=100, value=10)
231
+ batch_size = st.selectbox("Batch Size", [8, 16, 32, 64], index=1)
232
+ learning_rate = st.select_slider(
233
+ "Learning Rate",
234
+ options=[1e-4, 5e-4, 1e-3, 5e-3, 1e-2],
235
+ value=1e-3,
236
+ format_func=lambda x: f"{x:.0e}",
237
+ )
238
+
239
+ with col2:
240
+ num_folds = st.number_input(
241
+ "Cross-Validation Folds", min_value=3, max_value=10, value=10
242
+ )
243
+ target_len = st.number_input(
244
+ "Target Length", min_value=100, max_value=1000, value=500
245
+ )
246
+ modality = st.selectbox("Modality", ["raman", "ftir"], index=0)
247
+
248
+ # Advanced Cross-Validation Options
249
+ st.markdown("**Cross-Validation Strategy**")
250
+ cv_strategy = st.selectbox(
251
+ "CV Strategy",
252
+ ["stratified_kfold", "kfold", "time_series_split"],
253
+ index=0,
254
+ help="Choose CV strategy: Stratified K-Fold (recommended for balanced datasets), K-Fold (for any dataset), Time Series Split (for temporal data)",
255
+ )
256
+
257
+ # Data Augmentation Options
258
+ st.markdown("**Data Augmentation**")
259
+ col1, col2 = st.columns(2)
260
+
261
+ with col1:
262
+ enable_augmentation = st.checkbox(
263
+ "Enable Spectral Augmentation",
264
+ value=False,
265
+ help="Add realistic noise and variations to improve model robustness",
266
+ )
267
+ with col2:
268
+ noise_level = st.slider(
269
+ "Noise Level",
270
+ min_value=0.001,
271
+ max_value=0.05,
272
+ value=0.01,
273
+ step=0.001,
274
+ disabled=not enable_augmentation,
275
+ help="Amount of Gaussian noise to add for augmentation",
276
+ )
277
+
278
+ # Spectroscopy-Specific Options
279
+ st.markdown("**Spectroscopy-Specific Settings**")
280
+ spectral_weight = st.slider(
281
+ "Spectral Metrics Weight",
282
+ min_value=0.0,
283
+ max_value=1.0,
284
+ value=0.1,
285
+ step=0.05,
286
+ help="Weight for spectroscopy-specific metrics (cosine similarity, peak matching)",
287
+ )
288
+
289
+ # Preprocessing options
290
+ st.markdown("**Preprocessing Options**")
291
+ col1, col2, col3 = st.columns(3)
292
+
293
+ with col1:
294
+ baseline_correction = st.checkbox("Baseline Correction", value=True)
295
+ with col2:
296
+ smoothing = st.checkbox("Smoothing", value=True)
297
+ with col3:
298
+ normalization = st.checkbox("Normalization", value=True)
299
+
300
+ # Device selection
301
+ device_options = ["auto", "cpu"]
302
+ if torch.cuda.is_available():
303
+ device_options.append("cuda")
304
+
305
+ device = st.selectbox("Device", device_options, index=0)
306
+
307
+ # Store parameters in session state
308
+ st.session_state.update(
309
+ {
310
+ "train_epochs": epochs,
311
+ "train_batch_size": batch_size,
312
+ "train_learning_rate": learning_rate,
313
+ "train_num_folds": num_folds,
314
+ "train_target_len": target_len,
315
+ "train_modality": modality,
316
+ "train_cv_strategy": cv_strategy,
317
+ "train_enable_augmentation": enable_augmentation,
318
+ "train_noise_level": noise_level,
319
+ "train_spectral_weight": spectral_weight,
320
+ "train_baseline_correction": baseline_correction,
321
+ "train_smoothing": smoothing,
322
+ "train_normalization": normalization,
323
+ "train_device": device,
324
+ }
325
+ )
326
+
327
+
328
+ def render_training_status():
329
+ """Render training status and active jobs"""
330
+ st.markdown("### 📊 Training Status")
331
+
332
+ training_manager = get_training_manager()
333
+
334
+ # Active jobs
335
+ active_jobs = training_manager.list_jobs(TrainingStatus.RUNNING)
336
+ pending_jobs = training_manager.list_jobs(TrainingStatus.PENDING)
337
+
338
+ if active_jobs or pending_jobs:
339
+ st.markdown("#### Active Jobs")
340
+ for job in active_jobs + pending_jobs:
341
+ render_job_status_card(job)
342
+
343
+ # Recent completed jobs
344
+ completed_jobs = training_manager.list_jobs(TrainingStatus.COMPLETED)[
345
+ :3
346
+ ] # Show last 3
347
+ if completed_jobs:
348
+ st.markdown("#### Recent Completed")
349
+ for job in completed_jobs:
350
+ render_job_status_card(job, compact=True)
351
+
352
+
353
+ def render_job_status_card(job: TrainingJob, compact: bool = False):
354
+ """Render a status card for a training job"""
355
+ status_color = {
356
+ TrainingStatus.PENDING: "🟡",
357
+ TrainingStatus.RUNNING: "🔵",
358
+ TrainingStatus.COMPLETED: "🟢",
359
+ TrainingStatus.FAILED: "🔴",
360
+ TrainingStatus.CANCELLED: "⚫",
361
+ }
362
+
363
+ with st.expander(
364
+ f"{status_color[job.status]} {job.config.model_name} - {job.job_id[:8]}",
365
+ expanded=not compact,
366
+ ):
367
+ if not compact:
368
+ col1, col2 = st.columns(2)
369
+ with col1:
370
+ st.text(f"Model: {job.config.model_name}")
371
+ st.text(f"Dataset: {Path(job.config.dataset_path).name}")
372
+ st.text(f"Status: {job.status.value}")
373
+ with col2:
374
+ st.text(f"Created: {job.created_at.strftime('%H:%M:%S')}")
375
+ if job.status == TrainingStatus.RUNNING:
376
+ st.text(
377
+ f"Fold: {job.progress.current_fold}/{job.progress.total_folds}"
378
+ )
379
+ st.text(
380
+ f"Epoch: {job.progress.current_epoch}/{job.progress.total_epochs}"
381
+ )
382
+
383
+ if job.status == TrainingStatus.RUNNING:
384
+ # Progress bars
385
+ fold_progress = job.progress.current_fold / job.progress.total_folds
386
+ epoch_progress = job.progress.current_epoch / job.progress.total_epochs
387
+
388
+ st.progress(fold_progress)
389
+ st.caption(
390
+ f"Overall: {fold_progress:.1%} | Current Loss: {job.progress.current_loss:.4f}"
391
+ )
392
+
393
+ elif job.status == TrainingStatus.COMPLETED and job.progress.fold_accuracies:
394
+ mean_acc = np.mean(job.progress.fold_accuracies)
395
+ std_acc = np.std(job.progress.fold_accuracies)
396
+ st.success(f"✅ Accuracy: {mean_acc:.3f} ± {std_acc:.3f}")
397
+
398
+ elif job.status == TrainingStatus.FAILED:
399
+ st.error(f"❌ Error: {job.error_message}")
400
+
401
+
402
+ def render_training_progress():
403
+ """Render detailed training progress visualization"""
404
+ st.markdown("### 📈 Training Progress")
405
+
406
+ training_manager = get_training_manager()
407
+ active_jobs = training_manager.list_jobs(TrainingStatus.RUNNING)
408
+
409
+ if not active_jobs:
410
+ st.info("No active training jobs. Start a training job to see progress here.")
411
+ return
412
+
413
+ # Job selector for multiple active jobs
414
+ if len(active_jobs) > 1:
415
+ selected_job_id = st.selectbox(
416
+ "Select Job to Monitor",
417
+ [job.job_id for job in active_jobs],
418
+ format_func=lambda x: f"{x[:8]} - {next(job.config.model_name for job in active_jobs if job.job_id == x)}",
419
+ )
420
+ selected_job = next(job for job in active_jobs if job.job_id == selected_job_id)
421
+ else:
422
+ selected_job = active_jobs[0]
423
+
424
+ # Real-time progress visualization
425
+ render_job_progress_details(selected_job)
426
+
427
+
428
+ def render_job_progress_details(job: TrainingJob):
429
+ """Render detailed progress for a specific job with enhanced metrics"""
430
+ col1, col2 = st.columns(2)
431
+
432
+ with col1:
433
+ st.metric(
434
+ "Current Fold", f"{job.progress.current_fold}/{job.progress.total_folds}"
435
+ )
436
+ st.metric(
437
+ "Current Epoch", f"{job.progress.current_epoch}/{job.progress.total_epochs}"
438
+ )
439
+
440
+ with col2:
441
+ st.metric("Current Loss", f"{job.progress.current_loss:.4f}")
442
+ st.metric("Current Accuracy", f"{job.progress.current_accuracy:.3f}")
443
+
444
+ # Progress bars
445
+ fold_progress = (
446
+ job.progress.current_fold / job.progress.total_folds
447
+ if job.progress.total_folds > 0
448
+ else 0
449
+ )
450
+ epoch_progress = (
451
+ job.progress.current_epoch / job.progress.total_epochs
452
+ if job.progress.total_epochs > 0
453
+ else 0
454
+ )
455
+
456
+ st.progress(fold_progress)
457
+ st.caption(f"Overall Progress: {fold_progress:.1%}")
458
+
459
+ st.progress(epoch_progress)
460
+ st.caption(f"Current Fold Progress: {epoch_progress:.1%}")
461
+
462
+ # Enhanced metrics visualization
463
+ if job.progress.fold_accuracies and job.progress.spectroscopy_metrics:
464
+ col1, col2 = st.columns(2)
465
+
466
+ with col1:
467
+ # Standard accuracy chart
468
+ fig_acc = go.Figure(
469
+ data=go.Bar(
470
+ x=[f"Fold {i+1}" for i in range(len(job.progress.fold_accuracies))],
471
+ y=job.progress.fold_accuracies,
472
+ name="Validation Accuracy",
473
+ marker_color="lightblue",
474
+ )
475
+ )
476
+ fig_acc.update_layout(
477
+ title="Cross-Validation Accuracies by Fold",
478
+ yaxis_title="Accuracy",
479
+ height=300,
480
+ )
481
+ st.plotly_chart(fig_acc, use_container_width=True)
482
+
483
+ with col2:
484
+ # Spectroscopy-specific metrics
485
+ if len(job.progress.spectroscopy_metrics) > 0:
486
+ # Extract metrics across folds
487
+ f1_scores = [
488
+ m.get("f1_score", 0) for m in job.progress.spectroscopy_metrics
489
+ ]
490
+ cosine_sim = [
491
+ m.get("cosine_similarity", 0)
492
+ for m in job.progress.spectroscopy_metrics
493
+ ]
494
+ dist_sim = [
495
+ m.get("distribution_similarity", 0)
496
+ for m in job.progress.spectroscopy_metrics
497
+ ]
498
+
499
+ fig_spectro = go.Figure()
500
+
501
+ # Add traces for different metrics
502
+ fig_spectro.add_trace(
503
+ go.Scatter(
504
+ x=[f"Fold {i+1}" for i in range(len(f1_scores))],
505
+ y=f1_scores,
506
+ mode="lines+markers",
507
+ name="F1 Score",
508
+ line=dict(color="green"),
509
+ )
510
+ )
511
+
512
+ if any(c > 0 for c in cosine_sim):
513
+ fig_spectro.add_trace(
514
+ go.Scatter(
515
+ x=[f"Fold {i+1}" for i in range(len(cosine_sim))],
516
+ y=cosine_sim,
517
+ mode="lines+markers",
518
+ name="Cosine Similarity",
519
+ line={"color": "orange"},
520
+ )
521
+ )
522
+
523
+ fig_spectro.add_trace(
524
+ go.Scatter(
525
+ x=[f"Fold {i+1}" for i in range(len(dist_sim))],
526
+ y=dist_sim,
527
+ mode="lines+markers",
528
+ name="Distribution Similarity",
529
+ line=dict(color="purple"),
530
+ )
531
+ )
532
+
533
+ fig_spectro.update_layout(
534
+ title="Spectroscopy-Specific Metrics by Fold",
535
+ yaxis_title="Score",
536
+ height=300,
537
+ legend=dict(
538
+ orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
539
+ ),
540
+ )
541
+ st.plotly_chart(fig_spectro, use_container_width=True)
542
+
543
+ elif job.progress.fold_accuracies:
544
+ # Fallback to standard accuracy chart only
545
+ fig = go.Figure(
546
+ data=go.Bar(
547
+ x=[f"Fold {i+1}" for i in range(len(job.progress.fold_accuracies))],
548
+ y=job.progress.fold_accuracies,
549
+ name="Validation Accuracy",
550
+ )
551
+ )
552
+ fig.update_layout(
553
+ title="Cross-Validation Accuracies by Fold",
554
+ yaxis_title="Accuracy",
555
+ height=300,
556
+ )
557
+ st.plotly_chart(fig, use_container_width=True)
558
+
559
+
560
+ def render_training_history():
561
+ """Render training history and results"""
562
+ st.markdown("### 📚 Training History")
563
+
564
+ training_manager = get_training_manager()
565
+ all_jobs = training_manager.list_jobs()
566
+
567
+ if not all_jobs:
568
+ st.info("No training history available. Start training some models!")
569
+ return
570
+
571
+ # Convert to DataFrame for display
572
+ history_data = []
573
+ for job in all_jobs:
574
+ row = {
575
+ "Job ID": job.job_id[:8],
576
+ "Model": job.config.model_name,
577
+ "Dataset": Path(job.config.dataset_path).name,
578
+ "Status": job.status.value,
579
+ "Created": job.created_at.strftime("%Y-%m-%d %H:%M"),
580
+ "Duration": "",
581
+ "Accuracy": "",
582
+ }
583
+
584
+ if job.completed_at and job.started_at:
585
+ duration = job.completed_at - job.started_at
586
+ row["Duration"] = str(duration).split(".")[0] # Remove microseconds
587
+
588
+ if job.status == TrainingStatus.COMPLETED and job.progress.fold_accuracies:
589
+ mean_acc = np.mean(job.progress.fold_accuracies)
590
+ std_acc = np.std(job.progress.fold_accuracies)
591
+ row["Accuracy"] = f"{mean_acc:.3f} ± {std_acc:.3f}"
592
+
593
+ history_data.append(row)
594
+
595
+ df = pd.DataFrame(history_data)
596
+ st.dataframe(df, use_container_width=True)
597
+
598
+ # Job details
599
+ if st.checkbox("Show detailed results"):
600
+ completed_jobs = [
601
+ job for job in all_jobs if job.status == TrainingStatus.COMPLETED
602
+ ]
603
+ if completed_jobs:
604
+ selected_job_id = st.selectbox(
605
+ "Select job for details",
606
+ [job.job_id for job in completed_jobs],
607
+ format_func=lambda x: f"{x[:8]} - {next(job.config.model_name for job in completed_jobs if job.job_id == x)}",
608
+ )
609
+
610
+ selected_job = next(
611
+ job for job in completed_jobs if job.job_id == selected_job_id
612
+ )
613
+ render_training_results(selected_job)
614
+
615
+
616
+ def render_training_results(job: TrainingJob):
617
+ """Render detailed training results for a completed job with enhanced metrics"""
618
+ st.markdown(f"#### Results for {job.config.model_name} - {job.job_id[:8]}")
619
+
620
+ if not job.progress.fold_accuracies:
621
+ st.warning("No results available for this job.")
622
+ return
623
+
624
+ # Summary metrics
625
+ mean_acc = np.mean(job.progress.fold_accuracies)
626
+ std_acc = np.std(job.progress.fold_accuracies)
627
+
628
+ # Enhanced metrics display
629
+ col1, col2, col3, col4 = st.columns(4)
630
+ with col1:
631
+ st.metric("Mean Accuracy", f"{mean_acc:.3f}")
632
+ with col2:
633
+ st.metric("Std Deviation", f"{std_acc:.3f}")
634
+ with col3:
635
+ st.metric("Best Fold", f"{max(job.progress.fold_accuracies):.3f}")
636
+ with col4:
637
+ st.metric("CV Strategy", job.config.cv_strategy.replace("_", " ").title())
638
+
639
+ # Spectroscopy-specific metrics summary
640
+ if job.progress.spectroscopy_metrics:
641
+ st.markdown("**Spectroscopy-Specific Metrics Summary**")
642
+ spectro_summary = {}
643
+
644
+ for metric_name in ["f1_score", "cosine_similarity", "distribution_similarity"]:
645
+ values = [
646
+ m.get(metric_name, 0)
647
+ for m in job.progress.spectroscopy_metrics
648
+ if m.get(metric_name, 0) > 0
649
+ ]
650
+ if values:
651
+ spectro_summary[metric_name] = {
652
+ "mean": np.mean(values),
653
+ "std": np.std(values),
654
+ "best": max(values),
655
+ }
656
+
657
+ if spectro_summary:
658
+ cols = st.columns(len(spectro_summary))
659
+ for i, (metric, stats) in enumerate(spectro_summary.items()):
660
+ with cols[i]:
661
+ metric_display = metric.replace("_", " ").title()
662
+ st.metric(
663
+ f"{metric_display}",
664
+ f"{stats['mean']:.3f} ± {stats['std']:.3f}",
665
+ f"Best: {stats['best']:.3f}",
666
+ )
667
+
668
+ # Configuration summary
669
+ with st.expander("Training Configuration"):
670
+ config_display = {
671
+ "Model": job.config.model_name,
672
+ "Dataset": Path(job.config.dataset_path).name,
673
+ "Epochs": job.config.epochs,
674
+ "Batch Size": job.config.batch_size,
675
+ "Learning Rate": job.config.learning_rate,
676
+ "CV Folds": job.config.num_folds,
677
+ "CV Strategy": job.config.cv_strategy,
678
+ "Augmentation": "Enabled" if job.config.enable_augmentation else "Disabled",
679
+ "Noise Level": (
680
+ job.config.noise_level if job.config.enable_augmentation else "N/A"
681
+ ),
682
+ "Spectral Weight": job.config.spectral_weight,
683
+ "Device": job.config.device,
684
+ }
685
+
686
+ config_df = pd.DataFrame(
687
+ list(config_display.items()), columns=["Parameter", "Value"]
688
+ )
689
+ st.dataframe(config_df, use_container_width=True)
690
+
691
+ # Enhanced visualizations
692
+ col1, col2 = st.columns(2)
693
+
694
+ with col1:
695
+ # Accuracy distribution
696
+ fig_acc = go.Figure(
697
+ data=go.Box(y=job.progress.fold_accuracies, name="Fold Accuracies")
698
+ )
699
+ fig_acc.update_layout(
700
+ title="Cross-Validation Accuracy Distribution", yaxis_title="Accuracy"
701
+ )
702
+ st.plotly_chart(fig_acc, use_container_width=True)
703
+
704
+ with col2:
705
+ # Metrics comparison if available
706
+ if (
707
+ job.progress.spectroscopy_metrics
708
+ and len(job.progress.spectroscopy_metrics) > 0
709
+ ):
710
+ metrics_df = pd.DataFrame(job.progress.spectroscopy_metrics)
711
+
712
+ if not metrics_df.empty:
713
+ fig_metrics = go.Figure()
714
+
715
+ for col in metrics_df.columns:
716
+ if col in [
717
+ "accuracy",
718
+ "f1_score",
719
+ "cosine_similarity",
720
+ "distribution_similarity",
721
+ ]:
722
+ fig_metrics.add_trace(
723
+ go.Scatter(
724
+ x=list(range(1, len(metrics_df) + 1)),
725
+ y=metrics_df[col],
726
+ mode="lines+markers",
727
+ name=col.replace("_", " ").title(),
728
+ )
729
+ )
730
+
731
+ fig_metrics.update_layout(
732
+ title="All Metrics Across Folds",
733
+ xaxis_title="Fold",
734
+ yaxis_title="Score",
735
+ height=300,
736
+ )
737
+ st.plotly_chart(fig_metrics, use_container_width=True)
738
+
739
+ # Download options
740
+ col1, col2, col3 = st.columns(3)
741
+ with col1:
742
+ if st.button("📥 Download Weights", key=f"weights_{job.job_id}"):
743
+ if job.weights_path and os.path.exists(job.weights_path):
744
+ with open(job.weights_path, "rb") as f:
745
+ st.download_button(
746
+ "Download Model Weights",
747
+ f.read(),
748
+ file_name=f"{job.config.model_name}_{job.job_id[:8]}.pth",
749
+ mime="application/octet-stream",
750
+ )
751
+
752
+ with col2:
753
+ if st.button("📄 Download Logs", key=f"logs_{job.job_id}"):
754
+ if job.logs_path and os.path.exists(job.logs_path):
755
+ with open(job.logs_path, "r") as f:
756
+ st.download_button(
757
+ "Download Training Logs",
758
+ f.read(),
759
+ file_name=f"training_log_{job.job_id[:8]}.json",
760
+ mime="application/json",
761
+ )
762
+
763
+ with col3:
764
+ if st.button("📊 Download Metrics CSV", key=f"metrics_{job.job_id}"):
765
+ # Create comprehensive metrics CSV
766
+ metrics_data = []
767
+ for i, (acc, spectro) in enumerate(
768
+ zip(
769
+ job.progress.fold_accuracies,
770
+ job.progress.spectroscopy_metrics or [],
771
+ )
772
+ ):
773
+ row = {"fold": i + 1, "accuracy": acc}
774
+ if spectro:
775
+ row.update(spectro)
776
+ metrics_data.append(row)
777
+
778
+ metrics_df = pd.DataFrame(metrics_data)
779
+ csv = metrics_df.to_csv(index=False)
780
+ st.download_button(
781
+ "Download Metrics CSV",
782
+ csv,
783
+ file_name=f"metrics_{job.job_id[:8]}.csv",
784
+ mime="text/csv",
785
+ )
786
+
787
+ # Interpretability section
788
+ if st.checkbox("🔍 Show Model Interpretability", key=f"interpret_{job.job_id}"):
789
+ render_model_interpretability(job)
790
+
791
+
792
+ def render_model_interpretability(job: TrainingJob):
793
+ """Render model interpretability features"""
794
+ st.markdown("##### 🔍 Model Interpretability")
795
+
796
+ try:
797
+ # Try to load the trained model for interpretation
798
+ if not job.weights_path or not os.path.exists(job.weights_path):
799
+ st.warning("Model weights not available for interpretation.")
800
+ return
801
+
802
+ # Simple feature importance visualization
803
+ st.markdown("**Feature Importance Analysis**")
804
+
805
+ # Generate mock feature importance for demonstration
806
+ # In a real implementation, this would use SHAP, Captum, or gradient-based methods
807
+ wavenumbers = np.linspace(400, 4000, job.config.target_len)
808
+
809
+ # Simulate feature importance (peaks at common polymer bands)
810
+ importance = np.zeros_like(wavenumbers)
811
+
812
+ # Simulate important regions for polymer degradation
813
+ # C-H stretch (2800-3000 cm⁻¹)
814
+ ch_region = (wavenumbers >= 2800) & (wavenumbers <= 3000)
815
+ importance[ch_region] = np.random.normal(0.8, 0.1, (np.sum(ch_region),))
816
+
817
+ # C=O stretch (1600-1800 cm⁻¹) - often changes with degradation
818
+ co_region = (wavenumbers >= 1600) & (wavenumbers <= 1800)
819
+ importance[co_region] = np.random.normal(0.9, 0.1, int(np.sum(co_region)))
820
+
821
+ # Fingerprint region (400-1500 cm⁻¹)
822
+ fingerprint_region = (wavenumbers >= 400) & (wavenumbers <= 1500)
823
+ importance[fingerprint_region] = np.random.normal(
824
+ 0.3, 0.2, int(np.sum(fingerprint_region))
825
+ )
826
+
827
+ # Normalize importance
828
+ importance = np.abs(importance)
829
+ importance = (
830
+ importance / np.max(importance) if np.max(importance) > 0 else importance
831
+ )
832
+
833
+ # Create interpretability plot
834
+ fig_interpret = go.Figure()
835
+
836
+ # Add feature importance
837
+ fig_interpret.add_trace(
838
+ go.Scatter(
839
+ x=wavenumbers,
840
+ y=importance,
841
+ mode="lines",
842
+ name="Feature Importance",
843
+ fill="tonexty",
844
+ line=dict(color="red", width=2),
845
+ )
846
+ )
847
+
848
+ # Add annotations for important regions
849
+ fig_interpret.add_annotation(
850
+ x=2900,
851
+ y=0.8,
852
+ text="C-H Stretch<br>(Polymer backbone)",
853
+ showarrow=True,
854
+ arrowhead=2,
855
+ arrowcolor="blue",
856
+ bgcolor="lightblue",
857
+ bordercolor="blue",
858
+ )
859
+
860
+ fig_interpret.add_annotation(
861
+ x=1700,
862
+ y=0.9,
863
+ text="C=O Stretch<br>(Degradation marker)",
864
+ showarrow=True,
865
+ arrowhead=2,
866
+ arrowcolor="red",
867
+ bgcolor="lightcoral",
868
+ bordercolor="red",
869
+ )
870
+
871
+ fig_interpret.update_layout(
872
+ title="Model Feature Importance for Polymer Degradation Classification",
873
+ xaxis_title="Wavenumber (cm⁻¹)",
874
+ yaxis_title="Feature Importance",
875
+ height=400,
876
+ showlegend=False,
877
+ )
878
+
879
+ st.plotly_chart(fig_interpret, use_container_width=True)
880
+
881
+ # Interpretation insights
882
+ st.markdown("**Key Insights:**")
883
+ col1, col2 = st.columns(2)
884
+
885
+ with col1:
886
+ st.info(
887
+ "🔬 **High Importance Regions:**\n"
888
+ "- C=O stretch (1600-1800 cm⁻¹): Critical for degradation detection\n"
889
+ "- C-H stretch (2800-3000 cm⁻¹): Polymer backbone changes"
890
+ )
891
+
892
+ with col2:
893
+ st.info(
894
+ "📊 **Model Behavior:**\n"
895
+ "- Focuses on spectral regions known to change with polymer degradation\n"
896
+ "- Fingerprint region provides molecular specificity"
897
+ )
898
+
899
+ # Attention heatmap simulation
900
+ st.markdown("**Spectral Attention Heatmap**")
901
+
902
+ # Create a 2D heatmap showing attention across different samples
903
+ n_samples = 10
904
+ attention_matrix = np.random.beta(2, 5, (n_samples, len(wavenumbers)))
905
+
906
+ # Enhance attention in important regions
907
+ for i in range(n_samples):
908
+ attention_matrix[i, ch_region] *= np.random.uniform(2, 4)
909
+ attention_matrix[i, co_region] *= np.random.uniform(3, 5)
910
+
911
+ fig_heatmap = go.Figure(
912
+ data=go.Heatmap(
913
+ z=attention_matrix,
914
+ x=wavenumbers[::10], # Subsample for display
915
+ y=[f"Sample {i+1}" for i in range(n_samples)],
916
+ colorscale="Viridis",
917
+ colorbar=dict(title="Attention Score"),
918
+ )
919
+ )
920
+
921
+ fig_heatmap.update_layout(
922
+ title="Model Attention Across Different Samples",
923
+ xaxis_title="Wavenumber (cm⁻¹)",
924
+ yaxis_title="Sample",
925
+ height=300,
926
+ )
927
+
928
+ st.plotly_chart(fig_heatmap, use_container_width=True)
929
+
930
+ st.markdown(
931
+ "**Note:** *This interpretability analysis is simulated for demonstration. "
932
+ "In production, this would use actual gradient-based attribution methods "
933
+ "(SHAP, Integrated Gradients, etc.) on the trained model.*"
934
+ )
935
+
936
+ except Exception as e:
937
+ st.error(f"Error generating interpretability analysis: {e}")
938
+ st.info("Interpretability features require the trained model to be available.")
939
+
940
+
941
+ def start_training_job():
942
+ """Start a new training job with current configuration"""
943
+ # Validate configuration
944
+ if "selected_dataset" not in st.session_state:
945
+ st.error("❌ Please select a dataset first.")
946
+ return
947
+
948
+ if not Path(st.session_state["selected_dataset"]).exists():
949
+ st.error("❌ Selected dataset path does not exist.")
950
+ return
951
+
952
+ # Create training configuration
953
+ config = TrainingConfig(
954
+ model_name=st.session_state.get("selected_model", "figure2"),
955
+ dataset_path=st.session_state["selected_dataset"],
956
+ target_len=st.session_state.get("train_target_len", 500),
957
+ batch_size=st.session_state.get("train_batch_size", 16),
958
+ epochs=st.session_state.get("train_epochs", 10),
959
+ learning_rate=st.session_state.get("train_learning_rate", 1e-3),
960
+ num_folds=st.session_state.get("train_num_folds", 10),
961
+ baseline_correction=st.session_state.get("train_baseline_correction", True),
962
+ smoothing=st.session_state.get("train_smoothing", True),
963
+ normalization=st.session_state.get("train_normalization", True),
964
+ modality=st.session_state.get("train_modality", "raman"),
965
+ device=st.session_state.get("train_device", "auto"),
966
+ cv_strategy=st.session_state.get("train_cv_strategy", "stratified_kfold"),
967
+ enable_augmentation=st.session_state.get("train_enable_augmentation", False),
968
+ noise_level=st.session_state.get("train_noise_level", 0.01),
969
+ spectral_weight=st.session_state.get("train_spectral_weight", 0.1),
970
+ )
971
+
972
+ # Submit job
973
+ training_manager = get_training_manager()
974
+ job_id = training_manager.submit_training_job(config)
975
+
976
+ st.success(f"✅ Training job started! Job ID: {job_id[:8]}")
977
+ st.info("Monitor progress in the Training Status section above.")
978
+
979
+ # Auto-refresh to show new job
980
+ time.sleep(1)
981
+ st.rerun()
982
+
983
+
984
+ def save_uploaded_dataset(
985
+ uploaded_files, dataset_name: str, file_labels: Dict[str, str]
986
+ ):
987
+ """Save uploaded dataset to local storage"""
988
+ try:
989
+ # Create dataset directory
990
+ dataset_dir = Path("datasets") / dataset_name
991
+ dataset_dir.mkdir(parents=True, exist_ok=True)
992
+
993
+ # Create label directories
994
+ (dataset_dir / "stable").mkdir(exist_ok=True)
995
+ (dataset_dir / "weathered").mkdir(exist_ok=True)
996
+
997
+ # Save files
998
+ saved_count = 0
999
+ for file in uploaded_files:
1000
+ # Determine label
1001
+ label = file_labels.get(file.name, "stable") # Default to stable
1002
+ if "weathered" in file.name.lower() or "degraded" in file.name.lower():
1003
+ label = "weathered"
1004
+
1005
+ # Save file
1006
+ target_path = dataset_dir / label / file.name
1007
+ with open(target_path, "wb") as f:
1008
+ f.write(file.getbuffer())
1009
+ saved_count += 1
1010
+
1011
+ st.success(
1012
+ f"✅ Dataset '{dataset_name}' saved successfully! {saved_count} files processed."
1013
+ )
1014
+ st.session_state["selected_dataset"] = str(dataset_dir)
1015
+
1016
+ # Display saved dataset info
1017
+ display_dataset_info(dataset_dir)
1018
+
1019
+ except Exception as e:
1020
+ st.error(f"❌ Error saving dataset: {str(e)}")
1021
+
1022
+
1023
+ # Auto-refresh for active training jobs
1024
+ def setup_training_auto_refresh():
1025
+ """Set up auto-refresh for training progress"""
1026
+ if "training_auto_refresh" not in st.session_state:
1027
+ st.session_state.training_auto_refresh = True
1028
+
1029
+ training_manager = get_training_manager()
1030
+ active_jobs = training_manager.list_jobs(TrainingStatus.RUNNING)
1031
+
1032
+ if active_jobs and st.session_state.training_auto_refresh:
1033
+ # Auto-refresh every 5 seconds if there are active jobs
1034
+ time.sleep(5)
1035
+ st.rerun()