Spaces:
Sleeping
Sleeping
devjas1
(chore): unify model selection via shared registry (train uses 'choices()'/'build()')
218c86b
| import os | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from datetime import datetime | |
| import argparse, numpy as np, torch | |
| from torch.utils.data import TensorDataset, DataLoader | |
| from sklearn.model_selection import StratifiedKFold | |
| from sklearn.metrics import confusion_matrix | |
| import random | |
| import json | |
| # Reproducibility | |
| SEED = 42 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(SEED) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # Add project-specific imports | |
| from scripts.preprocess_dataset import preprocess_dataset | |
| from models.registry import choices as model_choices, build as build_model | |
| # Argument parser for CLI usage | |
| parser = argparse.ArgumentParser( | |
| description="Run 10-fold CV on Raman data with optional preprocessing.") | |
| parser.add_argument("--target-len", type=int, default=500) | |
| parser.add_argument("--baseline", action="store_true") | |
| parser.add_argument("--smooth", action="store_true") | |
| parser.add_argument("--normalize", action="store_true") | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--epochs", type=int, default=10) | |
| parser.add_argument("--learning-rate", type=float, default=1e-3) | |
| parser.add_argument("--model", type=str, default="figure2", choices=model_choices()) | |
| args = parser.parse_args() | |
| # Constants | |
| # Raman-only dataset (RDWP) | |
| DATASET_PATH = 'datasets/rdwp' | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu') | |
| NUM_FOLDS = 10 | |
| # Ensure output dirs exist | |
| os.makedirs("outputs", exist_ok=True) | |
| os.makedirs("outputs/logs", exist_ok=True) | |
| print("Preprocessing Configuration:") | |
| print(f" Resample to : {args.target_len}") | |
| print(f" Baseline Correct: {'β ' if args.baseline else 'β'}") | |
| print(f" Smoothing : {'β ' if args.smooth else 'β'}") | |
| print(f" Normalization : {'β ' if args.normalize else 'β'}") | |
| # Load + Preprocess data | |
| print("π Loading and preprocessing data ...") | |
| X, y = preprocess_dataset( | |
| DATASET_PATH, | |
| target_len=args.target_len, | |
| baseline_correction=args.baseline, | |
| apply_smoothing=args.smooth, | |
| normalize=args.normalize | |
| ) | |
| X, y = np.array(X, np.float32), np.array(y, np.int64) | |
| print(f"β Data Loaded: {X.shape[0]} samples, {X.shape[1]} features each.") | |
| print(f"π Using model: {args.model}") | |
| # CV | |
| skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42) | |
| fold_accuracies = [] | |
| all_conf_matrices = [] | |
| for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1): | |
| print(f"\nπ Fold {fold}/{NUM_FOLDS}") | |
| X_train, X_val = X[train_idx], X[val_idx] | |
| y_train, y_val = y[train_idx], y[val_idx] | |
| train_loader = DataLoader( | |
| TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)), | |
| batch_size=args.batch_size, shuffle=True) | |
| val_loader = DataLoader( | |
| TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))) | |
| # Model selection | |
| model = build_model(args.model, args.target_len).to(DEVICE) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| for epoch in range(args.epochs): | |
| model.train() | |
| RUNNING_LOSS = 0.0 | |
| for inputs, labels in train_loader: | |
| inputs = inputs.unsqueeze(1).to(DEVICE) | |
| labels = labels.to(DEVICE) | |
| optimizer.zero_grad() | |
| loss = criterion(model(inputs), labels) | |
| loss.backward() | |
| optimizer.step() | |
| RUNNING_LOSS += loss.item() | |
| # After fold loop (outside the epoch loop), print 1 line: | |
| print(f"β Fold {fold} done. Final loss: {RUNNING_LOSS:.4f}") | |
| # Evaluation | |
| model.eval() | |
| all_true, all_pred = [], [] | |
| with torch.no_grad(): | |
| for inputs, labels in val_loader: | |
| inputs = inputs.unsqueeze(1).to(DEVICE) | |
| labels = labels.to(DEVICE) | |
| outputs = model(inputs) | |
| _, predicted = torch.max(outputs, 1) | |
| all_true.extend(labels.cpu().numpy()) | |
| all_pred.extend(predicted.cpu().numpy()) | |
| acc = 100 * np.mean(np.array(all_true) == np.array(all_pred)) | |
| fold_accuracies.append(acc) | |
| all_conf_matrices.append(confusion_matrix(all_true, all_pred)) | |
| print(f"β Fold {fold} Accuracy: {acc:.2f}%") | |
| # Save model checkpoint **after** final fold | |
| model_path = f"outputs/{args.model}_model.pth" | |
| torch.save(model.state_dict(), model_path) | |
| # Summary | |
| mean_acc, std_acc = np.mean(fold_accuracies), np.std(fold_accuracies) | |
| print("\nπ Cross-Validation Results:") | |
| for i, a in enumerate(fold_accuracies, 1): | |
| print(f"Fold {i}: {a:.2f}%") | |
| print(f"\nβ Mean Accuracy: {mean_acc:.2f}% Β± {std_acc:.2f}%") | |
| print(f"β Model saved to {model_path}") | |
| # Save diagnostics | |
| def save_diagnostics_log(fold_acc, confs, args_param, output_path): | |
| fold_metrics = [ | |
| {"fold": i + 1, "accuracy": float(a), "confusion_matrix": c.tolist()} | |
| for i, (a, c) in enumerate(zip(fold_acc, confs)) | |
| ] | |
| log = { | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "preprocessing": { | |
| "target_len": args_param.target_len, | |
| "baseline": args_param.baseline, | |
| "smooth": args_param.smooth, | |
| "normalize": args_param.normalize, | |
| }, | |
| "fold_metrics": fold_metrics, | |
| "overall": { | |
| "mean_accuracy": float(np.mean(fold_acc)), | |
| "std_accuracy": float(np.std(fold_acc)), | |
| "num_folds": len(fold_acc), | |
| "batch_size": args_param.batch_size, | |
| "epochs": args_param.epochs, | |
| "learning_rate": args_param.learning_rate, | |
| "device": str(DEVICE) | |
| } | |
| } | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| json.dump(log, f, indent=2) | |
| print(f"π§ Diagnostics written to {output_path}") | |
| log_path = f"outputs/logs/raman_{args.model}_diagnostics.json" | |
| save_diagnostics_log(fold_accuracies, all_conf_matrices, args, log_path) |