File size: 8,602 Bytes
218c86b
71b3dbd
218c86b
 
71b3dbd
0be85e4
218c86b
 
 
 
 
71b3dbd
0be85e4
 
 
218c86b
 
71b3dbd
 
 
 
 
0be85e4
71b3dbd
 
0be85e4
 
 
71b3dbd
 
 
 
 
 
 
0be85e4
 
 
71b3dbd
 
 
 
 
 
 
0be85e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b3dbd
 
 
 
 
 
0be85e4
71b3dbd
0be85e4
 
71b3dbd
 
0be85e4
 
 
 
 
 
 
71b3dbd
0be85e4
 
 
 
 
 
 
 
71b3dbd
 
 
 
218c86b
 
 
 
71b3dbd
 
 
 
 
 
218c86b
 
 
 
 
 
71b3dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6373c5a
 
71b3dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0be85e4
71b3dbd
 
6373c5a
 
0be85e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b3dbd
 
 
 
 
 
 
 
 
 
0be85e4
 
 
71b3dbd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# models/registry.py
from typing import Callable, Dict, List, Any
from models.figure2_cnn import Figure2CNN
from models.resnet_cnn import ResNet1D
from models.resnet18_vision import ResNet18Vision
from models.enhanced_cnn import EnhancedCNN, EfficientSpectralCNN, HybridSpectralNet

# Internal registry of model builders keyed by short name.
_REGISTRY: Dict[str, Callable[[int], object]] = {
    "figure2": lambda L: Figure2CNN(input_length=L),
    "resnet": lambda L: ResNet1D(input_length=L),
    "resnet18vision": lambda L: ResNet18Vision(input_length=L),
    "enhanced_cnn": lambda L: EnhancedCNN(input_length=L),
    "efficient_cnn": lambda L: EfficientSpectralCNN(input_length=L),
    "hybrid_net": lambda L: HybridSpectralNet(input_length=L),
}

# Model specifications with metadata for enhanced features
_MODEL_SPECS: Dict[str, Dict[str, Any]] = {
    "figure2": {
        "input_length": 500,
        "num_classes": 2,
        "description": "Figure 2 baseline custom implementation",
        "modalities": ["raman", "ftir"],
        "citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
        "performance": {"accuracy": 0.948, "f1_score": 0.943},
        "parameters": "~500K",
        "speed": "fast",
    },
    "resnet": {
        "input_length": 500,
        "num_classes": 2,
        "description": "(Residual Network) uses skip connections to train much deeper networks",
        "modalities": ["raman", "ftir"],
        "citation": "Custom ResNet implementation",
        "performance": {"accuracy": 0.962, "f1_score": 0.959},
        "parameters": "~100K",
        "speed": "very_fast",
    },
    "resnet18vision": {
        "input_length": 500,
        "num_classes": 2,
        "description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
        "modalities": ["raman", "ftir"],
        "citation": "ResNet18 Vision adaptation",
        "performance": {"accuracy": 0.945, "f1_score": 0.940},
        "parameters": "~11M",
        "speed": "medium",
    },
    "enhanced_cnn": {
        "input_length": 500,
        "num_classes": 2,
        "description": "Enhanced CNN with attention mechanisms and multi-scale feature extraction",
        "modalities": ["raman", "ftir"],
        "citation": "Custom enhanced architecture with attention",
        "performance": {"accuracy": 0.975, "f1_score": 0.973},
        "parameters": "~800K",
        "speed": "medium",
        "features": ["attention", "multi_scale", "batch_norm", "dropout"],
    },
    "efficient_cnn": {
        "input_length": 500,
        "num_classes": 2,
        "description": "Efficient CNN optimized for real-time inference with depthwise separable convolutions",
        "modalities": ["raman", "ftir"],
        "citation": "Custom efficient architecture",
        "performance": {"accuracy": 0.955, "f1_score": 0.952},
        "parameters": "~200K",
        "speed": "very_fast",
        "features": ["depthwise_separable", "lightweight", "real_time"],
    },
    "hybrid_net": {
        "input_length": 500,
        "num_classes": 2,
        "description": "Hybrid network combining CNN backbone with self-attention mechanisms",
        "modalities": ["raman", "ftir"],
        "citation": "Custom hybrid CNN-Transformer architecture",
        "performance": {"accuracy": 0.968, "f1_score": 0.965},
        "parameters": "~1.2M",
        "speed": "medium",
        "features": ["self_attention", "cnn_backbone", "transformer_head"],
    },
}

# Placeholder for future model expansions
_FUTURE_MODELS = {
    "densenet1d": {
        "description": "DenseNet1D for spectroscopy with dense connections",
        "status": "planned",
        "modalities": ["raman", "ftir"],
        "features": ["dense_connections", "parameter_efficient"],
    },
    "ensemble_cnn": {
        "description": "Ensemble of multiple CNN variants for robust predictions",
        "status": "planned",
        "modalities": ["raman", "ftir"],
        "features": ["ensemble", "robust", "high_accuracy"],
    },
    "vision_transformer": {
        "description": "Vision Transformer adapted for 1D spectral data",
        "status": "planned",
        "modalities": ["raman", "ftir"],
        "features": ["transformer", "attention", "state_of_art"],
    },
    "autoencoder_cnn": {
        "description": "CNN with autoencoder for unsupervised feature learning",
        "status": "planned",
        "modalities": ["raman", "ftir"],
        "features": ["autoencoder", "unsupervised", "feature_learning"],
    },
}


def choices():
    """Return the list of available model keys."""
    return list(_REGISTRY.keys())


def planned_models():
    """Return the list of planned future model keys."""
    return list(_FUTURE_MODELS.keys())


def build(name: str, input_length: int):
    """Instantiate a model by short name with the given input length."""
    if name not in _REGISTRY:
        raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
    return _REGISTRY[name](input_length)


def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
    """Nuild multiple models for comparison."""
    models = {}
    for name in names:
        if name in _REGISTRY:
            models[name] = build(name, input_length)
        else:
            raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
    return models


def register_model(

    name: str, builder: Callable[[int], object], spec: Dict[str, Any]

) -> None:
    """Dynamically register a new model."""
    if name in _REGISTRY:
        raise ValueError(f"Model '{name}' already registered.")
    if not callable(builder):
        raise TypeError("Builder must be a callable that accepts an integer argument.")
    _REGISTRY[name] = builder
    _MODEL_SPECS[name] = spec


def spec(name: str):
    """Return expected input length and number of classes for a model key."""
    if name in _MODEL_SPECS:
        return _MODEL_SPECS[name].copy()
    raise KeyError(f"Unknown model '{name}'. Available: {choices()}")


def get_model_info(name: str) -> Dict[str, Any]:
    """Get comprehensive model information including metadata."""
    if name in _MODEL_SPECS:
        return _MODEL_SPECS[name].copy()
    elif name in _FUTURE_MODELS:
        return _FUTURE_MODELS[name].copy()
    else:
        raise KeyError(f"Unknown model '{name}'")


def models_for_modality(modality: str) -> List[str]:
    """Get list of models that support a specific modality."""
    compatible = []
    for name, spec_info in _MODEL_SPECS.items():
        if modality in spec_info.get("modalities", []):
            compatible.append(name)
    return compatible


def validate_model_list(names: List[str]) -> List[str]:
    """Validate and return list of available models from input list."""
    available = choices()
    valid_models = []
    for name in names:
        if name in available:  # Fixed: was using 'is' instead of 'in'
            valid_models.append(name)
    return valid_models


def get_models_metadata() -> Dict[str, Dict[str, Any]]:
    """Get metadata for all registered models."""
    return {name: _MODEL_SPECS[name].copy() for name in _MODEL_SPECS}


def is_model_compatible(name: str, modality: str) -> bool:
    """Check if a model is compatible with a specific modality."""
    if name not in _MODEL_SPECS:
        return False
    return modality in _MODEL_SPECS[name].get("modalities", [])


def get_model_capabilities(name: str) -> Dict[str, Any]:
    """Get detailed capabilities of a model."""
    if name not in _MODEL_SPECS:
        raise KeyError(f"Unknown model '{name}'")

    spec = _MODEL_SPECS[name].copy()
    spec.update(
        {
            "available": True,
            "status": "active",
            "supported_tasks": ["binary_classification"],
            "performance_metrics": {
                "supports_confidence": True,
                "supports_batch": True,
                "memory_efficient": spec.get("description", "").lower().find("resnet")
                != -1,
            },
        }
    )
    return spec


__all__ = [
    "choices",
    "build",
    "spec",
    "build_multiple",
    "register_model",
    "get_model_info",
    "models_for_modality",
    "validate_model_list",
    "planned_models",
    "get_models_metadata",
    "is_model_compatible",
    "get_model_capabilities",
]