File size: 7,153 Bytes
5cfa59c e58f28a 5cfa59c e58f28a 5cfa59c e58f28a 5cfa59c e58f28a 5cfa59c e58f28a 5cfa59c e58f28a 5cfa59c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import math
import torch
import torch.nn as nn
from torch.nn import functional as f
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from language_config import BigBrainLanguageConfig
def _make_casual_mask(size: int) -> torch.Tensor:
return torch.tril(torch.ones(size, size))
class RootMeanSquareNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_eps = eps
def forward(self, x: torch.Tensor):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_eps)
return self.weight * x
class MultiLayerPerceptron(nn.Module):
def __init__(self, config: BigBrainLanguageConfig):
super().__init__()
self.config = config
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim: int, base: int = 10000):
super().__init__()
self.dim = dim
self.base = base
self.cos = None
self.sin = None
def _build_cache(self, x: torch.Tensor):
if self.cos is not None and x.shape[0] <= self.cos.shape[0]:
return
seq_len = x.shape[0]
theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device)
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
idx_theta = torch.einsum('a,b->ab', seq_idx, theta)
idx_theta = torch.cat([idx_theta, idx_theta], dim=1)
self.cos = idx_theta.cos()[:, None, None, :]
self.sin = idx_theta.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
d_2 = self.dim // 2
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
self._build_cache(x)
x_rope, x_pass = x[..., :self.dim], x[..., self.dim:]
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos[:x.shape[0]]) + (neg_half_x * self.sin[:x.shape[0]])
return torch.cat((x_rope, x_pass), dim=-1)
class RotaryMultiHeadAttention(nn.Module):
def __init__(self, config: BigBrainLanguageConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
if (self.head_dim * config.num_attention_heads) != config.hidden_size:
raise ValueError('num_embedd must be evenly divisible by num_heads')
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.rope_e = RotaryPositionalEmbedding(self.head_dim, config.rope_theta)
def _shape(self, tensor: torch.Tensor, batch_size: int, seq_len: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reshape(self, tensor: torch.Tensor, batch_size: int, seq_len: int):
return tensor.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.hidden_size)
def forward(self, states: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
batch_size, seq_len, _ = states.size()
q_states = self.rope_e(self._shape(self.q_proj(states), batch_size, seq_len))
k_states = self.rope_e(self._shape(self.k_proj(states), batch_size, seq_len))
v_states = self._shape(self.v_proj(states), batch_size, seq_len)
attn_weights = torch.matmul(q_states, k_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.clamp(attn_weights, min=-1024.0, max=1024.0)
if mask is not None:
attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
attn_weights = f.softmax(attn_weights, dim=-1)
attn_outputs = torch.matmul(attn_weights, v_states)
return self._reshape(attn_outputs, batch_size, seq_len)
class BigBrainDecoderLayer(nn.Module):
def __init__(self, config: BigBrainLanguageConfig):
super().__init__()
self.config = config
self.self_attn = RotaryMultiHeadAttention(config)
self.feed_forward = MultiLayerPerceptron(config)
self.input_norm = RootMeanSquareNorm(config.hidden_size, config.layer_norm_eps)
self.attn_norm = RootMeanSquareNorm(config.hidden_size, config.layer_norm_eps)
self.register_buffer('attn_mask', _make_casual_mask(config.max_position_embeddings))
def forward(self, x: torch.Tensor):
batch_size, seq_len, _ = x.size()
mask = self.attn_mask[:seq_len, :seq_len]
x = x + self.self_attn(self.input_norm(x), mask)
x = x + self.feed_forward(self.attn_norm(x))
return x
class BigBrainLanguageModel(PreTrainedModel):
config_class = BigBrainLanguageConfig
base_model_prefix = 'big-brain-lm'
def __init__(self, config: BigBrainLanguageConfig):
super().__init__(config)
self.config = config
self.tok_embed = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.layers = nn.ModuleList([BigBrainDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RootMeanSquareNorm(config.hidden_size, config.layer_norm_eps)
self.linear = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(self, input_ids: torch.Tensor, target_ids: torch.Tensor = None):
hidden_states = self.tok_embed(input_ids)
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = self.linear(hidden_states)
if target_ids is None:
return hidden_states, None
b, t, c = hidden_states.size()
loss = f.cross_entropy(hidden_states.view(b * t, c), target_ids.view(b * t))
return hidden_states, loss
|