|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
try:
|
|
from .configuration_snowflake_core import SnowflakeCoreConfig
|
|
except ImportError:
|
|
SnowflakeCoreConfig = PretrainedConfig
|
|
|
|
class FusedSelfAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
assert (
|
|
self.head_dim * num_heads == embed_dim
|
|
), "embed_dim must be divisible by num_heads"
|
|
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
|
B, T, C = x.size()
|
|
qkv = self.qkv_proj(x)
|
|
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
if attn_mask is not None:
|
|
attn_scores = attn_scores + attn_mask.unsqueeze(0).unsqueeze(0).to(attn_scores.dtype)
|
|
if key_padding_mask is not None:
|
|
attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
|
attn_probs = F.softmax(attn_scores, dim=-1)
|
|
attn_output = attn_probs @ v
|
|
attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
|
|
return self.out_proj(attn_output)
|
|
|
|
class GPTBlock(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, dropout=0.1):
|
|
super().__init__()
|
|
self.ln1 = nn.LayerNorm(embed_dim)
|
|
self.attn = FusedSelfAttention(embed_dim, num_heads)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.ln2 = nn.LayerNorm(embed_dim)
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(embed_dim, 4 * embed_dim),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(4 * embed_dim, embed_dim),
|
|
)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
|
h = self.ln1(x)
|
|
attn_output = self.attn(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
|
x = x + self.dropout1(attn_output)
|
|
x = x + self.dropout2(self.mlp(self.ln2(x)))
|
|
return x
|
|
|
|
class SnowflakeCoreG1(PreTrainedModel):
|
|
config_class = SnowflakeCoreConfig
|
|
supports_gradient_checkpointing = True
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_dim = config.embed_dim
|
|
self.num_heads = config.num_heads
|
|
self.num_layers = config.num_layers
|
|
self.max_length = config.max_length
|
|
self.ffn_dim = getattr(config, 'ffn_dim', 4 * config.embed_dim)
|
|
self.dropout = getattr(config, 'dropout', 0.1)
|
|
|
|
self.embed = nn.Embedding(self.vocab_size, self.embed_dim)
|
|
self.pos_embed = nn.Embedding(self.max_length, self.embed_dim)
|
|
self.dropout_layer = nn.Dropout(self.dropout)
|
|
self.blocks = nn.ModuleList([
|
|
GPTBlock(self.embed_dim, self.num_heads, self.dropout) for _ in range(self.num_layers)
|
|
])
|
|
self.ln_f = nn.LayerNorm(self.embed_dim)
|
|
self.lm_head = nn.Linear(self.embed_dim, self.vocab_size, bias=False)
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embed
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embed = value
|
|
|
|
def get_output_embeddings(self):
|
|
return self.lm_head
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.lm_head = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
**kwargs
|
|
) -> Tuple:
|
|
B, T = input_ids.size()
|
|
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
|
x = self.embed(input_ids) + self.pos_embed(pos)
|
|
x = self.dropout_layer(x)
|
|
causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool()
|
|
causal_mask = causal_mask.masked_fill(causal_mask, float('-inf'))
|
|
key_padding_mask = None
|
|
if attention_mask is not None:
|
|
key_padding_mask = attention_mask == 0
|
|
for block in self.blocks:
|
|
x = block(x, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
|
|
x = self.ln_f(x)
|
|
logits = self.lm_head(x)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.vocab_size)
|
|
shift_labels = labels[:, 1:].contiguous().view(-1)
|
|
loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=self.config.pad_token_id)
|
|
if loss is not None:
|
|
return {"loss": loss, "logits": logits}
|
|
return {"logits": logits}
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
|
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) |