File size: 5,359 Bytes
ce186be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from typing import Optional, Tuple
# Optional: import custom config if present
try:
from .configuration_snowflake_core import SnowflakeCoreConfig
except ImportError:
SnowflakeCoreConfig = PretrainedConfig
class FusedSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, attn_mask=None, key_padding_mask=None):
B, T, C = x.size()
qkv = self.qkv_proj(x) # [B, T, 3 * C]
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # Each: [B, num_heads, T, head_dim]
attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, num_heads, T, T]
if attn_mask is not None:
attn_scores = attn_scores + attn_mask.unsqueeze(0).unsqueeze(0).to(attn_scores.dtype)
if key_padding_mask is not None:
attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
attn_probs = F.softmax(attn_scores, dim=-1)
attn_output = attn_probs @ v # [B, num_heads, T, head_dim]
attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
return self.out_proj(attn_output)
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = FusedSelfAttention(embed_dim, num_heads)
self.dropout1 = nn.Dropout(dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(4 * embed_dim, embed_dim),
)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, attn_mask=None, key_padding_mask=None):
h = self.ln1(x)
attn_output = self.attn(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
x = x + self.dropout1(attn_output)
x = x + self.dropout2(self.mlp(self.ln2(x)))
return x
class SnowflakeCoreG1(PreTrainedModel):
config_class = SnowflakeCoreConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.embed_dim = config.embed_dim
self.num_heads = config.num_heads
self.num_layers = config.num_layers
self.max_length = config.max_length
self.ffn_dim = getattr(config, 'ffn_dim', 4 * config.embed_dim)
self.dropout = getattr(config, 'dropout', 0.1)
self.embed = nn.Embedding(self.vocab_size, self.embed_dim)
self.pos_embed = nn.Embedding(self.max_length, self.embed_dim)
self.dropout_layer = nn.Dropout(self.dropout)
self.blocks = nn.ModuleList([
GPTBlock(self.embed_dim, self.num_heads, self.dropout) for _ in range(self.num_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim)
self.lm_head = nn.Linear(self.embed_dim, self.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple:
B, T = input_ids.size()
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
x = self.embed(input_ids) + self.pos_embed(pos)
x = self.dropout_layer(x)
causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool()
causal_mask = causal_mask.masked_fill(causal_mask, float('-inf'))
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = attention_mask == 0
for block in self.blocks:
x = block(x, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.vocab_size)
shift_labels = labels[:, 1:].contiguous().view(-1)
loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=self.config.pad_token_id)
if loss is not None:
return {"loss": loss, "logits": logits}
return {"logits": logits}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) |