File size: 4,011 Bytes
2b59497 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import jax.numpy as jnp
from flax import linen as nn
from typing import Any
import jax
from .utils import custom_uniform
from jax.nn.initializers import Initializer
def complex_kernel_uniform_init(numerator : float = 6,
mode : str = "fan_in",
dtype : jnp.dtype = jnp.float32,
distribution: str = "uniform") -> Initializer:
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
return real_kernel + 1j * imag_kernel
return init
class WIRE(nn.Module):
output_dim: int
hidden_dim: int
num_layers: int
hidden_omega_0: float
first_omega_0: float
scale: float
complexgabor: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
if self.complexgabor:
WIRElayer = ComplexGaborLayer
dtype = jnp.complex64
else:
WIRElayer = RealGaborLayer
dtype = self.dtype
self.kernel_net = [
WIRElayer(
output_dim=self.hidden_dim,
omega_0=self.first_omega_0,
s_0=self.scale,
is_first_layer=True,
dtype=dtype
)
] + [
WIRElayer(
output_dim=self.hidden_dim,
omega_0=self.hidden_omega_0,
s_0=self.scale,
is_first_layer=False,
dtype=dtype
)
for _ in range(self.num_layers)
]
self.output_linear = nn.Dense(
features=self.output_dim,
use_bias=True,
kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"),
param_dtype=self.dtype,
)
def __call__(self, x):
for layer in self.kernel_net:
x = layer(x)
out = jnp.real(self.output_linear(x))
return out
class ComplexGaborLayer(nn.Module):
output_dim: int
omega_0: float
s_0: float
is_first_layer: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
c = 1 if self.is_first_layer else 6 / self.omega_0**2
distrib = "uniform_squared" if self.is_first_layer else "uniform"
if self.is_first_layer:
dtype = self.dtype
else:
dtype = jnp.complex64
self.linear = nn.Dense(
features=self.output_dim,
use_bias=True,
kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib),
param_dtype=dtype
)
def __call__(self, x):
omega = self.omega_0 * self.linear(x)
scale = self.s_0 * self.linear(x)
return jnp.exp(1j * omega - (jnp.abs(scale)**2))
class RealGaborLayer(nn.Module):
output_dim: int
omega_0: float
s_0: float
is_first_layer: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
c = 1 if self.is_first_layer else 6 / self.omega_0**2
distrib = "uniform_squared" if self.is_first_layer else "uniform"
self.freqs = nn.Dense(
features=self.output_dim,
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
use_bias=True,
param_dtype=self.dtype
)
self.scales = nn.Dense(
features = self.output_dim,
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
use_bias=True,
param_dtype=self.dtype
)
def __call__(self, x):
omega = self.omega_0 * self.freqs(x)
scale = self.s_0 * self.scales(x)
return jnp.cos(omega) * jnp.exp(-(scale**2))
|