Spaces:
Sleeping
Sleeping
| import hydra | |
| import torch | |
| import torch.nn as nn | |
| from einops.layers.torch import Rearrange | |
| from nn.inr import make_functional, params_to_tensor, wrap_func | |
| class GraphProbeFeatures(nn.Module): | |
| def __init__(self, d_in, num_inputs, inr_model, input_init=None, proj_dim=None): | |
| super().__init__() | |
| inr = hydra.utils.instantiate(inr_model) | |
| fmodel, params = make_functional(inr) | |
| vparams, vshapes = params_to_tensor(params) | |
| self.sirens = torch.vmap(wrap_func(fmodel, vshapes)) | |
| inputs = ( | |
| input_init | |
| if input_init is not None | |
| else 2 * torch.rand(1, num_inputs, d_in) - 1 | |
| ) | |
| self.inputs = nn.Parameter(inputs, requires_grad=input_init is None) | |
| self.reshape_weights = Rearrange("b i o 1 -> b (o i)") | |
| self.reshape_biases = Rearrange("b o 1 -> b o") | |
| self.proj_dim = proj_dim | |
| if proj_dim is not None: | |
| self.proj = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Linear(num_inputs, proj_dim), | |
| nn.LayerNorm(proj_dim), | |
| ) | |
| for _ in range(inr.num_layers + 1) | |
| ] | |
| ) | |
| def forward(self, weights, biases): | |
| weights = [self.reshape_weights(w) for w in weights] | |
| biases = [self.reshape_biases(b) for b in biases] | |
| params_flat = torch.cat( | |
| [w_or_b for p in zip(weights, biases) for w_or_b in p], dim=-1 | |
| ) | |
| out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1)) | |
| if self.proj_dim is not None: | |
| out = [proj(out[i].permute(0, 2, 1)) for i, proj in enumerate(self.proj)] | |
| out = torch.cat(out, dim=1) | |
| return out | |
| else: | |
| out = torch.cat(out, dim=-1) | |
| return out.permute(0, 2, 1) | |