Spaces:
Paused
Paused
| import torch.nn as nn | |
| import torch.nn.init as init | |
| import torch | |
| class Swish(nn.Module): | |
| def __init__(self): | |
| super(Swish, self).__init__() | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| class Adapter(nn.Module): | |
| def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): | |
| super().__init__() | |
| self.query_length = query_length | |
| self.dropout_prob = dropout_prob | |
| self.adapter_norm = adapter_norm | |
| self.dropout = nn.Dropout(p=self.dropout_prob) | |
| self.c_fc = nn.Linear(input_size, input_size*2) | |
| self.act = Swish() | |
| self.c_proj = nn.Linear(input_size*2, output_size) | |
| if adapter_norm == "layer_norm": | |
| self.norm = nn.LayerNorm([self.query_length, output_size]) | |
| elif adapter_norm == "batch_norm": | |
| self.norm = nn.BatchNorm1d(self.query_length) | |
| self.init_type = init_type.lower() | |
| self._initialize_weights() | |
| def forward(self, hidden_states): | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.c_fc(hidden_states) | |
| hidden_states = self.act(hidden_states) | |
| hidden_states = self.c_proj(hidden_states) | |
| hidden_states = self.norm(hidden_states) | |
| return hidden_states | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| if self.init_type == "glorot": | |
| init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| init.constant_(m.bias, 0) | |
| elif self.init_type == "normal": | |
| init.normal_(m.weight, mean=0, std=0.01) | |
| if m.bias is not None: | |
| init.constant_(m.bias, 0) | |
| else: | |
| raise ValueError("Invalid initialization type specified.") | |