File size: 4,267 Bytes
b3b6708 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers.ops as xops
class SmallGPT(nn.Module):
def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, max_length=128, pad_idx=0):
super().__init__()
self.d_model = d_model
self.max_length = max_length
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
self.position_embedding = nn.Embedding(max_length, d_model)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads) for _ in range(n_layers)
])
# Output
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
# Init weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.03)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.03)
def forward(self, x):
batch_size, seq_len = x.size()
# position indices
pos = torch.arange(0, seq_len, dtype=torch.long, device=x.device)
pos = pos.unsqueeze(0).expand(batch_size, seq_len)
# Embeddings
tok_emb = self.token_embedding(x)
pos_emb = self.position_embedding(pos)
x = tok_emb + pos_emb
# Transformer blocks
for block in self.blocks:
x = block(x)
# Final layer norm and projection
x = self.ln_f(x)
logits = self.head(x)
return logits
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = MLP(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_model = d_model
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x):
batch, seq_len, d_model = x.size()
qkv = self.qkv(x) # [B, S, 3*D]
q, k, v = qkv.chunk(3, dim=-1)
# reshape into heads
q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # [B, H, S, Hd]
k = k.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
# flatten for xformers: [B*H, S, Hd]
q = q.reshape(batch * self.n_heads, seq_len, self.head_dim)
k = k.reshape(batch * self.n_heads, seq_len, self.head_dim)
v = v.reshape(batch * self.n_heads, seq_len, self.head_dim)
# apply memory-efficient attention with causal mask
out = xops.memory_efficient_attention(q, k, v, attn_bias=xops.LowerTriangularMask())
# out: [B*H, S, Hd]
# reshape back
out = out.view(batch, self.n_heads, seq_len, self.head_dim)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.proj(out)
class MLP(nn.Module):
def __init__(self, d_model):
super().__init__()
self.fc1 = nn.Linear(d_model, 4 * d_model)
self.fc2 = nn.Linear(4 * d_model, d_model)
self.silu = nn.SiLU()
def forward(self, x):
x = self.fc1(x)
x = self.silu(x)
x = self.fc2(x)
return x
DEFAULT_CONFIG = {
"vocab_size": 24_005,
"d_model": 256,
"n_heads": 8,
"n_layers": 6,
"max_length": 128,
} |