from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions import torch import torch.nn as nn from torch.nn import functional as F from transformers.modeling_outputs import CausalLMOutput class BVVAbsConfig(PretrainedConfig): model_type = "bvv_abs" def __init__( self, vocab_size = 131072, n_embd = 4096, n_head = 32, n_layer = 4, block_size = 1024, pad_id = 57344, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.block_size = block_size self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.pad_id = pad_id class RotaryEmbedding(nn.Module): def __init__(self, dim): # dim = head_dim (?? n_embd!) super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seq_len, device): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, dim) return emb def apply_rotary_emb(x, rot_emb): # x: (B, n_head, seq_len, head_dim) # rot_emb: (seq_len, head_dim) seq_len = x.shape[-2] rot_emb = rot_emb[:seq_len] cos = torch.cos(rot_emb).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim) sin = torch.sin(rot_emb).unsqueeze(0).unsqueeze(0) x_shape = x.shape x = x.reshape(*x_shape[:-1], -1, 2) # (..., head_dim/2, 2) x1 = x[..., 0] x2 = x[..., 1] cos = cos.reshape(*cos.shape[:-1], -1, 2)[..., 0] sin = sin.reshape(*sin.shape[:-1], -1, 2)[..., 0] x1_rot = x1 * cos - x2 * sin x2_rot = x1 * sin + x2 * cos x_rot = torch.stack([x1_rot, x2_rot], dim=-1) return x_rot.reshape(x_shape) class MultiHeadSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() assert n_embd % n_head == 0 self.n_embd = n_embd self.n_head = n_head self.head_dim = n_embd // n_head self.q_proj = nn.Linear(n_embd, n_embd, bias=False) self.k_proj = nn.Linear(n_embd, n_embd, bias=False) self.v_proj = nn.Linear(n_embd, n_embd, bias=False) self.o_proj = nn.Linear(n_embd, n_embd, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim) self.dropout = nn.Dropout(0.0) self.register_buffer( "tril", torch.tril(torch.ones(block_size, block_size)), persistent=False ) def forward(self, x): # x: (B, T, n_embd) B, T, C = x.shape q = self.q_proj(x) # (B, T, n_embd) k = self.k_proj(x) v = self.v_proj(x) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # Rotary embeddings rot_emb = self.rotary_emb(seq_len=T, device=x.device) # (T, head_dim) q = apply_rotary_emb(q, rot_emb) k = apply_rotary_emb(k, rot_emb) # Attention attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** -0.5) # (B, n_head, T, T) attn_scores = attn_scores.masked_fill(self.tril[:T, :T] == 0, float('-inf')) attn_probs = F.softmax(attn_scores, dim=-1) attn_probs = self.dropout(attn_probs) out = torch.matmul(attn_probs, v) # (B, n_head, T, head_dim) out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, n_embd) return self.o_proj(out) class TransformerMLP(nn.Module): def __init__(self, n_embd): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.0), ) def forward(self, x): return self.net(x) class TransformerBlock(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() self.self_attn = MultiHeadSelfAttention(n_embd, n_head, block_size) self.mlp = TransformerMLP(n_embd) self.input_layernorm = nn.LayerNorm(n_embd) self.post_attention_layernorm = nn.LayerNorm(n_embd) def forward(self, x): x = x + self.self_attn(self.input_layernorm(x)) x = x + self.mlp(self.post_attention_layernorm(x)) return x class BVVAbsForCausalLM(PreTrainedModel): config_class = BVVAbsConfig def __init__(self, config): super().__init__(config) self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd) self.transformer_layers = nn.Sequential(*[ TransformerBlock(config.n_embd, n_head=config.n_head, block_size=config.block_size) for _ in range(config.n_layer) ]) self.final_layernorm = nn.LayerNorm(config.n_embd) self.lm_head = nn.Linear(config.n_embd, config.vocab_size) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape x = self.token_embeddings(idx) x = self.transformer_layers(x) x = self.final_layernorm(x) logits = self.lm_head(x) loss = None if targets is not None: #logits_flat = logits.view(-1, logits.size(-1)) #targets_flat = targets.view(-1) logits_flat = logits.reshape(-1, logits.size(-1)) targets_flat = targets.reshape(-1) loss = F.cross_entropy(logits_flat, targets_flat, ignore_index = 57344) return CausalLMOutput( logits=logits, loss=loss, ) def generate(self, input_ids=None, max_new_tokens=None, max_length=None, temperature=1.0, top_k=None, top_p=None, do_sample=True, pad_token_id=None, eos_token_id=None, **kwargs): if input_ids is None: raise ValueError("Input_ids must be provided") idx = input_ids if max_new_tokens is None: if max_length is not None: max_new_tokens = max_length - idx.shape[1] else: max_new_tokens = 50 with torch.no_grad(): for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.block_size:] outputs = self(idx_cond) logits = outputs.logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) if do_sample: idx_next = torch.multinomial(probs, num_samples=1) else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) idx = torch.cat((idx, idx_next), dim=1) if eos_token_id is not None and (idx_next == eos_token_id).any(): break return idx