Spaces:
Running
Running
devjas1
commited on
Commit
·
218c86b
1
Parent(s):
12ab884
(chore): unify model selection via shared registry (train uses 'choices()'/'build()')
Browse files- models/registry.py +24 -0
- scripts/train_model.py +15 -2
models/registry.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/registry.py
|
2 |
+
from typing import Callable, Dict
|
3 |
+
from models.figure2_cnn import Figure2CNN
|
4 |
+
from models.resnet_cnn import ResNet1D
|
5 |
+
# from models.resnet18_vision import ResNet18Vision # (Step 2)
|
6 |
+
|
7 |
+
# Internal registry of model builders keyed by short name.
|
8 |
+
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
9 |
+
"figure2": lambda L: Figure2CNN(input_length=L),
|
10 |
+
"resnet": lambda L: ResNet1D(input_length=L),
|
11 |
+
# "resnet18vision": lambda L: ResNet18Vision(input_length=L)
|
12 |
+
}
|
13 |
+
|
14 |
+
def choices():
|
15 |
+
"""Return the list of available model keys."""
|
16 |
+
return list(_REGISTRY.keys())
|
17 |
+
|
18 |
+
def build(name: str, input_length: int):
|
19 |
+
"""Instantiate a model by short name with the given input length."""
|
20 |
+
if name not in _REGISTRY:
|
21 |
+
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
22 |
+
return _REGISTRY[name](input_length)
|
23 |
+
|
24 |
+
__all__ = ["choices", "build"]
|
scripts/train_model.py
CHANGED
@@ -7,10 +7,23 @@ from torch.utils.data import TensorDataset, DataLoader
|
|
7 |
from sklearn.model_selection import StratifiedKFold
|
8 |
from sklearn.metrics import confusion_matrix
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# Add project-specific imports
|
11 |
from scripts.preprocess_dataset import preprocess_dataset
|
12 |
-
from models.
|
13 |
-
|
14 |
|
15 |
# Argument parser for CLI usage
|
16 |
parser = argparse.ArgumentParser(
|
|
|
7 |
from sklearn.model_selection import StratifiedKFold
|
8 |
from sklearn.metrics import confusion_matrix
|
9 |
|
10 |
+
import random
|
11 |
+
import json
|
12 |
+
|
13 |
+
# Reproducibility
|
14 |
+
SEED = 42
|
15 |
+
random.seed(SEED)
|
16 |
+
np.random.seed(SEED)
|
17 |
+
torch.manual_seed(SEED)
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
torch.cuda.manual_seed_all(SEED)
|
20 |
+
torch.backends.cudnn.deterministic = True
|
21 |
+
torch.backends.cudnn.benchmark = False
|
22 |
+
|
23 |
# Add project-specific imports
|
24 |
from scripts.preprocess_dataset import preprocess_dataset
|
25 |
+
from models.registry import choices as model_choices, build as build_model
|
26 |
+
|
27 |
|
28 |
# Argument parser for CLI usage
|
29 |
parser = argparse.ArgumentParser(
|