hindi-foundational-model-base / hindi_language_model.py
convaiinnovations's picture
Initial upload of custom Hindi LM v1
b253808 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List, Dict, Any, Union
class HindiCausalLMConfig:
"""Configuration class for Hindi Causal Language Model"""
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 768,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
intermediate_size: int = 3072,
hidden_dropout_prob: float = 0.1,
attention_probs_dropout_prob: float = 0.1,
max_position_embeddings: int = 512,
layer_norm_eps: float = 1e-12,
pad_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.layer_norm_eps = layer_norm_eps
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
# Add any additional kwargs as attributes
for key, value in kwargs.items():
setattr(self, key, value)
@classmethod
def from_embedding_config(cls, config_dict, **kwargs):
"""Create LM config from embedding model config"""
# Check if override parameters are provided
override_params = {}
for key in ["num_hidden_layers", "hidden_size", "num_attention_heads",
"intermediate_size", "max_position_embeddings", "vocab_size"]:
if key in kwargs:
override_params[key] = kwargs.pop(key)
# Get hidden size first to calculate appropriate number of attention heads
hidden_size = override_params.get("hidden_size", config_dict.get("hidden_size", 768))
# If num_attention_heads is not provided, choose a value that divides hidden_size evenly
if "num_attention_heads" not in override_params:
# Default options to try: 12, 16, 8, 4
for heads in [12, 16, 8, 4]:
if hidden_size % heads == 0:
print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}")
override_params["num_attention_heads"] = heads
break
# If none of the defaults work, find the largest factor <= 32
if "num_attention_heads" not in override_params:
# Find the largest factor of hidden_size that is <= 32
for heads in range(min(32, hidden_size), 0, -1):
if hidden_size % heads == 0:
print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}")
override_params["num_attention_heads"] = heads
break
# Build the config, with overrides taking precedence
config_params = {
"vocab_size": override_params.get("vocab_size", config_dict.get("vocab_size", 32000)),
"hidden_size": hidden_size,
"num_hidden_layers": override_params.get("num_hidden_layers", config_dict.get("num_hidden_layers", 12)),
"num_attention_heads": override_params.get("num_attention_heads", config_dict.get("num_attention_heads", 12)),
"intermediate_size": override_params.get("intermediate_size", config_dict.get("intermediate_size", 3072)),
"hidden_dropout_prob": config_dict.get("hidden_dropout_prob", 0.1),
"attention_probs_dropout_prob": config_dict.get("attention_probs_dropout_prob", 0.1),
"max_position_embeddings": override_params.get("max_position_embeddings",
config_dict.get("max_position_embeddings", 512)),
"layer_norm_eps": config_dict.get("layer_norm_eps", 1e-12),
"pad_token_id": config_dict.get("pad_token_id", 0),
}
# Verify that hidden_size is divisible by num_attention_heads
if config_params["hidden_size"] % config_params["num_attention_heads"] != 0:
raise ValueError(
f"Hidden size ({config_params['hidden_size']}) must be divisible by the number of attention "
f"heads ({config_params['num_attention_heads']})"
)
# Add remaining kwargs
config_params.update(kwargs)
# Create and return the config
lm_config = cls(**config_params)
return lm_config
def to_dict(self):
"""Convert config to dictionary"""
return {k: v for k, v in self.__dict__.items()}
class CausalSelfAttention(nn.Module):
"""Causal self-attention layer"""
def __init__(self, config):
super().__init__()
assert config.hidden_size % config.num_attention_heads == 0
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
# Query, Key, Value projections
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
# Output projection
self.output = nn.Sequential(
nn.Linear(self.all_head_size, config.hidden_size),
nn.Dropout(config.attention_probs_dropout_prob)
)
# Causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"causal_mask",
torch.triu(
torch.ones(config.max_position_embeddings, config.max_position_embeddings) * -1e10,
diagonal=1
)
)
def transpose_for_scores(self, x):
# Reshape from [batch_size, seq_length, hidden_size] to [batch_size, seq_length, num_heads, head_size]
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_shape)
# Transpose to [batch_size, num_heads, seq_length, head_size]
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None):
batch_size, seq_length = hidden_states.size()[:2]
# Project inputs to queries, keys, and values
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# Scale dot-product attention
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply causal mask - prevents attending to future tokens
causal_mask = self.causal_mask[:seq_length, :seq_length]
attention_scores = attention_scores + causal_mask
# Apply attention mask if provided
if attention_mask is not None:
# Expand mask to match attention_scores shape
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
attention_scores = attention_scores + attention_mask
# Softmax normalization
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
# Apply attention to values
context_layer = torch.matmul(attention_probs, value_layer)
# Reshape back to [batch_size, seq_length, hidden_size]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_shape)
# Final output projection
output = self.output(context_layer)
return output
class TransformerBlock(nn.Module):
"""Transformer block with causal attention for language modeling"""
def __init__(self, config):
super().__init__()
self.attention = CausalSelfAttention(config)
self.attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size),
nn.Dropout(config.hidden_dropout_prob)
)
self.ffn_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None):
# Self-attention block with residual connection and layer norm
attn_output = self.attention(hidden_states, attention_mask)
hidden_states = self.attention_layernorm(hidden_states + attn_output)
# Feed-forward block with residual connection and layer norm
ffn_output = self.ffn(hidden_states)
hidden_states = self.ffn_layernorm(hidden_states + ffn_output)
return hidden_states
class HindiCausalLM(nn.Module):
"""Hindi Causal Language Model for text generation"""
def __init__(self, config):
super().__init__()
self.config = config
# Embeddings
self.token_embeddings = nn.Embedding(
config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings,
config.hidden_size
)
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
# Transformer layers
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.num_hidden_layers)
])
# LM head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights if configured
if config.tie_word_embeddings:
self.lm_head.weight = self.token_embeddings.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights with small random values"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_input_embeddings(self):
return self.token_embeddings
def set_input_embeddings(self, new_embeddings):
self.token_embeddings = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
return_dict=True
):
device = input_ids.device
batch_size, seq_length = input_ids.size()
# Create position ids
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Get embeddings
inputs_embeds = self.token_embeddings(input_ids)
position_embeds = self.position_embeddings(position_ids)
# Sum token and position embeddings
hidden_states = inputs_embeds + position_embeds
hidden_states = self.embedding_dropout(hidden_states)
# Default attention mask (all tokens can be attended to)
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, device=device)
# Apply transformer layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# Language model head
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Move labels to correct device
labels = labels.to(device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
if return_dict:
return {
"logits": lm_logits,
"loss": loss,
"hidden_states": hidden_states
}
return (lm_logits, loss, hidden_states)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
"""Prepare inputs for text generation"""
# Only keep inputs needed for forward pass
inputs = {
"input_ids": input_ids,
}
# Add attention mask if provided
if attention_mask is not None:
inputs["attention_mask"] = attention_mask
return inputs
@staticmethod
def _reorder_cache(past, beam_idx):
"""Reorder cached hidden states for beam search generation"""
reordered_past = []
for layer_past in past:
reordered_past.append(
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
)
return reordered_past
def save_pretrained(self, save_directory):
"""Save model and config to directory"""
import os
import json
import torch
os.makedirs(save_directory, exist_ok=True)
# Save config
config_path = os.path.join(save_directory, "config.json")
with open(config_path, "w", encoding="utf-8") as f:
json.dump(self.config.to_dict(), f, indent=2)
# Save model weights
model_path = os.path.join(save_directory, "pytorch_model.bin")
torch.save(self.state_dict(), model_path)
return [config_path, model_path]
@classmethod
def from_pretrained(cls, model_path):
"""Load model and config from directory"""
import os
import json
import torch
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
# Create config object
config = HindiCausalLMConfig(**config_dict)
# Create model
model = cls(config)
# Load model weights
model_path = os.path.join(model_path, "pytorch_model.bin")
model.load_state_dict(torch.load(model_path, map_location="cpu"))
return model
class HindiTextGenerator:
"""Text generation utility for HindiCausalLM"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
def generate(
self,
prompt,
max_length=100,
temperature=1.0,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
**kwargs
):
"""Generate text from a prompt"""
# Encode the prompt
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.model.config.max_position_embeddings - max_length
)["input_ids"].to(self.device)
# Set the model to evaluation mode
self.model.eval()
# Set generation parameters
gen_kwargs = {
"max_length": input_ids.shape[1] + max_length,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": do_sample,
"num_return_sequences": num_return_sequences,
**kwargs
}
# Generate text
with torch.no_grad():
output_sequences = self._generate_text(input_ids, **gen_kwargs)
# Decode generated sequences
generated_texts = []
for sequence in output_sequences:
# Remove the prompt from the generated text
sequence = sequence[input_ids.shape[1]:]
text = self.tokenizer.sp_model.DecodeIds(sequence.tolist())
generated_texts.append(text)
if num_return_sequences == 1:
return generated_texts[0]
return generated_texts
def _generate_text(
self,
input_ids,
max_length,
temperature=1.0,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
pad_token_id=None,
eos_token_id=None,
**kwargs
):
"""Core text generation logic"""
# Set pad_token_id and eos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.model.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.model.config.eos_token_id
batch_size = input_ids.shape[0]
# Create attention mask
attention_mask = torch.ones_like(input_ids)
# Clone the input_ids for each return sequence
input_ids = input_ids.repeat(num_return_sequences, 1)
attention_mask = attention_mask.repeat(num_return_sequences, 1)
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Keep track of which sequences are already finished
cur_len = input_ids.shape[1]
while cur_len < max_length:
# Prepare model inputs
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, attention_mask=attention_mask
)
# Forward pass
outputs = self.model(**model_inputs, return_dict=True)
next_token_logits = outputs["logits"][:, -1, :]
# Apply temperature scaling
next_token_logits = next_token_logits / temperature
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(input_ids.shape[0]):
for token_id in set(input_ids[i].tolist()):
next_token_logits[i, token_id] /= repetition_penalty
# Filter logits using top-k and top-p sampling
if do_sample:
# Top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = -float("Inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(next_token_logits.shape[0]):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
next_token_logits[i, indices_to_remove] = -float("Inf")
# Sample from the filtered distribution
probabilities = F.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probabilities, 1).squeeze(1)
else:
# Greedy decoding
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Update unfinished sequences based on EOS token
if eos_token_id is not None:
# Set the unfinished flag to 0 if the sequence reaches EOS
unfinished_sequences = unfinished_sequences.mul(
next_tokens.ne(eos_token_id).long()
)
# Check if all sequences are finished
if unfinished_sequences.max() == 0:
break
# Concatenate next tokens to input_ids
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
# Expand attention mask
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
cur_len += 1
return input_ids