File size: 4,695 Bytes
2da731b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from .mlp import MLP
from .siren import SIREN
from .wire import WIRE
from .activations import get_activation, list_activations
from flax import serialization
import os
import yaml
import jax
import jax.numpy as jnp
from ml_collections import ConfigDict
model_key_dict = {
"MLP": MLP,
"SIREN": SIREN,
"WIRE": WIRE
}
def make_model(config):
"""
Create and configure a flax neural network nn.Module based on configuration.
Args:
config: Model configuration containing:
- model_name: Type of model (MLP, SIREN, WIRE, etc.)
- output_dim: Number of output dimensions
- hidden_dim: Hidden layer dimensions
- num_layers: Number of layers
- activation: Activation function name
- extra_model_args: Additional model-specific arguments
Returns:
model (nn.Module): Configured flax nn.Module instance ready for training
Note:
Handles special case for WIRE and SIREN models which don't accept
activation functions as an argument.
"""
model = get_model(config.model_name)
if config.extra_model_args is not None:
if config.model_name == "WIRE" or config.model_name == "SIREN":
model = model(output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
**config.extra_model_args)
else:
model = model(output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
act=get_activation(config.activation),
**config.extra_model_args)
else:
model = model(output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
act=get_activation(config.activation),
)
return model
def load_metric_from_model(model_dir):
"""
Load the model state from a given directory.
If the model has output dimension of 10, meaning
it was trained only on the symmetric part of the metric,
it reconstructs the full metric tensor.
Args:
model_dir (str): Directory containing the model state file.
Returns:
callable: The metric tensor function from the model.
"""
with open(os.path.join(model_dir, "params.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read())
with open(os.path.join(model_dir, "architecture.yml"), "r") as f:
config_model = yaml.load(f, Loader=yaml.FullLoader)
config_model = ConfigDict(config_model)
model = make_model(config_model.architecture)
if config_model.architecture.output_dim == 16:
return lambda coords: model.apply(params, coords).reshape(4, 4)
elif config_model.architecture.output_dim == 10:
return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4)
def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array:
"""returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric"""
i, j = jnp.triu_indices(n, k=0)
matrix = jnp.zeros((n, n))
matrix = matrix.at[i, j].set(metric_sym)
matrix = matrix.at[j, i].set(metric_sym)
return matrix
def get_model(model_name : str):
"""
Get the model class by name.
Args:
model_name (str): Name of the model.
Returns:
nn.Module: The model class.
"""
if model_name not in model_key_dict:
raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}")
return model_key_dict[model_name]
def create_model_configs():
"""
Create a dictionary of model configurations.
Returns:
dict: A dictionary of model configurations.
"""
model_configs = {
"MLP": {},
"SIREN": {
"omega_0": 3.
},
"WIRE": {
"first_omega_0": 4.,
"hidden_omega_0": 4.,
"scale": 5.,
}
}
return model_configs
model_configs = create_model_configs()
def get_extra_model_cfg(model_name: str):
"""
Get the extra model configuration for a given model name.
Args:
model_name (str): Name of the model.
Returns:
dict: The extra model configuration.
"""
if model_name not in model_configs:
raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}")
return model_configs[model_name] |