AndreiB137 commited on
Commit
3ee0667
·
verified ·
1 Parent(s): 9874994

Delete PirateNet.py

Browse files
Files changed (1) hide show
  1. PirateNet.py +0 -85
PirateNet.py DELETED
@@ -1,85 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- import flax.linen as nn
4
- from .utils import Dense, FourierEmbs
5
- from typing import Union, Dict, Callable
6
-
7
- class PIModifiedBottleneck(nn.Module):
8
- hidden_dim: int
9
- output_dim: int
10
- act: Callable
11
- nonlinearity: float
12
- reparam: Union[None, Dict]
13
- dtype: jnp.dtype = jnp.float32
14
-
15
- @nn.compact
16
- def __call__(self, x, u, v):
17
- identity = x
18
-
19
- x = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
20
- x = self.act(x)
21
-
22
- x = x * u + (1 - x) * v
23
-
24
- x = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
25
- x = self.act(x)
26
-
27
- x = x * u + (1 - x) * v
28
-
29
- x = Dense(features=self.output_dim, reparam=self.reparam, dtype=self.dtype)(x)
30
- x = self.act(x)
31
-
32
- alpha = self.param("alpha", nn.initializers.constant(self.nonlinearity), (1,))
33
- x = alpha * x + (1 - alpha) * identity
34
-
35
- return x
36
-
37
- class PirateNet(nn.Module):
38
- num_layers: int
39
- hidden_dim: int
40
- output_dim: int
41
- act: Callable = nn.silu
42
- nonlinearity: float = 0.0
43
- pi_init: Union[None, jnp.ndarray] = None
44
- reparam : Union[None, Dict] = None
45
- fourier_emb : Union[None, Dict] = None
46
- dtype: jnp.dtype = jnp.float32
47
-
48
- @nn.compact
49
- def __call__(self, x):
50
- embs = FourierEmbs(**self.fourier_emb)(x)
51
- x = embs
52
-
53
- u = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
54
- u = self.act(u)
55
-
56
- v = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
57
- v = self.act(v)
58
-
59
- for _ in range(self.num_layers):
60
- x = PIModifiedBottleneck(
61
- hidden_dim=self.hidden_dim,
62
- output_dim=x.shape[-1],
63
- act=self.act,
64
- nonlinearity=self.nonlinearity,
65
- reparam=self.reparam,
66
- dtype=self.dtype
67
- )(x, u, v)
68
-
69
- if self.pi_init is not None:
70
- kernel = self.param("pi_init", nn.initializers.constant(self.pi_init, dtype=self.dtype), self.pi_init.shape)
71
- y = jnp.dot(x, kernel)
72
-
73
- else:
74
- y = Dense(features=self.output_dim, reparam=self.reparam, dtype=self.dtype)(x)
75
-
76
- return x, y
77
-
78
- if __name__ == "__main__":
79
- # Example usage
80
- from activations import cauchy
81
- cauchy_mod = lambda x : cauchy()(x)
82
- model = PirateNet(num_layers=3, hidden_dim=32, output_dim=16, act=cauchy_mod, reparam=None, fourier_emb={'embed_scale': 1.0, 'embed_dim': 64})
83
- params = model.init(jax.random.PRNGKey(0), jnp.ones(3))
84
- output = model.apply(params, jnp.ones(3))
85
- print(params)