File size: 1,377 Bytes
2b59497 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Flax
import jax.numpy as jnp
from flax import linen as nn
import jax
from typing import Callable
class MLP(nn.Module):
hidden_dim: int
output_dim: int
num_layers: int
act: Callable = nn.silu
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
x = nn.Dense(
features=self.hidden_dim,
use_bias=True,
kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
param_dtype=self.dtype
)(x)
x = self.act(x)
for _ in range(self.num_layers):
x = nn.Dense(
features=self.hidden_dim,
use_bias=True,
kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
param_dtype=self.dtype
)(x)
x = self.act(x)
x = nn.Dense(
features=self.output_dim,
use_bias=True,
kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
param_dtype=self.dtype
)(x)
return x
if __name__ == "__main__":
# Example usage
x = jax.random.uniform(jax.random.PRNGKey(0), (1, 3), minval=-3, maxval=3)
model = MLP(hidden_dim=32, output_dim=16, num_layers=3)
params = model.init(jax.random.PRNGKey(0), x)
model_fn = lambda params, x : model.apply(params, x)
print(model_fn(params, x).shape) |