import torch import torch.nn as nn import torch.nn.functional as F import xformers.ops as xops class SmallGPT(nn.Module): def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, max_length=128, pad_idx=0): super().__init__() self.d_model = d_model self.max_length = max_length # Embeddings self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx) self.position_embedding = nn.Embedding(max_length, d_model) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads) for _ in range(n_layers) ]) # Output self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) # Init weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.03) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.03) def forward(self, x): batch_size, seq_len = x.size() # position indices pos = torch.arange(0, seq_len, dtype=torch.long, device=x.device) pos = pos.unsqueeze(0).expand(batch_size, seq_len) # Embeddings tok_emb = self.token_embedding(x) pos_emb = self.position_embedding(pos) x = tok_emb + pos_emb # Transformer blocks for block in self.blocks: x = block(x) # Final layer norm and projection x = self.ln_f(x) logits = self.head(x) return logits class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = CausalSelfAttention(d_model, n_heads) self.ln2 = nn.LayerNorm(d_model) self.mlp = MLP(d_model) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_model = d_model self.head_dim = d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model) self.proj = nn.Linear(d_model, d_model) def forward(self, x): batch, seq_len, d_model = x.size() qkv = self.qkv(x) # [B, S, 3*D] q, k, v = qkv.chunk(3, dim=-1) # reshape into heads q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # [B, H, S, Hd] k = k.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # flatten for xformers: [B*H, S, Hd] q = q.reshape(batch * self.n_heads, seq_len, self.head_dim) k = k.reshape(batch * self.n_heads, seq_len, self.head_dim) v = v.reshape(batch * self.n_heads, seq_len, self.head_dim) # apply memory-efficient attention with causal mask out = xops.memory_efficient_attention(q, k, v, attn_bias=xops.LowerTriangularMask()) # out: [B*H, S, Hd] # reshape back out = out.view(batch, self.n_heads, seq_len, self.head_dim) out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model) return self.proj(out) class MLP(nn.Module): def __init__(self, d_model): super().__init__() self.fc1 = nn.Linear(d_model, 4 * d_model) self.fc2 = nn.Linear(4 * d_model, d_model) self.silu = nn.SiLU() def forward(self, x): x = self.fc1(x) x = self.silu(x) x = self.fc2(x) return x DEFAULT_CONFIG = { "vocab_size": 24_005, "d_model": 256, "n_heads": 8, "n_layers": 6, "max_length": 128, }