File size: 1,257 Bytes
5569e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn

import math



class ResamplerProjector(nn.Module):
    def __init__(self, proj_input_size, hidden_size):
        super().__init__()

        self.pre_proj_layernorm = torch.nn.LayerNorm(proj_input_size)

        self.mlp = nn.Sequential(
            nn.Linear(proj_input_size, hidden_size, bias=False),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size, bias=False),
        )
        self.mlp.apply(init_weights)
        self.pre_proj_layernorm.apply(init_weights)

    def forward(self, x, *args, **kwargs):
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        x = self.pre_proj_layernorm(x)
        x = self.mlp(x)
        # print(torch.distributed.get_rank(), {name: [param, param.grad] for name, param in self.pre_proj_layernorm.named_parameters()})
        # print(torch.distributed.get_rank(), {name: [param, param.grad] for name, param in self.mlp.named_parameters()})
        return x

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

    if isinstance(m, nn.LayerNorm):
        torch.nn.init.ones_(m.weight)
        torch.nn.init.zeros_(m.bias)