File size: 8,410 Bytes
d278d9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class HierarchicalPositionEncoding(nn.Module):
"""
Hierarchical Position Encoding that captures position information at multiple scales:
- Fine-grained local position (token level)
- Medium-scale position (segment level)
- Coarse-grained position (document level)
"""
def __init__(self, d_model, max_len=1024, base=10000):
super().__init__()
self.d_model = d_model
self.max_len = max_len
self.base = base
# Split embedding dimensions for different scales
self.local_dim = d_model // 2
self.segment_dim = d_model // 4
self.doc_dim = d_model - self.local_dim - self.segment_dim
# Create position encodings for different scales
self.register_buffer('local_pe', self._create_pe(max_len, self.local_dim))
self.register_buffer('segment_pe', self._create_pe(max_len//8, self.segment_dim))
self.register_buffer('doc_pe', self._create_pe(max_len//32, self.doc_dim))
def _create_pe(self, max_len, d_model):
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(self.base) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0)
def forward(self, x):
B, T, C = x.shape
# Get positional encodings at different scales
local_pos = self.local_pe[:, :T, :]
segment_pos = self.segment_pe[:, :(T//8), :].repeat_interleave(8, dim=1)[:, :T, :]
doc_pos = self.doc_pe[:, :(T//32), :].repeat_interleave(32, dim=1)[:, :T, :]
# Combine all scales
pos_encoding = torch.cat([local_pos, segment_pos, doc_pos], dim=-1)
return pos_encoding
class MultiScaleAttention(nn.Module):
"""
Multi-scale attention mechanism that processes information at different temporal scales
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
def forward(self, x):
B, T, C = x.shape # batch size, sequence length, embedding dimensionality
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = MultiScaleAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
act = nn.GELU(),
dropout = nn.Dropout(config.dropout),
))
m = self.mlp
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlpf(self.ln_2(x))
return x
class GPTModified(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
hpe = HierarchicalPositionEncoding(config.n_embd, config.block_size),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
# Apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# Report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wte.weight.numel()
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# Forward pass
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.hpe(tok_emb) # position embeddings of shape (b, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
# If we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
# If the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# Forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# Pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# Optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# Append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
|