ssllm_hf / ssllm_hf.py
sausheong's picture
cleaned up
744de0d
"""
A custom model for causal language modeling, compatible with HuggingFace.
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class SSLLMConfig(PretrainedConfig):
"""Configuration class for SSLLM model compatible with HuggingFace."""
model_type = "ssllm"
def __init__(
self,
vocab_size=100277,
d_model=768,
num_heads=12,
num_layers=10,
d_ff=2560,
max_seq_len=1024,
dropout_rate=0.1,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.d_ff = d_ff
self.max_seq_len = max_seq_len
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout
self.stochastic_depth_rate = stochastic_depth_rate
# HuggingFace compatibility
self.hidden_size = d_model
self.num_attention_heads = num_heads
self.num_hidden_layers = num_layers
self.intermediate_size = d_ff
self.max_position_embeddings = max_seq_len
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention module matching SSLLM exactly."""
def __init__(self, d_model, num_heads, attention_dropout, dropout_rate):
super().__init__()
self.attention = nn.MultiheadAttention(
d_model,
num_heads,
dropout=attention_dropout,
bias=True,
batch_first=True
)
self.resid_dropout = nn.Dropout(dropout_rate)
def forward(self, x):
B, T, C = x.size()
# Create causal mask
causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
# Apply attention
attn_output, _ = self.attention(x, x, x, attn_mask=causal_mask, is_causal=True)
return self.resid_dropout(attn_output)
class TransformerBlock(nn.Module):
"""Transformer block matching SSLLM exactly."""
def __init__(self, d_model, num_heads, d_ff, dropout_rate, attention_dropout, stochastic_depth_rate):
super().__init__()
self.attn = MultiHeadSelfAttention(d_model, num_heads, attention_dropout, dropout_rate)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout_rate)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout_rate)
self.drop_path = nn.Dropout(stochastic_depth_rate) if stochastic_depth_rate > 0 else nn.Identity()
def forward(self, x):
# Pre-layer norm for attention
normed_x = self.norm1(x)
attn_out = self.attn(normed_x)
x = x + self.dropout(attn_out)
# Pre-layer norm for feed-forward
normed_x = self.norm2(x)
ff_out = self.ff(normed_x)
x = x + self.dropout(ff_out)
return x
class SSLLMForCausalLM(PreTrainedModel, GenerationMixin):
"""SSLLM model for causal language modeling, compatible with HuggingFace."""
config_class = SSLLMConfig
def __init__(self, config):
super().__init__(config)
self.token_embed = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_model))
self.dropout = nn.Dropout(config.dropout_rate)
# Create transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
config.d_model,
config.num_heads,
config.d_ff,
config.dropout_rate,
config.attention_dropout,
config.stochastic_depth_rate
) for _ in range(config.num_layers)
])
# Final layer norm and head
self.ln_f = nn.LayerNorm(config.d_model)
self.head = nn.Linear(config.d_model, config.vocab_size)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, **kwargs):
B, T = input_ids.size()
# Embeddings
tok_emb = self.token_embed(input_ids)
pos_emb = self.pos_embed[:, :T, :]
x = self.dropout(tok_emb + pos_emb)
# Apply transformer blocks
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs):
"""Prepare inputs for generation."""
# If attention_mask is not provided, create one
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}
def get_output_embeddings(self):
"""Get output embeddings for generation."""
return self.head
def set_output_embeddings(self, new_embeddings):
"""Set output embeddings."""
self.head = new_embeddings
def get_input_embeddings(self):
"""Get input embeddings."""
return self.token_embed
def set_input_embeddings(self, new_embeddings):
"""Set input embeddings."""
self.token_embed = new_embeddings