import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple, List, Dict, Any, Union class HindiCausalLMConfig: """Configuration class for Hindi Causal Language Model""" def __init__( self, vocab_size: int = 32000, hidden_size: int = 768, num_hidden_layers: int = 12, num_attention_heads: int = 12, intermediate_size: int = 3072, hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, max_position_embeddings: int = 512, layer_norm_eps: float = 1e-12, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings: bool = True, **kwargs ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.layer_norm_eps = layer_norm_eps self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.tie_word_embeddings = tie_word_embeddings # Add any additional kwargs as attributes for key, value in kwargs.items(): setattr(self, key, value) @classmethod def from_embedding_config(cls, config_dict, **kwargs): """Create LM config from embedding model config""" # Check if override parameters are provided override_params = {} for key in ["num_hidden_layers", "hidden_size", "num_attention_heads", "intermediate_size", "max_position_embeddings", "vocab_size"]: if key in kwargs: override_params[key] = kwargs.pop(key) # Get hidden size first to calculate appropriate number of attention heads hidden_size = override_params.get("hidden_size", config_dict.get("hidden_size", 768)) # If num_attention_heads is not provided, choose a value that divides hidden_size evenly if "num_attention_heads" not in override_params: # Default options to try: 12, 16, 8, 4 for heads in [12, 16, 8, 4]: if hidden_size % heads == 0: print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}") override_params["num_attention_heads"] = heads break # If none of the defaults work, find the largest factor <= 32 if "num_attention_heads" not in override_params: # Find the largest factor of hidden_size that is <= 32 for heads in range(min(32, hidden_size), 0, -1): if hidden_size % heads == 0: print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}") override_params["num_attention_heads"] = heads break # Build the config, with overrides taking precedence config_params = { "vocab_size": override_params.get("vocab_size", config_dict.get("vocab_size", 32000)), "hidden_size": hidden_size, "num_hidden_layers": override_params.get("num_hidden_layers", config_dict.get("num_hidden_layers", 12)), "num_attention_heads": override_params.get("num_attention_heads", config_dict.get("num_attention_heads", 12)), "intermediate_size": override_params.get("intermediate_size", config_dict.get("intermediate_size", 3072)), "hidden_dropout_prob": config_dict.get("hidden_dropout_prob", 0.1), "attention_probs_dropout_prob": config_dict.get("attention_probs_dropout_prob", 0.1), "max_position_embeddings": override_params.get("max_position_embeddings", config_dict.get("max_position_embeddings", 512)), "layer_norm_eps": config_dict.get("layer_norm_eps", 1e-12), "pad_token_id": config_dict.get("pad_token_id", 0), } # Verify that hidden_size is divisible by num_attention_heads if config_params["hidden_size"] % config_params["num_attention_heads"] != 0: raise ValueError( f"Hidden size ({config_params['hidden_size']}) must be divisible by the number of attention " f"heads ({config_params['num_attention_heads']})" ) # Add remaining kwargs config_params.update(kwargs) # Create and return the config lm_config = cls(**config_params) return lm_config def to_dict(self): """Convert config to dictionary""" return {k: v for k, v in self.__dict__.items()} class CausalSelfAttention(nn.Module): """Causal self-attention layer""" def __init__(self, config): super().__init__() assert config.hidden_size % config.num_attention_heads == 0 self.num_attention_heads = config.num_attention_heads self.attention_head_size = config.hidden_size // config.num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size # Query, Key, Value projections self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) # Output projection self.output = nn.Sequential( nn.Linear(self.all_head_size, config.hidden_size), nn.Dropout(config.attention_probs_dropout_prob) ) # Causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( "causal_mask", torch.triu( torch.ones(config.max_position_embeddings, config.max_position_embeddings) * -1e10, diagonal=1 ) ) def transpose_for_scores(self, x): # Reshape from [batch_size, seq_length, hidden_size] to [batch_size, seq_length, num_heads, head_size] new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_shape) # Transpose to [batch_size, num_heads, seq_length, head_size] return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask=None): batch_size, seq_length = hidden_states.size()[:2] # Project inputs to queries, keys, and values query_layer = self.transpose_for_scores(self.query(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) # Scale dot-product attention attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply causal mask - prevents attending to future tokens causal_mask = self.causal_mask[:seq_length, :seq_length] attention_scores = attention_scores + causal_mask # Apply attention mask if provided if attention_mask is not None: # Expand mask to match attention_scores shape attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (1.0 - attention_mask) * -10000.0 attention_scores = attention_scores + attention_mask # Softmax normalization attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = F.dropout(attention_probs, p=0.1, training=self.training) # Apply attention to values context_layer = torch.matmul(attention_probs, value_layer) # Reshape back to [batch_size, seq_length, hidden_size] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_shape) # Final output projection output = self.output(context_layer) return output class TransformerBlock(nn.Module): """Transformer block with causal attention for language modeling""" def __init__(self, config): super().__init__() self.attention = CausalSelfAttention(config) self.attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Feed-forward network self.ffn = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size), nn.Dropout(config.hidden_dropout_prob) ) self.ffn_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None): # Self-attention block with residual connection and layer norm attn_output = self.attention(hidden_states, attention_mask) hidden_states = self.attention_layernorm(hidden_states + attn_output) # Feed-forward block with residual connection and layer norm ffn_output = self.ffn(hidden_states) hidden_states = self.ffn_layernorm(hidden_states + ffn_output) return hidden_states class HindiCausalLM(nn.Module): """Hindi Causal Language Model for text generation""" def __init__(self, config): super().__init__() self.config = config # Embeddings self.token_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # Transformer layers self.layers = nn.ModuleList([ TransformerBlock(config) for _ in range(config.num_hidden_layers) ]) # LM head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tie weights if configured if config.tie_word_embeddings: self.lm_head.weight = self.token_embeddings.weight # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights with small random values""" if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def get_input_embeddings(self): return self.token_embeddings def set_input_embeddings(self, new_embeddings): self.token_embeddings = new_embeddings def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=True ): device = input_ids.device batch_size, seq_length = input_ids.size() # Create position ids position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Get embeddings inputs_embeds = self.token_embeddings(input_ids) position_embeds = self.position_embeddings(position_ids) # Sum token and position embeddings hidden_states = inputs_embeds + position_embeds hidden_states = self.embedding_dropout(hidden_states) # Default attention mask (all tokens can be attended to) if attention_mask is None: attention_mask = torch.ones(batch_size, seq_length, device=device) # Apply transformer layers for layer in self.layers: hidden_states = layer(hidden_states, attention_mask) # Language model head lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: # Move labels to correct device labels = labels.to(device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) if return_dict: return { "logits": lm_logits, "loss": loss, "hidden_states": hidden_states } return (lm_logits, loss, hidden_states) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): """Prepare inputs for text generation""" # Only keep inputs needed for forward pass inputs = { "input_ids": input_ids, } # Add attention mask if provided if attention_mask is not None: inputs["attention_mask"] = attention_mask return inputs @staticmethod def _reorder_cache(past, beam_idx): """Reorder cached hidden states for beam search generation""" reordered_past = [] for layer_past in past: reordered_past.append( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) ) return reordered_past def save_pretrained(self, save_directory): """Save model and config to directory""" import os import json import torch os.makedirs(save_directory, exist_ok=True) # Save config config_path = os.path.join(save_directory, "config.json") with open(config_path, "w", encoding="utf-8") as f: json.dump(self.config.to_dict(), f, indent=2) # Save model weights model_path = os.path.join(save_directory, "pytorch_model.bin") torch.save(self.state_dict(), model_path) return [config_path, model_path] @classmethod def from_pretrained(cls, model_path): """Load model and config from directory""" import os import json import torch # Load config config_path = os.path.join(model_path, "config.json") with open(config_path, "r", encoding="utf-8") as f: config_dict = json.load(f) # Create config object config = HindiCausalLMConfig(**config_dict) # Create model model = cls(config) # Load model weights model_path = os.path.join(model_path, "pytorch_model.bin") model.load_state_dict(torch.load(model_path, map_location="cpu")) return model class HindiTextGenerator: """Text generation utility for HindiCausalLM""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device def generate( self, prompt, max_length=100, temperature=1.0, top_k=50, top_p=0.95, repetition_penalty=1.0, do_sample=True, num_return_sequences=1, **kwargs ): """Generate text from a prompt""" # Encode the prompt input_ids = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.model.config.max_position_embeddings - max_length )["input_ids"].to(self.device) # Set the model to evaluation mode self.model.eval() # Set generation parameters gen_kwargs = { "max_length": input_ids.shape[1] + max_length, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "do_sample": do_sample, "num_return_sequences": num_return_sequences, **kwargs } # Generate text with torch.no_grad(): output_sequences = self._generate_text(input_ids, **gen_kwargs) # Decode generated sequences generated_texts = [] for sequence in output_sequences: # Remove the prompt from the generated text sequence = sequence[input_ids.shape[1]:] text = self.tokenizer.sp_model.DecodeIds(sequence.tolist()) generated_texts.append(text) if num_return_sequences == 1: return generated_texts[0] return generated_texts def _generate_text( self, input_ids, max_length, temperature=1.0, top_k=50, top_p=0.95, repetition_penalty=1.0, do_sample=True, num_return_sequences=1, pad_token_id=None, eos_token_id=None, **kwargs ): """Core text generation logic""" # Set pad_token_id and eos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.model.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.model.config.eos_token_id batch_size = input_ids.shape[0] # Create attention mask attention_mask = torch.ones_like(input_ids) # Clone the input_ids for each return sequence input_ids = input_ids.repeat(num_return_sequences, 1) attention_mask = attention_mask.repeat(num_return_sequences, 1) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) # Keep track of which sequences are already finished cur_len = input_ids.shape[1] while cur_len < max_length: # Prepare model inputs model_inputs = self.model.prepare_inputs_for_generation( input_ids, attention_mask=attention_mask ) # Forward pass outputs = self.model(**model_inputs, return_dict=True) next_token_logits = outputs["logits"][:, -1, :] # Apply temperature scaling next_token_logits = next_token_logits / temperature # Apply repetition penalty if repetition_penalty != 1.0: for i in range(input_ids.shape[0]): for token_id in set(input_ids[i].tolist()): next_token_logits[i, token_id] /= repetition_penalty # Filter logits using top-k and top-p sampling if do_sample: # Top-k filtering if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = -float("Inf") # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 for i in range(next_token_logits.shape[0]): indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] next_token_logits[i, indices_to_remove] = -float("Inf") # Sample from the filtered distribution probabilities = F.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probabilities, 1).squeeze(1) else: # Greedy decoding next_tokens = torch.argmax(next_token_logits, dim=-1) # Update unfinished sequences based on EOS token if eos_token_id is not None: # Set the unfinished flag to 0 if the sequence reaches EOS unfinished_sequences = unfinished_sequences.mul( next_tokens.ne(eos_token_id).long() ) # Check if all sequences are finished if unfinished_sequences.max() == 0: break # Concatenate next tokens to input_ids input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1) # Expand attention mask attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) cur_len += 1 return input_ids