|
from .mlp import MLP |
|
from .siren import SIREN |
|
from .wire import WIRE |
|
from .activations import get_activation, list_activations |
|
from flax import serialization |
|
import os |
|
import yaml |
|
import jax |
|
import jax.numpy as jnp |
|
from ml_collections import ConfigDict |
|
|
|
model_key_dict = { |
|
"MLP": MLP, |
|
"SIREN": SIREN, |
|
"WIRE": WIRE |
|
} |
|
|
|
def make_model(config): |
|
""" |
|
Create and configure a flax neural network nn.Module based on configuration. |
|
|
|
Args: |
|
config: Model configuration containing: |
|
- model_name: Type of model (MLP, SIREN, WIRE, etc.) |
|
- output_dim: Number of output dimensions |
|
- hidden_dim: Hidden layer dimensions |
|
- num_layers: Number of layers |
|
- activation: Activation function name |
|
- extra_model_args: Additional model-specific arguments |
|
|
|
Returns: |
|
model (nn.Module): Configured flax nn.Module instance ready for training |
|
|
|
Note: |
|
Handles special case for WIRE and SIREN models which don't accept |
|
activation functions as an argument. |
|
""" |
|
|
|
model = get_model(config.model_name) |
|
if config.extra_model_args is not None: |
|
if config.model_name == "WIRE" or config.model_name == "SIREN": |
|
model = model(output_dim=config.output_dim, |
|
hidden_dim=config.hidden_dim, |
|
num_layers=config.num_layers, |
|
**config.extra_model_args) |
|
else: |
|
model = model(output_dim=config.output_dim, |
|
hidden_dim=config.hidden_dim, |
|
num_layers=config.num_layers, |
|
act=get_activation(config.activation), |
|
**config.extra_model_args) |
|
else: |
|
model = model(output_dim=config.output_dim, |
|
hidden_dim=config.hidden_dim, |
|
num_layers=config.num_layers, |
|
act=get_activation(config.activation), |
|
) |
|
|
|
return model |
|
|
|
def load_metric_from_model(model_dir): |
|
""" |
|
Load the model state from a given directory. |
|
If the model has output dimension of 10, meaning |
|
it was trained only on the symmetric part of the metric, |
|
it reconstructs the full metric tensor. |
|
|
|
Args: |
|
model_dir (str): Directory containing the model state file. |
|
|
|
Returns: |
|
callable: The metric tensor function from the model. |
|
|
|
""" |
|
with open(os.path.join(model_dir, "params.msgpack"), "rb") as f: |
|
params = serialization.msgpack_restore(f.read()) |
|
|
|
with open(os.path.join(model_dir, "architecture.yml"), "r") as f: |
|
config_model = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
config_model = ConfigDict(config_model) |
|
model = make_model(config_model.architecture) |
|
|
|
if config_model.architecture.output_dim == 16: |
|
return lambda coords: model.apply(params, coords).reshape(4, 4) |
|
elif config_model.architecture.output_dim == 10: |
|
return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4) |
|
|
|
def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array: |
|
"""returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric""" |
|
i, j = jnp.triu_indices(n, k=0) |
|
matrix = jnp.zeros((n, n)) |
|
matrix = matrix.at[i, j].set(metric_sym) |
|
matrix = matrix.at[j, i].set(metric_sym) |
|
return matrix |
|
|
|
def get_model(model_name : str): |
|
""" |
|
Get the model class by name. |
|
|
|
Args: |
|
model_name (str): Name of the model. |
|
|
|
Returns: |
|
nn.Module: The model class. |
|
""" |
|
if model_name not in model_key_dict: |
|
raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}") |
|
|
|
return model_key_dict[model_name] |
|
|
|
def create_model_configs(): |
|
""" |
|
Create a dictionary of model configurations. |
|
|
|
Returns: |
|
dict: A dictionary of model configurations. |
|
""" |
|
model_configs = { |
|
"MLP": {}, |
|
"SIREN": { |
|
"omega_0": 3. |
|
}, |
|
"WIRE": { |
|
"first_omega_0": 4., |
|
"hidden_omega_0": 4., |
|
"scale": 5., |
|
} |
|
|
|
} |
|
return model_configs |
|
|
|
model_configs = create_model_configs() |
|
|
|
def get_extra_model_cfg(model_name: str): |
|
""" |
|
Get the extra model configuration for a given model name. |
|
|
|
Args: |
|
model_name (str): Name of the model. |
|
|
|
Returns: |
|
dict: The extra model configuration. |
|
""" |
|
if model_name not in model_configs: |
|
raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}") |
|
|
|
return model_configs[model_name] |