AndreiB137 commited on
Commit
2da731b
·
1 Parent(s): 3ee0667

update file structure

Browse files
Files changed (46) hide show
  1. __init__.py → flax_models/__init__.py +87 -30
  2. flax_models/__pycache__/__init__.cpython-311.pyc +0 -0
  3. flax_models/__pycache__/activations.cpython-311.pyc +0 -0
  4. flax_models/__pycache__/mlp.cpython-311.pyc +0 -0
  5. flax_models/__pycache__/siren.cpython-311.pyc +0 -0
  6. flax_models/__pycache__/utils.cpython-311.pyc +0 -0
  7. flax_models/__pycache__/wire.cpython-311.pyc +0 -0
  8. activations.py → flax_models/activations.py +0 -0
  9. mlp.py → flax_models/mlp.py +0 -0
  10. siren.py → flax_models/siren.py +0 -0
  11. flax_models/tmp.txt +0 -0
  12. utils.py → flax_models/utils.py +0 -0
  13. wire.py → flax_models/wire.py +0 -0
  14. gw/cartesian/silu/architecture.yml +9 -0
  15. gw/cartesian/silu/params.msgpack +3 -0
  16. gw/cartesian/silu/train_data.yml +19 -0
  17. gw/cartesian/siren/architecture.yml +10 -0
  18. gw/cartesian/siren/params.msgpack +3 -0
  19. gw/cartesian/siren/train_data.yml +19 -0
  20. gw/cartesian/wire/architecture.yml +12 -0
  21. gw/cartesian/wire/params.msgpack +3 -0
  22. gw/cartesian/wire/train_data.yml +19 -0
  23. gw/tmp.txt +0 -0
  24. kerr/boyer_lindquist/a_0.623/architecture.yml +9 -0
  25. kerr/boyer_lindquist/a_0.623/params.msgpack +3 -0
  26. kerr/boyer_lindquist/a_0.623/train_data.yml +19 -0
  27. kerr/boyer_lindquist/a_0.628/architecture.yml +9 -0
  28. kerr/boyer_lindquist/a_0.628/params.msgpack +3 -0
  29. kerr/boyer_lindquist/a_0.628/train_data.yml +19 -0
  30. kerr/boyer_lindquist/prograde/architecture.yml +9 -0
  31. kerr/boyer_lindquist/prograde/params.msgpack +3 -0
  32. kerr/boyer_lindquist/prograde/train_data.yml +19 -0
  33. kerr/boyer_lindquist/zackiger/architecture.yml +9 -0
  34. kerr/boyer_lindquist/zackiger/params.msgpack +3 -0
  35. kerr/boyer_lindquist/zackiger/train_data.yml +19 -0
  36. kerr/kerr_schild_cartesian/architecture.yml +9 -0
  37. kerr/kerr_schild_cartesian/params.msgpack +3 -0
  38. kerr/kerr_schild_cartesian/train_data.yml +19 -0
  39. kerr/tmp.txt +0 -0
  40. schwarzschild/spherical/close_event_horizon/architecture.yml +9 -0
  41. schwarzschild/spherical/close_event_horizon/params.msgpack +3 -0
  42. schwarzschild/spherical/close_event_horizon/train_data.yml +19 -0
  43. schwarzschild/spherical/perihelion/architecture.yml +9 -0
  44. schwarzschild/spherical/perihelion/params.msgpack +3 -0
  45. schwarzschild/spherical/perihelion/train_data.yml +19 -0
  46. schwarzschild/tmp.txt +0 -0
__init__.py → flax_models/__init__.py RENAMED
@@ -1,18 +1,99 @@
1
- from .mlp_pinn import MLP_PINN
2
- from .PirateNet import PirateNet
3
  from .mlp import MLP
4
  from .siren import SIREN
5
  from .wire import WIRE
6
  from .activations import get_activation, list_activations
 
 
 
 
 
 
7
 
8
  model_key_dict = {
9
  "MLP": MLP,
10
  "SIREN": SIREN,
11
- "WIRE": WIRE,
12
- "PirateNet": PirateNet,
13
- "MLP_PINN": MLP_PINN
14
  }
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def get_model(model_name : str):
17
  """
18
  Get the model class by name.
@@ -44,31 +125,7 @@ def create_model_configs():
44
  "first_omega_0": 4.,
45
  "hidden_omega_0": 4.,
46
  "scale": 5.,
47
- },
48
- "PirateNet": {
49
- "nonlinearity": 0.0,
50
- "pi_init": None,
51
- "reparam": {
52
- "type": "weight_fact",
53
- "mean": 1.0,
54
- "stddev": 0.1,
55
- },
56
- "fourier_emb": {
57
- "embed_scale": 2.,
58
- "embed_dim": 256,
59
- },
60
- },
61
- "MLP_PINN": {
62
- "reparam": {
63
- "type": "weight_fact",
64
- "mean": 1.0,
65
- "stddev": 0.1,
66
- },
67
- "fourier_emb": {
68
- "embed_scale": 2.,
69
- "embed_dim": 256,
70
- },
71
- },
72
 
73
  }
74
  return model_configs
 
 
 
1
  from .mlp import MLP
2
  from .siren import SIREN
3
  from .wire import WIRE
4
  from .activations import get_activation, list_activations
5
+ from flax import serialization
6
+ import os
7
+ import yaml
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from ml_collections import ConfigDict
11
 
12
  model_key_dict = {
13
  "MLP": MLP,
14
  "SIREN": SIREN,
15
+ "WIRE": WIRE
 
 
16
  }
17
 
18
+ def make_model(config):
19
+ """
20
+ Create and configure a flax neural network nn.Module based on configuration.
21
+
22
+ Args:
23
+ config: Model configuration containing:
24
+ - model_name: Type of model (MLP, SIREN, WIRE, etc.)
25
+ - output_dim: Number of output dimensions
26
+ - hidden_dim: Hidden layer dimensions
27
+ - num_layers: Number of layers
28
+ - activation: Activation function name
29
+ - extra_model_args: Additional model-specific arguments
30
+
31
+ Returns:
32
+ model (nn.Module): Configured flax nn.Module instance ready for training
33
+
34
+ Note:
35
+ Handles special case for WIRE and SIREN models which don't accept
36
+ activation functions as an argument.
37
+ """
38
+
39
+ model = get_model(config.model_name)
40
+ if config.extra_model_args is not None:
41
+ if config.model_name == "WIRE" or config.model_name == "SIREN":
42
+ model = model(output_dim=config.output_dim,
43
+ hidden_dim=config.hidden_dim,
44
+ num_layers=config.num_layers,
45
+ **config.extra_model_args)
46
+ else:
47
+ model = model(output_dim=config.output_dim,
48
+ hidden_dim=config.hidden_dim,
49
+ num_layers=config.num_layers,
50
+ act=get_activation(config.activation),
51
+ **config.extra_model_args)
52
+ else:
53
+ model = model(output_dim=config.output_dim,
54
+ hidden_dim=config.hidden_dim,
55
+ num_layers=config.num_layers,
56
+ act=get_activation(config.activation),
57
+ )
58
+
59
+ return model
60
+
61
+ def load_metric_from_model(model_dir):
62
+ """
63
+ Load the model state from a given directory.
64
+ If the model has output dimension of 10, meaning
65
+ it was trained only on the symmetric part of the metric,
66
+ it reconstructs the full metric tensor.
67
+
68
+ Args:
69
+ model_dir (str): Directory containing the model state file.
70
+
71
+ Returns:
72
+ callable: The metric tensor function from the model.
73
+
74
+ """
75
+ with open(os.path.join(model_dir, "params.msgpack"), "rb") as f:
76
+ params = serialization.msgpack_restore(f.read())
77
+
78
+ with open(os.path.join(model_dir, "architecture.yml"), "r") as f:
79
+ config_model = yaml.load(f, Loader=yaml.FullLoader)
80
+
81
+ config_model = ConfigDict(config_model)
82
+ model = make_model(config_model.architecture)
83
+
84
+ if config_model.architecture.output_dim == 16:
85
+ return lambda coords: model.apply(params, coords).reshape(4, 4)
86
+ elif config_model.architecture.output_dim == 10:
87
+ return lambda coords: reconstruct_full_metric(model.apply(params, coords)).reshape(4, 4)
88
+
89
+ def reconstruct_full_metric(metric_sym: jax.Array, n : int) -> jax.Array:
90
+ """returns the fully reconstructed (n, n) metric tensor from the symmetry reduced metric"""
91
+ i, j = jnp.triu_indices(n, k=0)
92
+ matrix = jnp.zeros((n, n))
93
+ matrix = matrix.at[i, j].set(metric_sym)
94
+ matrix = matrix.at[j, i].set(metric_sym)
95
+ return matrix
96
+
97
  def get_model(model_name : str):
98
  """
99
  Get the model class by name.
 
125
  "first_omega_0": 4.,
126
  "hidden_omega_0": 4.,
127
  "scale": 5.,
128
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  }
131
  return model_configs
flax_models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (7.06 kB). View file
 
flax_models/__pycache__/activations.cpython-311.pyc ADDED
Binary file (5.17 kB). View file
 
flax_models/__pycache__/mlp.cpython-311.pyc ADDED
Binary file (2.88 kB). View file
 
flax_models/__pycache__/siren.cpython-311.pyc ADDED
Binary file (3.59 kB). View file
 
flax_models/__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.88 kB). View file
 
flax_models/__pycache__/wire.cpython-311.pyc ADDED
Binary file (6.84 kB). View file
 
activations.py → flax_models/activations.py RENAMED
File without changes
mlp.py → flax_models/mlp.py RENAMED
File without changes
siren.py → flax_models/siren.py RENAMED
File without changes
flax_models/tmp.txt DELETED
File without changes
utils.py → flax_models/utils.py RENAMED
File without changes
wire.py → flax_models/wire.py RENAMED
File without changes
gw/cartesian/silu/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 128
5
+ model_name: MLP
6
+ num_layers: 5
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
gw/cartesian/silu/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a71c9e504d93ddbb04ac9c9d325d1b93e18d9d46e33b212a9d17c3157ef2074
3
+ size 341507
gw/cartesian/silu/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - true
6
+ grid_range:
7
+ - - 0.0
8
+ - 10.0
9
+ - - 0.0
10
+ - 10.0
11
+ - - 0.0
12
+ - 10.0
13
+ - - 0.0
14
+ - 10.0
15
+ grid_shape:
16
+ - 140
17
+ - 10
18
+ - 10
19
+ - 140
gw/cartesian/siren/architecture.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args:
4
+ omega_0: 1.0
5
+ hidden_dim: 128
6
+ model_name: SIREN
7
+ num_layers: 5
8
+ output_dim: 16
9
+ training:
10
+ metric_type: distortion
gw/cartesian/siren/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea59e055998adf61be227c9c1b9a011348e6be52a621c353061c2567c0cd018d
3
+ size 341591
gw/cartesian/siren/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - true
6
+ grid_range:
7
+ - - 0.0
8
+ - 10.0
9
+ - - 0.0
10
+ - 10.0
11
+ - - 0.0
12
+ - 10.0
13
+ - - 0.0
14
+ - 10.0
15
+ grid_shape:
16
+ - 140
17
+ - 10
18
+ - 10
19
+ - 140
gw/cartesian/wire/architecture.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args:
4
+ first_omega_0: 1
5
+ hidden_omega_0: 1
6
+ scale: 1
7
+ hidden_dim: 90
8
+ model_name: WIRE
9
+ num_layers: 5
10
+ output_dim: 16
11
+ training:
12
+ metric_type: distortion
gw/cartesian/wire/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0d95701e9c4d6e037afbac8021d9fb9b0388b65180fb10ec0a2089fadc9b3b7
3
+ size 337857
gw/cartesian/wire/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - true
6
+ grid_range:
7
+ - - 0.0
8
+ - 10.0
9
+ - - 0.0
10
+ - 10.0
11
+ - - 0.0
12
+ - 10.0
13
+ - - 0.0
14
+ - 10.0
15
+ grid_shape:
16
+ - 140
17
+ - 10
18
+ - 10
19
+ - 140
gw/tmp.txt DELETED
File without changes
kerr/boyer_lindquist/a_0.623/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 190
5
+ model_name: MLP
6
+ num_layers: 6
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
kerr/boyer_lindquist/a_0.623/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed0b22feeba09f6106b75fdcc160c0abf76e7e654c8aa0986f6f163b442e4dcc
3
+ size 742275
kerr/boyer_lindquist/a_0.623/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - false
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - 3.0
10
+ - 8.0
11
+ - - 0.01
12
+ - 3.1315926535897933
13
+ - - 0.0
14
+ - 6.283185307179586
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
kerr/boyer_lindquist/a_0.628/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 190
5
+ model_name: MLP
6
+ num_layers: 6
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
kerr/boyer_lindquist/a_0.628/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5be530410902ee569bbccadadef0827b022196b413812cc9d77564d01ebcb0f1
3
+ size 742275
kerr/boyer_lindquist/a_0.628/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - false
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - 3.0
10
+ - 8.0
11
+ - - 0.01
12
+ - 3.1315926535897933
13
+ - - 0.0
14
+ - 6.283185307179586
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
kerr/boyer_lindquist/prograde/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 190
5
+ model_name: MLP
6
+ num_layers: 6
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
kerr/boyer_lindquist/prograde/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ac2705cc1280e5f97246ca752cf87384b1fc76ca2329259dfd540151ffec27a
3
+ size 742275
kerr/boyer_lindquist/prograde/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - false
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - 3.0
10
+ - 8.0
11
+ - - 0.01
12
+ - 3.1315926535897933
13
+ - - 0.0
14
+ - 6.283185307179586
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
kerr/boyer_lindquist/zackiger/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 190
5
+ model_name: MLP
6
+ num_layers: 6
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
kerr/boyer_lindquist/zackiger/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82cecefe0018bb127fb5c7950ee7569793fd246736ee93839fef1bcf143a1e27
3
+ size 742275
kerr/boyer_lindquist/zackiger/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - false
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - 6.0
10
+ - 14.0
11
+ - - 0.01
12
+ - 3.1315926535897933
13
+ - - 0.0
14
+ - 6.283185307179586
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
kerr/kerr_schild_cartesian/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 190
5
+ model_name: MLP
6
+ num_layers: 5
7
+ output_dim: 16
8
+ training:
9
+ metric_type: full_flatten
kerr/kerr_schild_cartesian/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9b7a968b6f2557abfad7e0f7267cbe20a2a57ef9d72fbf3c63a28b595663d1e
3
+ size 742275
kerr/kerr_schild_cartesian/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - true
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - -3.0
10
+ - 3.0
11
+ - - -3.0
12
+ - 3.0
13
+ - - 0.1
14
+ - 3.0
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
kerr/tmp.txt DELETED
File without changes
schwarzschild/spherical/close_event_horizon/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 256
5
+ model_name: MLP
6
+ num_layers: 5
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
schwarzschild/spherical/close_event_horizon/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22f2792349d0281035e6eed84294d7422f7391f99d196ab0e3aa1a0b876cd8c
3
+ size 1337877
schwarzschild/spherical/close_event_horizon/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - false
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - 2.1
10
+ - 50.0
11
+ - - 0.01
12
+ - 3.1315926535897933
13
+ - - 0.0
14
+ - 6.283185307179586
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
schwarzschild/spherical/perihelion/architecture.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ architecture:
2
+ activation: silu
3
+ extra_model_args: {}
4
+ hidden_dim: 128
5
+ model_name: MLP
6
+ num_layers: 6
7
+ output_dim: 16
8
+ training:
9
+ metric_type: distortion
schwarzschild/spherical/perihelion/params.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a34941f366b9424b90f4807310aa937c966bcdbd781d027c3dfb4f8df309b1
3
+ size 407620
schwarzschild/spherical/perihelion/train_data.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ endpoint:
2
+ - true
3
+ - true
4
+ - true
5
+ - false
6
+ grid_range:
7
+ - - 0.0
8
+ - 0.0
9
+ - - 5.0
10
+ - 140.0
11
+ - - 0.01
12
+ - 3.1315926535897933
13
+ - - 0.0
14
+ - 6.283185307179586
15
+ grid_shape:
16
+ - 1
17
+ - 128
18
+ - 128
19
+ - 128
schwarzschild/tmp.txt DELETED
File without changes