EinFields / flax_models /__init__.py
AndreiB137's picture
update file structure
2da731b
from .mlp import MLP
from .siren import SIREN
from .wire import WIRE
from .activations import get_activation, list_activations
from flax import serialization
import os
import yaml
import jax
import jax.numpy as jnp
from ml_collections import ConfigDict
model_key_dict = {
"MLP": MLP,
"SIREN": SIREN,
"WIRE": WIRE
}
def make_model(config):
"""
Create and configure a flax neural network nn.Module based on configuration.
Args:
config: Model configuration containing:
- model_name: Type of model (MLP, SIREN, WIRE, etc.)
- output_dim: Number of output dimensions
- hidden_dim: Hidden layer dimensions
- num_layers: Number of layers
- activation: Activation function name
- extra_model_args: Additional model-specific arguments
Returns:
model (nn.Module): Configured flax nn.Module instance ready for training
Note:
Handles special case for WIRE and SIREN models which don't accept
activation functions as an argument.
"""
model = get_model(config.model_name)
if config.extra_model_args is not None:
if config.model_name == "WIRE" or config.model_name == "SIREN":
model = model(output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
**config.extra_model_args)
else:
model = model(output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
act=get_activation(config.activation),
**config.extra_model_args)
else:
model = model(output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
act=get_activation(config.activation),
)
return model
def load_metric_from_model(model_dir):
"""
Load the model state from a given directory.
If the model has output dimension of 10, meaning
it was trained only on the symmetric part of the metric,
it reconstructs the full metric tensor.
Args:
model_dir (str): Directory containing the model state file.
Returns:
callable: The metric tensor function from the model.
"""
with open(os.path.join(model_dir, "params.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read())
with open(os.path.join(model_dir, "architecture.yml"), "r") as f:
config_model = yaml.load(f, Loader=yaml.FullLoader)
config_model = ConfigDict(config_model)
model = make_model(config_model.architecture)
if config_model.architecture.output_dim == 16:
return lambda coords: model.apply(params, coords).reshape(4, 4)
elif config_model.architecture.output_dim == 10:
return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4)
def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array:
"""returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric"""
i, j = jnp.triu_indices(n, k=0)
matrix = jnp.zeros((n, n))
matrix = matrix.at[i, j].set(metric_sym)
matrix = matrix.at[j, i].set(metric_sym)
return matrix
def get_model(model_name : str):
"""
Get the model class by name.
Args:
model_name (str): Name of the model.
Returns:
nn.Module: The model class.
"""
if model_name not in model_key_dict:
raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}")
return model_key_dict[model_name]
def create_model_configs():
"""
Create a dictionary of model configurations.
Returns:
dict: A dictionary of model configurations.
"""
model_configs = {
"MLP": {},
"SIREN": {
"omega_0": 3.
},
"WIRE": {
"first_omega_0": 4.,
"hidden_omega_0": 4.,
"scale": 5.,
}
}
return model_configs
model_configs = create_model_configs()
def get_extra_model_cfg(model_name: str):
"""
Get the extra model configuration for a given model name.
Args:
model_name (str): Name of the model.
Returns:
dict: The extra model configuration.
"""
if model_name not in model_configs:
raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}")
return model_configs[model_name]