""" A custom model for causal language modeling, compatible with HuggingFace. """ import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions class SSLLMConfig(PretrainedConfig): """Configuration class for SSLLM model compatible with HuggingFace.""" model_type = "ssllm" def __init__( self, vocab_size=100277, d_model=768, num_heads=12, num_layers=10, d_ff=2560, max_seq_len=1024, dropout_rate=0.1, attention_dropout=0.1, stochastic_depth_rate=0.1, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.num_heads = num_heads self.num_layers = num_layers self.d_ff = d_ff self.max_seq_len = max_seq_len self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.stochastic_depth_rate = stochastic_depth_rate # HuggingFace compatibility self.hidden_size = d_model self.num_attention_heads = num_heads self.num_hidden_layers = num_layers self.intermediate_size = d_ff self.max_position_embeddings = max_seq_len class MultiHeadSelfAttention(nn.Module): """Multi-head self-attention module matching SSLLM exactly.""" def __init__(self, d_model, num_heads, attention_dropout, dropout_rate): super().__init__() self.attention = nn.MultiheadAttention( d_model, num_heads, dropout=attention_dropout, bias=True, batch_first=True ) self.resid_dropout = nn.Dropout(dropout_rate) def forward(self, x): B, T, C = x.size() # Create causal mask causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() # Apply attention attn_output, _ = self.attention(x, x, x, attn_mask=causal_mask, is_causal=True) return self.resid_dropout(attn_output) class TransformerBlock(nn.Module): """Transformer block matching SSLLM exactly.""" def __init__(self, d_model, num_heads, d_ff, dropout_rate, attention_dropout, stochastic_depth_rate): super().__init__() self.attn = MultiHeadSelfAttention(d_model, num_heads, attention_dropout, dropout_rate) self.ff = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(d_ff, d_model), nn.Dropout(dropout_rate) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout_rate) self.drop_path = nn.Dropout(stochastic_depth_rate) if stochastic_depth_rate > 0 else nn.Identity() def forward(self, x): # Pre-layer norm for attention normed_x = self.norm1(x) attn_out = self.attn(normed_x) x = x + self.dropout(attn_out) # Pre-layer norm for feed-forward normed_x = self.norm2(x) ff_out = self.ff(normed_x) x = x + self.dropout(ff_out) return x class SSLLMForCausalLM(PreTrainedModel, GenerationMixin): """SSLLM model for causal language modeling, compatible with HuggingFace.""" config_class = SSLLMConfig def __init__(self, config): super().__init__(config) self.token_embed = nn.Embedding(config.vocab_size, config.d_model) self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_model)) self.dropout = nn.Dropout(config.dropout_rate) # Create transformer blocks self.blocks = nn.ModuleList([ TransformerBlock( config.d_model, config.num_heads, config.d_ff, config.dropout_rate, config.attention_dropout, config.stochastic_depth_rate ) for _ in range(config.num_layers) ]) # Final layer norm and head self.ln_f = nn.LayerNorm(config.d_model) self.head = nn.Linear(config.d_model, config.vocab_size) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.01) elif isinstance(module, nn.LayerNorm): torch.nn.init.ones_(module.weight) torch.nn.init.zeros_(module.bias) def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, **kwargs): B, T = input_ids.size() # Embeddings tok_emb = self.token_embed(input_ids) pos_emb = self.pos_embed[:, :T, :] x = self.dropout(tok_emb + pos_emb) # Apply transformer blocks for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.head(x) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs): """Prepare inputs for generation.""" # If attention_mask is not provided, create one if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, } def get_output_embeddings(self): """Get output embeddings for generation.""" return self.head def set_output_embeddings(self, new_embeddings): """Set output embeddings.""" self.head = new_embeddings def get_input_embeddings(self): """Get input embeddings.""" return self.token_embed def set_input_embeddings(self, new_embeddings): """Set input embeddings.""" self.token_embed = new_embeddings