devjas1 commited on
Commit
218c86b
·
1 Parent(s): 12ab884

(chore): unify model selection via shared registry (train uses 'choices()'/'build()')

Browse files
Files changed (2) hide show
  1. models/registry.py +24 -0
  2. scripts/train_model.py +15 -2
models/registry.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/registry.py
2
+ from typing import Callable, Dict
3
+ from models.figure2_cnn import Figure2CNN
4
+ from models.resnet_cnn import ResNet1D
5
+ # from models.resnet18_vision import ResNet18Vision # (Step 2)
6
+
7
+ # Internal registry of model builders keyed by short name.
8
+ _REGISTRY: Dict[str, Callable[[int], object]] = {
9
+ "figure2": lambda L: Figure2CNN(input_length=L),
10
+ "resnet": lambda L: ResNet1D(input_length=L),
11
+ # "resnet18vision": lambda L: ResNet18Vision(input_length=L)
12
+ }
13
+
14
+ def choices():
15
+ """Return the list of available model keys."""
16
+ return list(_REGISTRY.keys())
17
+
18
+ def build(name: str, input_length: int):
19
+ """Instantiate a model by short name with the given input length."""
20
+ if name not in _REGISTRY:
21
+ raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
22
+ return _REGISTRY[name](input_length)
23
+
24
+ __all__ = ["choices", "build"]
scripts/train_model.py CHANGED
@@ -7,10 +7,23 @@ from torch.utils.data import TensorDataset, DataLoader
7
  from sklearn.model_selection import StratifiedKFold
8
  from sklearn.metrics import confusion_matrix
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Add project-specific imports
11
  from scripts.preprocess_dataset import preprocess_dataset
12
- from models.figure2_cnn import Figure2CNN
13
- from models.resnet_cnn import ResNet1D
14
 
15
  # Argument parser for CLI usage
16
  parser = argparse.ArgumentParser(
 
7
  from sklearn.model_selection import StratifiedKFold
8
  from sklearn.metrics import confusion_matrix
9
 
10
+ import random
11
+ import json
12
+
13
+ # Reproducibility
14
+ SEED = 42
15
+ random.seed(SEED)
16
+ np.random.seed(SEED)
17
+ torch.manual_seed(SEED)
18
+ if torch.cuda.is_available():
19
+ torch.cuda.manual_seed_all(SEED)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+
23
  # Add project-specific imports
24
  from scripts.preprocess_dataset import preprocess_dataset
25
+ from models.registry import choices as model_choices, build as build_model
26
+
27
 
28
  # Argument parser for CLI usage
29
  parser = argparse.ArgumentParser(