File size: 4,011 Bytes
2b59497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import jax.numpy as jnp
from flax import linen as nn
from typing import Any

import jax
from .utils import custom_uniform
from jax.nn.initializers import Initializer
    

def complex_kernel_uniform_init(numerator : float = 6,
                                 mode : str = "fan_in",
                                dtype : jnp.dtype = jnp.float32,
                                distribution: str = "uniform") -> Initializer:
    def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
        real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
        imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)

        return real_kernel + 1j * imag_kernel
        
    return init


class WIRE(nn.Module):
    output_dim: int
    hidden_dim: int
    num_layers: int
    hidden_omega_0: float
    first_omega_0: float
    scale: float
    complexgabor: bool = False
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        if self.complexgabor:
            WIRElayer = ComplexGaborLayer
            dtype = jnp.complex64
        else:
            WIRElayer = RealGaborLayer
            dtype = self.dtype
        self.kernel_net = [
            WIRElayer(
                output_dim=self.hidden_dim,
                omega_0=self.first_omega_0,
                s_0=self.scale,
                is_first_layer=True,
                dtype=dtype
            )
        ] + [
            WIRElayer(
                output_dim=self.hidden_dim,
                omega_0=self.hidden_omega_0,
                s_0=self.scale,
                is_first_layer=False,
                dtype=dtype
            )
            for _ in range(self.num_layers)
        ]

        self.output_linear = nn.Dense(
            features=self.output_dim,
            use_bias=True,
            kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"),
            param_dtype=self.dtype,
        )

    def __call__(self, x):
        for layer in self.kernel_net:
            x = layer(x)

        out = jnp.real(self.output_linear(x))

        return out


class ComplexGaborLayer(nn.Module):
    output_dim: int
    omega_0: float
    s_0: float
    is_first_layer: bool = False
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        c = 1 if self.is_first_layer else 6 / self.omega_0**2
        distrib = "uniform_squared" if self.is_first_layer else "uniform"

        if self.is_first_layer:
            dtype = self.dtype
        else:
            dtype = jnp.complex64

        self.linear = nn.Dense(
            features=self.output_dim,
            use_bias=True,
            kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib),
            param_dtype=dtype
        )

    def __call__(self, x):
        omega = self.omega_0 * self.linear(x)
        scale = self.s_0 * self.linear(x)

        return jnp.exp(1j * omega - (jnp.abs(scale)**2))


class RealGaborLayer(nn.Module):
    output_dim: int
    omega_0: float
    s_0: float
    is_first_layer: bool = False
    dtype: jnp.dtype = jnp.float32

    def setup(self):

        c = 1 if self.is_first_layer else 6 / self.omega_0**2
        distrib = "uniform_squared" if self.is_first_layer else "uniform"

        self.freqs = nn.Dense(
            features=self.output_dim,
            kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
            use_bias=True,
            param_dtype=self.dtype
        )

        self.scales = nn.Dense(
            features = self.output_dim,
            kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
            use_bias=True,
            param_dtype=self.dtype
        )

    def __call__(self, x):
        omega = self.omega_0 * self.freqs(x)
        scale = self.s_0 * self.scales(x)

        return jnp.cos(omega) * jnp.exp(-(scale**2))