devjas1 commited on
Commit
12ab884
Β·
1 Parent(s): 4b66627

(chore): pre-flight hardening for model expansion (seeds, typo, diagnostics, dtypes, optional deterministic cuDNN)

Browse files
Files changed (1) hide show
  1. scripts/train_model.py +13 -11
scripts/train_model.py CHANGED
@@ -1,4 +1,5 @@
1
- import os, sys, json
 
2
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
3
  from datetime import datetime
4
  import argparse, numpy as np, torch
@@ -21,8 +22,8 @@ parser.add_argument("--normalize", action="store_true")
21
  parser.add_argument("--batch-size", type=int, default=16)
22
  parser.add_argument("--epochs", type=int, default=10)
23
  parser.add_argument("--learning-rate", type=float, default=1e-3)
24
- parser.add_argument("--model", type=str, default="figure2",
25
- choices=["figure2", "resnet"])
26
  args = parser.parse_args()
27
 
28
  # Constants
@@ -36,7 +37,8 @@ os.makedirs("outputs", exist_ok=True)
36
  os.makedirs("outputs/logs", exist_ok=True)
37
 
38
  print("Preprocessing Configuration:")
39
- print(f" Reseample to : {args.target_len}")
 
40
  print(f" Baseline Correct: {'βœ…' if args.baseline else '❌'}")
41
  print(f" Smoothing : {'βœ…' if args.smooth else '❌'}")
42
  print(f" Normalization : {'βœ…' if args.normalize else '❌'}")
@@ -66,14 +68,13 @@ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
66
  y_train, y_val = y[train_idx], y[val_idx]
67
 
68
  train_loader = DataLoader(
69
- TensorDataset(torch.tensor(X_train), torch.tensor(y_train)),
70
  batch_size=args.batch_size, shuffle=True)
71
  val_loader = DataLoader(
72
- TensorDataset(torch.tensor(X_val), torch.tensor(y_val)),batch_size=args.batch_size)
73
 
74
  # Model selection
75
- model = (Figure2CNN if args.model == "figure2" else ResNet1D)(
76
- input_length=args.target_len).to(DEVICE)
77
 
78
  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
79
  criterion = torch.nn.CrossEntropyLoss()
@@ -127,9 +128,10 @@ print(f"βœ… Model saved to {model_path}")
127
 
128
 
129
  def save_diagnostics_log(fold_acc, confs, args_param, output_path):
130
- fold_metrics = [{"fold": i+1, "accuracy": acc,
131
- "confusion_matrix": c.tolist()}
132
- for i, (a, c) in enumerate(zip(fold_acc, confs))]
 
133
  log = {
134
  "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
135
  "preprocessing": {
 
1
+ import os
2
+ import sys
3
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
4
  from datetime import datetime
5
  import argparse, numpy as np, torch
 
22
  parser.add_argument("--batch-size", type=int, default=16)
23
  parser.add_argument("--epochs", type=int, default=10)
24
  parser.add_argument("--learning-rate", type=float, default=1e-3)
25
+ parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
26
+
27
  args = parser.parse_args()
28
 
29
  # Constants
 
37
  os.makedirs("outputs/logs", exist_ok=True)
38
 
39
  print("Preprocessing Configuration:")
40
+ print(f" Resample to : {args.target_len}")
41
+
42
  print(f" Baseline Correct: {'βœ…' if args.baseline else '❌'}")
43
  print(f" Smoothing : {'βœ…' if args.smooth else '❌'}")
44
  print(f" Normalization : {'βœ…' if args.normalize else '❌'}")
 
68
  y_train, y_val = y[train_idx], y[val_idx]
69
 
70
  train_loader = DataLoader(
71
+ TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
72
  batch_size=args.batch_size, shuffle=True)
73
  val_loader = DataLoader(
74
+ TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)))
75
 
76
  # Model selection
77
+ model = build_model(args.model, args.target_len).to(DEVICE)
 
78
 
79
  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
80
  criterion = torch.nn.CrossEntropyLoss()
 
128
 
129
 
130
  def save_diagnostics_log(fold_acc, confs, args_param, output_path):
131
+ fold_metrics = [
132
+ {"fold": i + 1, "accuracy": float(a), "confusion_matrix": c.tolist()}
133
+ for i, (a, c) in enumerate(zip(fold_acc, confs))
134
+ ]
135
  log = {
136
  "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
137
  "preprocessing": {