from .mlp import MLP from .siren import SIREN from .wire import WIRE from .activations import get_activation, list_activations from flax import serialization import os import yaml import jax import jax.numpy as jnp from ml_collections import ConfigDict model_key_dict = { "MLP": MLP, "SIREN": SIREN, "WIRE": WIRE } def make_model(config): """ Create and configure a flax neural network nn.Module based on configuration. Args: config: Model configuration containing: - model_name: Type of model (MLP, SIREN, WIRE, etc.) - output_dim: Number of output dimensions - hidden_dim: Hidden layer dimensions - num_layers: Number of layers - activation: Activation function name - extra_model_args: Additional model-specific arguments Returns: model (nn.Module): Configured flax nn.Module instance ready for training Note: Handles special case for WIRE and SIREN models which don't accept activation functions as an argument. """ model = get_model(config.model_name) if config.extra_model_args is not None: if config.model_name == "WIRE" or config.model_name == "SIREN": model = model(output_dim=config.output_dim, hidden_dim=config.hidden_dim, num_layers=config.num_layers, **config.extra_model_args) else: model = model(output_dim=config.output_dim, hidden_dim=config.hidden_dim, num_layers=config.num_layers, act=get_activation(config.activation), **config.extra_model_args) else: model = model(output_dim=config.output_dim, hidden_dim=config.hidden_dim, num_layers=config.num_layers, act=get_activation(config.activation), ) return model def load_metric_from_model(model_dir): """ Load the model state from a given directory. If the model has output dimension of 10, meaning it was trained only on the symmetric part of the metric, it reconstructs the full metric tensor. Args: model_dir (str): Directory containing the model state file. Returns: callable: The metric tensor function from the model. """ with open(os.path.join(model_dir, "params.msgpack"), "rb") as f: params = serialization.msgpack_restore(f.read()) with open(os.path.join(model_dir, "architecture.yml"), "r") as f: config_model = yaml.load(f, Loader=yaml.FullLoader) config_model = ConfigDict(config_model) model = make_model(config_model.architecture) if config_model.architecture.output_dim == 16: return lambda coords: model.apply(params, coords).reshape(4, 4) elif config_model.architecture.output_dim == 10: return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4) def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array: """returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric""" i, j = jnp.triu_indices(n, k=0) matrix = jnp.zeros((n, n)) matrix = matrix.at[i, j].set(metric_sym) matrix = matrix.at[j, i].set(metric_sym) return matrix def get_model(model_name : str): """ Get the model class by name. Args: model_name (str): Name of the model. Returns: nn.Module: The model class. """ if model_name not in model_key_dict: raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}") return model_key_dict[model_name] def create_model_configs(): """ Create a dictionary of model configurations. Returns: dict: A dictionary of model configurations. """ model_configs = { "MLP": {}, "SIREN": { "omega_0": 3. }, "WIRE": { "first_omega_0": 4., "hidden_omega_0": 4., "scale": 5., } } return model_configs model_configs = create_model_configs() def get_extra_model_cfg(model_name: str): """ Get the extra model configuration for a given model name. Args: model_name (str): Name of the model. Returns: dict: The extra model configuration. """ if model_name not in model_configs: raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}") return model_configs[model_name]