File size: 4,695 Bytes
2da731b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from .mlp import MLP
from .siren import SIREN
from .wire import WIRE
from .activations import get_activation, list_activations
from flax import serialization
import os
import yaml
import jax
import jax.numpy as jnp
from ml_collections import ConfigDict

model_key_dict = {
    "MLP": MLP,
    "SIREN": SIREN,
    "WIRE": WIRE
}

def make_model(config):
    """
    Create and configure a flax neural network nn.Module based on configuration.

    Args:
        config: Model configuration containing:
            - model_name: Type of model (MLP, SIREN, WIRE, etc.)
            - output_dim: Number of output dimensions
            - hidden_dim: Hidden layer dimensions
            - num_layers: Number of layers
            - activation: Activation function name
            - extra_model_args: Additional model-specific arguments
            
    Returns:
        model (nn.Module): Configured flax nn.Module instance ready for training
        
    Note:
        Handles special case for WIRE and SIREN models which don't accept
        activation functions as an argument.
    """

    model = get_model(config.model_name)
    if config.extra_model_args is not None:
        if config.model_name == "WIRE" or config.model_name == "SIREN":
            model = model(output_dim=config.output_dim,
                          hidden_dim=config.hidden_dim,
                          num_layers=config.num_layers,
                          **config.extra_model_args)
        else:
            model = model(output_dim=config.output_dim,
                        hidden_dim=config.hidden_dim, 
                        num_layers=config.num_layers,
                        act=get_activation(config.activation),
                        **config.extra_model_args)
    else:
        model = model(output_dim=config.output_dim, 
                      hidden_dim=config.hidden_dim, 
                      num_layers=config.num_layers,
                      act=get_activation(config.activation),
    )   
        
    return model

def load_metric_from_model(model_dir):
    """
    Load the model state from a given directory.
    If the model has output dimension of 10, meaning
    it was trained only on the symmetric part of the metric,
    it reconstructs the full metric tensor.
    
    Args:
        model_dir (str): Directory containing the model state file.

    Returns:
        callable: The metric tensor function from the model.

    """
    with open(os.path.join(model_dir, "params.msgpack"), "rb") as f:
        params = serialization.msgpack_restore(f.read())
    
    with open(os.path.join(model_dir, "architecture.yml"), "r") as f:
        config_model = yaml.load(f, Loader=yaml.FullLoader)

    config_model = ConfigDict(config_model)
    model = make_model(config_model.architecture)

    if config_model.architecture.output_dim == 16:
        return lambda coords: model.apply(params, coords).reshape(4, 4)
    elif config_model.architecture.output_dim == 10:
        return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4)

def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array: 
    """returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric""" 
    i, j = jnp.triu_indices(n, k=0)
    matrix = jnp.zeros((n, n))
    matrix = matrix.at[i, j].set(metric_sym)
    matrix = matrix.at[j, i].set(metric_sym)
    return matrix

def get_model(model_name : str):
    """
    Get the model class by name.

    Args:
        model_name (str): Name of the model.

    Returns:
        nn.Module: The model class.
    """
    if model_name not in model_key_dict:
        raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}")
    
    return model_key_dict[model_name]

def create_model_configs():
    """
    Create a dictionary of model configurations.

    Returns:
        dict: A dictionary of model configurations.
    """
    model_configs = {
        "MLP": {},
        "SIREN": {
            "omega_0": 3.
        },
        "WIRE": {
            "first_omega_0": 4.,
            "hidden_omega_0": 4.,
            "scale": 5.,
        }

    }
    return model_configs

model_configs = create_model_configs()

def get_extra_model_cfg(model_name: str):
    """
    Get the extra model configuration for a given model name.

    Args:
        model_name (str): Name of the model.

    Returns:
        dict: The extra model configuration.
    """
    if model_name not in model_configs:
        raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}")

    return model_configs[model_name]