File size: 5,759 Bytes
2b59497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import jax
import math
from typing import Any, Dict, Sequence, Union

import jax.numpy as jnp
from jax import dtypes, random
from jax.nn.initializers import Initializer
from typing import Callable
from flax import linen as nn

class FourierEmbs(nn.Module):
    embed_scale: float
    embed_dim: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x):
        kernel = self.param(
            "kernel", jax.nn.initializers.normal(self.embed_scale, dtype=self.dtype), (x.shape[-1], self.embed_dim // 2)
        )
        y = jnp.concatenate(
            [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
        )
        return y

def _weight_fact(init_fn, mean, stddev, dtype=jnp.float32):
    def init(key, shape):
        key1, key2 = jax.random.split(key)
        w = init_fn(key1, shape)
        g = mean + nn.initializers.normal(stddev, dtype=dtype)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v

    return init

class Dense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.glorot_normal()
    bias_init: Callable = nn.initializers.zeros
    reparam : Union[None, Dict] = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param(
                "kernel", self.kernel_init(dtype=self.dtype), (x.shape[-1], self.features)
            )
        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(
                    self.kernel_init,
                    mean=self.reparam["mean"],
                    stddev=self.reparam["stddev"],
                    dtype=self.dtype
                ),
                (x.shape[-1], self.features),
            )

            kernel = g * v

        bias = self.param("bias", self.bias_init(dtype=self.dtype), (self.features,))

        y = jnp.dot(x, kernel) + bias

        return y


def _compute_fans(
    shape: tuple,
    in_axis: Union[int, Sequence[int]] = -2,
    out_axis: Union[int, Sequence[int]] = -1,
    batch_axis: Union[int, Sequence[int]] = (),
):
    """Compute effective input and output sizes for a linear or convolutional layer.

    Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of
    a convolution (kernel spatial dimensions).
    """
    if len(shape) <= 1:
        raise ValueError(
            f"Can't compute input and output sizes of a {shape.rank}"
            "-dimensional weights tensor. Must be at least 2D."
        )

    if isinstance(in_axis, int):
        in_size = shape[in_axis]
    else:
        in_size = math.prod([shape[i] for i in in_axis])
    if isinstance(out_axis, int):
        out_size = shape[out_axis]
    else:
        out_size = math.prod([shape[i] for i in out_axis])
    if isinstance(batch_axis, int):
        batch_size = shape[batch_axis]
    else:
        batch_size = math.prod([shape[i] for i in batch_axis])
    receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
    fan_in = in_size * receptive_field_size
    fan_out = out_size * receptive_field_size
    return fan_in, fan_out


def custom_uniform(
    numerator: float = 6,
    mode: str = "fan_in",
    dtype: jnp.dtype = jnp.float32,
    in_axis: Union[int, Sequence[int]] = -2,
    out_axis: Union[int, Sequence[int]] = -1,
    batch_axis: Sequence[int] = (),
    distribution: str = "uniform",
) -> Initializer:
    """Builds an initializer that returns real uniformly-distributed random arrays.

    :param numerator: the numerator of the range of the random distribution.
    :type numerator: float
    :param mode: the mode for computing the range of the random distribution.
    :type mode: str
    :param dtype: optional; the initializer's default dtype.
    :type dtype: jnp.dtype
    :param in_axis: the axis or axes that specify the input size.
    :type in_axis: Union[int, Sequence[int]]
    :param out_axis: the axis or axes that specify the output size.
    :type out_axis: Union[int, Sequence[int]]
    :param batch_axis: the axis or axes that specify the batch size.
    :type batch_axis: Sequence[int]
    :param distribution: the distribution of the random distribution.
    :type distribution: str

    :return: An initializer that returns arrays whose values are uniformly distributed in
        the range ``[-range, range)``.
    :rtype: Initializer
    """

    def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
        dtype = dtypes.canonicalize_dtype(dtype)
        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
        if mode == "fan_in":
            denominator = fan_in
        elif mode == "fan_out":
            denominator = fan_out
        elif mode == "fan_avg":
            denominator = (fan_in + fan_out) / 2
        else:
            raise ValueError(f"invalid mode for variance scaling initializer: {mode}")
        if distribution == "uniform":
            return random.uniform(
                key,
                shape,
                dtype,
                minval=-jnp.sqrt(numerator / denominator),
                maxval=jnp.sqrt(numerator / denominator),
            )
        elif distribution == "normal":
            return random.normal(key, shape, dtype) * jnp.sqrt(numerator / denominator)
        elif distribution == "uniform_squared":
            return random.uniform(
                key, shape, dtype, minval=-numerator / denominator, maxval=numerator / denominator
            )
        else:
            raise ValueError(
                f"invalid distribution for variance scaling initializer: {distribution}"
            )

    return init