File size: 4,465 Bytes
218c86b
71b3dbd
218c86b
 
71b3dbd
218c86b
 
 
 
 
71b3dbd
218c86b
 
71b3dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218c86b
 
 
 
71b3dbd
 
 
 
 
 
218c86b
 
 
 
 
 
71b3dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6373c5a
 
71b3dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6373c5a
 
71b3dbd
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# models/registry.py
from typing import Callable, Dict, List, Any
from models.figure2_cnn import Figure2CNN
from models.resnet_cnn import ResNet1D
from models.resnet18_vision import ResNet18Vision

# Internal registry of model builders keyed by short name.
_REGISTRY: Dict[str, Callable[[int], object]] = {
    "figure2": lambda L: Figure2CNN(input_length=L),
    "resnet": lambda L: ResNet1D(input_length=L),
    "resnet18vision": lambda L: ResNet18Vision(input_length=L),
}

# Model specifications with metadata for enhanced features
_MODEL_SPECS: Dict[str, Dict[str, Any]] = {
    "figure2": {
        "input_length": 500,
        "num_classes": 2,
        "description": "Figure 2 baseline custom implemetation",
        "modalities": ["raman", "ftir"],
        "citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
    },
    "resnet": {
        "input_length": 500,
        "num_classes": 2,
        "description": "(Residual Network) uses skip connections to train much deeper networks",
        "modalities": ["raman", "ftir"],
        "citation": "Custom ResNet implementation",
    },
    "resnet18vision": {
        "input_length": 500,
        "num_classes": 2,
        "description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
        "modalities": ["raman", "ftir"],
        "citation": "ResNet18 Vision adaptation",
    },
}

# Placeholder for future model expansions
_FUTURE_MODELS = {
    "densenet1d": {
        "description": "DenseNet1D for spectroscopy (placeholder)",
        "status": "planned",
    },
    "ensemble_cnn": {
        "description": "Ensemble of CNN variants (placeholder)",
        "status": "planned",
    },
}


def choices():
    """Return the list of available model keys."""
    return list(_REGISTRY.keys())


def planned_models():
    """Return the list of planned future model keys."""
    return list(_FUTURE_MODELS.keys())


def build(name: str, input_length: int):
    """Instantiate a model by short name with the given input length."""
    if name not in _REGISTRY:
        raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
    return _REGISTRY[name](input_length)


def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
    """Nuild multiple models for comparison."""
    models = {}
    for name in names:
        if name in _REGISTRY:
            models[name] = build(name, input_length)
        else:
            raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
    return models


def register_model(

    name: str, builder: Callable[[int], object], spec: Dict[str, Any]

) -> None:
    """Dynamically register a new model."""
    if name in _REGISTRY:
        raise ValueError(f"Model '{name}' already registered.")
    if not callable(builder):
        raise TypeError("Builder must be a callable that accepts an integer argument.")
    _REGISTRY[name] = builder
    _MODEL_SPECS[name] = spec


def spec(name: str):
    """Return expected input length and number of classes for a model key."""
    if name in _MODEL_SPECS:
        return _MODEL_SPECS[name].copy()
    raise KeyError(f"Unknown model '{name}'. Available: {choices()}")


def get_model_info(name: str) -> Dict[str, Any]:
    """Get comprehensive model information including metadata."""
    if name in _MODEL_SPECS:
        return _MODEL_SPECS[name].copy()
    elif name in _FUTURE_MODELS:
        return _FUTURE_MODELS[name].copy()
    else:
        raise KeyError(f"Unknown model '{name}'")


def models_for_modality(modality: str) -> List[str]:
    """Get list of models that support a specific modality."""
    compatible = []
    for name, spec_info in _MODEL_SPECS.items():
        if modality in spec_info.get("modalities", []):
            compatible.append(name)
    return compatible


def validate_model_list(names: List[str]) -> List[str]:
    """Validate and return list of available models from input list."""
    available = choices()
    valid_models = []
    for name in names:
        if name is available:
            valid_models.append(name)
    return valid_models


__all__ = [
    "choices",
    "build",
    "spec",
    "build_multiple",
    "register_model",
    "get_model_info",
    "models_for_modality",
    "validate_model_list",
    "planned_models",
]