AndreiB137's picture
update file structure
2da731b
raw
history blame
5.76 kB
import jax
import math
from typing import Any, Dict, Sequence, Union
import jax.numpy as jnp
from jax import dtypes, random
from jax.nn.initializers import Initializer
from typing import Callable
from flax import linen as nn
class FourierEmbs(nn.Module):
embed_scale: float
embed_dim: int
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
kernel = self.param(
"kernel", jax.nn.initializers.normal(self.embed_scale, dtype=self.dtype), (x.shape[-1], self.embed_dim // 2)
)
y = jnp.concatenate(
[jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
)
return y
def _weight_fact(init_fn, mean, stddev, dtype=jnp.float32):
def init(key, shape):
key1, key2 = jax.random.split(key)
w = init_fn(key1, shape)
g = mean + nn.initializers.normal(stddev, dtype=dtype)(key2, (shape[-1],))
g = jnp.exp(g)
v = w / g
return g, v
return init
class Dense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.glorot_normal()
bias_init: Callable = nn.initializers.zeros
reparam : Union[None, Dict] = None
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
if self.reparam is None:
kernel = self.param(
"kernel", self.kernel_init(dtype=self.dtype), (x.shape[-1], self.features)
)
elif self.reparam["type"] == "weight_fact":
g, v = self.param(
"kernel",
_weight_fact(
self.kernel_init,
mean=self.reparam["mean"],
stddev=self.reparam["stddev"],
dtype=self.dtype
),
(x.shape[-1], self.features),
)
kernel = g * v
bias = self.param("bias", self.bias_init(dtype=self.dtype), (self.features,))
y = jnp.dot(x, kernel) + bias
return y
def _compute_fans(
shape: tuple,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Union[int, Sequence[int]] = (),
):
"""Compute effective input and output sizes for a linear or convolutional layer.
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of
a convolution (kernel spatial dimensions).
"""
if len(shape) <= 1:
raise ValueError(
f"Can't compute input and output sizes of a {shape.rank}"
"-dimensional weights tensor. Must be at least 2D."
)
if isinstance(in_axis, int):
in_size = shape[in_axis]
else:
in_size = math.prod([shape[i] for i in in_axis])
if isinstance(out_axis, int):
out_size = shape[out_axis]
else:
out_size = math.prod([shape[i] for i in out_axis])
if isinstance(batch_axis, int):
batch_size = shape[batch_axis]
else:
batch_size = math.prod([shape[i] for i in batch_axis])
receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size
return fan_in, fan_out
def custom_uniform(
numerator: float = 6,
mode: str = "fan_in",
dtype: jnp.dtype = jnp.float32,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
distribution: str = "uniform",
) -> Initializer:
"""Builds an initializer that returns real uniformly-distributed random arrays.
:param numerator: the numerator of the range of the random distribution.
:type numerator: float
:param mode: the mode for computing the range of the random distribution.
:type mode: str
:param dtype: optional; the initializer's default dtype.
:type dtype: jnp.dtype
:param in_axis: the axis or axes that specify the input size.
:type in_axis: Union[int, Sequence[int]]
:param out_axis: the axis or axes that specify the output size.
:type out_axis: Union[int, Sequence[int]]
:param batch_axis: the axis or axes that specify the batch size.
:type batch_axis: Sequence[int]
:param distribution: the distribution of the random distribution.
:type distribution: str
:return: An initializer that returns arrays whose values are uniformly distributed in
the range ``[-range, range)``.
:rtype: Initializer
"""
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
dtype = dtypes.canonicalize_dtype(dtype)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError(f"invalid mode for variance scaling initializer: {mode}")
if distribution == "uniform":
return random.uniform(
key,
shape,
dtype,
minval=-jnp.sqrt(numerator / denominator),
maxval=jnp.sqrt(numerator / denominator),
)
elif distribution == "normal":
return random.normal(key, shape, dtype) * jnp.sqrt(numerator / denominator)
elif distribution == "uniform_squared":
return random.uniform(
key, shape, dtype, minval=-numerator / denominator, maxval=numerator / denominator
)
else:
raise ValueError(
f"invalid distribution for variance scaling initializer: {distribution}"
)
return init