polymer-aging-ml / models /registry.py
devjas1
(FEAT/REFAC)[Expand Registry & Metadata]: Enhance model registry with new models, richer metadata, and utility functions.
0be85e4
raw
history blame
8.6 kB
# models/registry.py
from typing import Callable, Dict, List, Any
from models.figure2_cnn import Figure2CNN
from models.resnet_cnn import ResNet1D
from models.resnet18_vision import ResNet18Vision
from models.enhanced_cnn import EnhancedCNN, EfficientSpectralCNN, HybridSpectralNet
# Internal registry of model builders keyed by short name.
_REGISTRY: Dict[str, Callable[[int], object]] = {
"figure2": lambda L: Figure2CNN(input_length=L),
"resnet": lambda L: ResNet1D(input_length=L),
"resnet18vision": lambda L: ResNet18Vision(input_length=L),
"enhanced_cnn": lambda L: EnhancedCNN(input_length=L),
"efficient_cnn": lambda L: EfficientSpectralCNN(input_length=L),
"hybrid_net": lambda L: HybridSpectralNet(input_length=L),
}
# Model specifications with metadata for enhanced features
_MODEL_SPECS: Dict[str, Dict[str, Any]] = {
"figure2": {
"input_length": 500,
"num_classes": 2,
"description": "Figure 2 baseline custom implementation",
"modalities": ["raman", "ftir"],
"citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
"performance": {"accuracy": 0.948, "f1_score": 0.943},
"parameters": "~500K",
"speed": "fast",
},
"resnet": {
"input_length": 500,
"num_classes": 2,
"description": "(Residual Network) uses skip connections to train much deeper networks",
"modalities": ["raman", "ftir"],
"citation": "Custom ResNet implementation",
"performance": {"accuracy": 0.962, "f1_score": 0.959},
"parameters": "~100K",
"speed": "very_fast",
},
"resnet18vision": {
"input_length": 500,
"num_classes": 2,
"description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
"modalities": ["raman", "ftir"],
"citation": "ResNet18 Vision adaptation",
"performance": {"accuracy": 0.945, "f1_score": 0.940},
"parameters": "~11M",
"speed": "medium",
},
"enhanced_cnn": {
"input_length": 500,
"num_classes": 2,
"description": "Enhanced CNN with attention mechanisms and multi-scale feature extraction",
"modalities": ["raman", "ftir"],
"citation": "Custom enhanced architecture with attention",
"performance": {"accuracy": 0.975, "f1_score": 0.973},
"parameters": "~800K",
"speed": "medium",
"features": ["attention", "multi_scale", "batch_norm", "dropout"],
},
"efficient_cnn": {
"input_length": 500,
"num_classes": 2,
"description": "Efficient CNN optimized for real-time inference with depthwise separable convolutions",
"modalities": ["raman", "ftir"],
"citation": "Custom efficient architecture",
"performance": {"accuracy": 0.955, "f1_score": 0.952},
"parameters": "~200K",
"speed": "very_fast",
"features": ["depthwise_separable", "lightweight", "real_time"],
},
"hybrid_net": {
"input_length": 500,
"num_classes": 2,
"description": "Hybrid network combining CNN backbone with self-attention mechanisms",
"modalities": ["raman", "ftir"],
"citation": "Custom hybrid CNN-Transformer architecture",
"performance": {"accuracy": 0.968, "f1_score": 0.965},
"parameters": "~1.2M",
"speed": "medium",
"features": ["self_attention", "cnn_backbone", "transformer_head"],
},
}
# Placeholder for future model expansions
_FUTURE_MODELS = {
"densenet1d": {
"description": "DenseNet1D for spectroscopy with dense connections",
"status": "planned",
"modalities": ["raman", "ftir"],
"features": ["dense_connections", "parameter_efficient"],
},
"ensemble_cnn": {
"description": "Ensemble of multiple CNN variants for robust predictions",
"status": "planned",
"modalities": ["raman", "ftir"],
"features": ["ensemble", "robust", "high_accuracy"],
},
"vision_transformer": {
"description": "Vision Transformer adapted for 1D spectral data",
"status": "planned",
"modalities": ["raman", "ftir"],
"features": ["transformer", "attention", "state_of_art"],
},
"autoencoder_cnn": {
"description": "CNN with autoencoder for unsupervised feature learning",
"status": "planned",
"modalities": ["raman", "ftir"],
"features": ["autoencoder", "unsupervised", "feature_learning"],
},
}
def choices():
"""Return the list of available model keys."""
return list(_REGISTRY.keys())
def planned_models():
"""Return the list of planned future model keys."""
return list(_FUTURE_MODELS.keys())
def build(name: str, input_length: int):
"""Instantiate a model by short name with the given input length."""
if name not in _REGISTRY:
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
return _REGISTRY[name](input_length)
def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
"""Nuild multiple models for comparison."""
models = {}
for name in names:
if name in _REGISTRY:
models[name] = build(name, input_length)
else:
raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
return models
def register_model(
name: str, builder: Callable[[int], object], spec: Dict[str, Any]
) -> None:
"""Dynamically register a new model."""
if name in _REGISTRY:
raise ValueError(f"Model '{name}' already registered.")
if not callable(builder):
raise TypeError("Builder must be a callable that accepts an integer argument.")
_REGISTRY[name] = builder
_MODEL_SPECS[name] = spec
def spec(name: str):
"""Return expected input length and number of classes for a model key."""
if name in _MODEL_SPECS:
return _MODEL_SPECS[name].copy()
raise KeyError(f"Unknown model '{name}'. Available: {choices()}")
def get_model_info(name: str) -> Dict[str, Any]:
"""Get comprehensive model information including metadata."""
if name in _MODEL_SPECS:
return _MODEL_SPECS[name].copy()
elif name in _FUTURE_MODELS:
return _FUTURE_MODELS[name].copy()
else:
raise KeyError(f"Unknown model '{name}'")
def models_for_modality(modality: str) -> List[str]:
"""Get list of models that support a specific modality."""
compatible = []
for name, spec_info in _MODEL_SPECS.items():
if modality in spec_info.get("modalities", []):
compatible.append(name)
return compatible
def validate_model_list(names: List[str]) -> List[str]:
"""Validate and return list of available models from input list."""
available = choices()
valid_models = []
for name in names:
if name in available: # Fixed: was using 'is' instead of 'in'
valid_models.append(name)
return valid_models
def get_models_metadata() -> Dict[str, Dict[str, Any]]:
"""Get metadata for all registered models."""
return {name: _MODEL_SPECS[name].copy() for name in _MODEL_SPECS}
def is_model_compatible(name: str, modality: str) -> bool:
"""Check if a model is compatible with a specific modality."""
if name not in _MODEL_SPECS:
return False
return modality in _MODEL_SPECS[name].get("modalities", [])
def get_model_capabilities(name: str) -> Dict[str, Any]:
"""Get detailed capabilities of a model."""
if name not in _MODEL_SPECS:
raise KeyError(f"Unknown model '{name}'")
spec = _MODEL_SPECS[name].copy()
spec.update(
{
"available": True,
"status": "active",
"supported_tasks": ["binary_classification"],
"performance_metrics": {
"supports_confidence": True,
"supports_batch": True,
"memory_efficient": spec.get("description", "").lower().find("resnet")
!= -1,
},
}
)
return spec
__all__ = [
"choices",
"build",
"spec",
"build_multiple",
"register_model",
"get_model_info",
"models_for_modality",
"validate_model_list",
"planned_models",
"get_models_metadata",
"is_model_compatible",
"get_model_capabilities",
]