|
""" |
|
A custom model for causal language modeling, compatible with HuggingFace. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
|
class SSLLMConfig(PretrainedConfig): |
|
"""Configuration class for SSLLM model compatible with HuggingFace.""" |
|
|
|
model_type = "ssllm" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=100277, |
|
d_model=768, |
|
num_heads=12, |
|
num_layers=10, |
|
d_ff=2560, |
|
max_seq_len=1024, |
|
dropout_rate=0.1, |
|
attention_dropout=0.1, |
|
stochastic_depth_rate=0.1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.num_layers = num_layers |
|
self.d_ff = d_ff |
|
self.max_seq_len = max_seq_len |
|
self.dropout_rate = dropout_rate |
|
self.attention_dropout = attention_dropout |
|
self.stochastic_depth_rate = stochastic_depth_rate |
|
|
|
|
|
self.hidden_size = d_model |
|
self.num_attention_heads = num_heads |
|
self.num_hidden_layers = num_layers |
|
self.intermediate_size = d_ff |
|
self.max_position_embeddings = max_seq_len |
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
"""Multi-head self-attention module matching SSLLM exactly.""" |
|
|
|
def __init__(self, d_model, num_heads, attention_dropout, dropout_rate): |
|
super().__init__() |
|
self.attention = nn.MultiheadAttention( |
|
d_model, |
|
num_heads, |
|
dropout=attention_dropout, |
|
bias=True, |
|
batch_first=True |
|
) |
|
self.resid_dropout = nn.Dropout(dropout_rate) |
|
|
|
def forward(self, x): |
|
B, T, C = x.size() |
|
|
|
|
|
causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() |
|
|
|
|
|
attn_output, _ = self.attention(x, x, x, attn_mask=causal_mask, is_causal=True) |
|
|
|
return self.resid_dropout(attn_output) |
|
|
|
class TransformerBlock(nn.Module): |
|
"""Transformer block matching SSLLM exactly.""" |
|
|
|
def __init__(self, d_model, num_heads, d_ff, dropout_rate, attention_dropout, stochastic_depth_rate): |
|
super().__init__() |
|
self.attn = MultiHeadSelfAttention(d_model, num_heads, attention_dropout, dropout_rate) |
|
self.ff = nn.Sequential( |
|
nn.Linear(d_model, d_ff), |
|
nn.GELU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(d_ff, d_model), |
|
nn.Dropout(dropout_rate) |
|
) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.drop_path = nn.Dropout(stochastic_depth_rate) if stochastic_depth_rate > 0 else nn.Identity() |
|
|
|
def forward(self, x): |
|
|
|
normed_x = self.norm1(x) |
|
attn_out = self.attn(normed_x) |
|
x = x + self.dropout(attn_out) |
|
|
|
|
|
normed_x = self.norm2(x) |
|
ff_out = self.ff(normed_x) |
|
x = x + self.dropout(ff_out) |
|
|
|
return x |
|
|
|
class SSLLMForCausalLM(PreTrainedModel, GenerationMixin): |
|
"""SSLLM model for causal language modeling, compatible with HuggingFace.""" |
|
|
|
config_class = SSLLMConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.token_embed = nn.Embedding(config.vocab_size, config.d_model) |
|
self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_model)) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
TransformerBlock( |
|
config.d_model, |
|
config.num_heads, |
|
config.d_ff, |
|
config.dropout_rate, |
|
config.attention_dropout, |
|
config.stochastic_depth_rate |
|
) for _ in range(config.num_layers) |
|
]) |
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.d_model) |
|
self.head = nn.Linear(config.d_model, config.vocab_size) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01) |
|
elif isinstance(module, nn.LayerNorm): |
|
torch.nn.init.ones_(module.weight) |
|
torch.nn.init.zeros_(module.bias) |
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, **kwargs): |
|
B, T = input_ids.size() |
|
|
|
|
|
tok_emb = self.token_embed(input_ids) |
|
pos_emb = self.pos_embed[:, :T, :] |
|
x = self.dropout(tok_emb + pos_emb) |
|
|
|
|
|
for block in self.blocks: |
|
x = block(x) |
|
|
|
x = self.ln_f(x) |
|
logits = self.head(x) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=None, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs): |
|
"""Prepare inputs for generation.""" |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"past_key_values": past_key_values, |
|
} |
|
|
|
def get_output_embeddings(self): |
|
"""Get output embeddings for generation.""" |
|
return self.head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
"""Set output embeddings.""" |
|
self.head = new_embeddings |
|
|
|
def get_input_embeddings(self): |
|
"""Get input embeddings.""" |
|
return self.token_embed |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
"""Set input embeddings.""" |
|
self.token_embed = new_embeddings |