|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from typing import Optional, Tuple, List, Dict, Any, Union |
|
|
|
class HindiCausalLMConfig: |
|
"""Configuration class for Hindi Causal Language Model""" |
|
def __init__( |
|
self, |
|
vocab_size: int = 32000, |
|
hidden_size: int = 768, |
|
num_hidden_layers: int = 12, |
|
num_attention_heads: int = 12, |
|
intermediate_size: int = 3072, |
|
hidden_dropout_prob: float = 0.1, |
|
attention_probs_dropout_prob: float = 0.1, |
|
max_position_embeddings: int = 512, |
|
layer_norm_eps: float = 1e-12, |
|
pad_token_id: int = 0, |
|
bos_token_id: int = 1, |
|
eos_token_id: int = 2, |
|
tie_word_embeddings: bool = True, |
|
**kwargs |
|
): |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.max_position_embeddings = max_position_embeddings |
|
self.layer_norm_eps = layer_norm_eps |
|
self.pad_token_id = pad_token_id |
|
self.bos_token_id = bos_token_id |
|
self.eos_token_id = eos_token_id |
|
self.tie_word_embeddings = tie_word_embeddings |
|
|
|
|
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
@classmethod |
|
def from_embedding_config(cls, config_dict, **kwargs): |
|
"""Create LM config from embedding model config""" |
|
|
|
override_params = {} |
|
for key in ["num_hidden_layers", "hidden_size", "num_attention_heads", |
|
"intermediate_size", "max_position_embeddings", "vocab_size"]: |
|
if key in kwargs: |
|
override_params[key] = kwargs.pop(key) |
|
|
|
|
|
hidden_size = override_params.get("hidden_size", config_dict.get("hidden_size", 768)) |
|
|
|
|
|
if "num_attention_heads" not in override_params: |
|
|
|
for heads in [12, 16, 8, 4]: |
|
if hidden_size % heads == 0: |
|
print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}") |
|
override_params["num_attention_heads"] = heads |
|
break |
|
|
|
|
|
if "num_attention_heads" not in override_params: |
|
|
|
for heads in range(min(32, hidden_size), 0, -1): |
|
if hidden_size % heads == 0: |
|
print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}") |
|
override_params["num_attention_heads"] = heads |
|
break |
|
|
|
|
|
config_params = { |
|
"vocab_size": override_params.get("vocab_size", config_dict.get("vocab_size", 32000)), |
|
"hidden_size": hidden_size, |
|
"num_hidden_layers": override_params.get("num_hidden_layers", config_dict.get("num_hidden_layers", 12)), |
|
"num_attention_heads": override_params.get("num_attention_heads", config_dict.get("num_attention_heads", 12)), |
|
"intermediate_size": override_params.get("intermediate_size", config_dict.get("intermediate_size", 3072)), |
|
"hidden_dropout_prob": config_dict.get("hidden_dropout_prob", 0.1), |
|
"attention_probs_dropout_prob": config_dict.get("attention_probs_dropout_prob", 0.1), |
|
"max_position_embeddings": override_params.get("max_position_embeddings", |
|
config_dict.get("max_position_embeddings", 512)), |
|
"layer_norm_eps": config_dict.get("layer_norm_eps", 1e-12), |
|
"pad_token_id": config_dict.get("pad_token_id", 0), |
|
} |
|
|
|
|
|
if config_params["hidden_size"] % config_params["num_attention_heads"] != 0: |
|
raise ValueError( |
|
f"Hidden size ({config_params['hidden_size']}) must be divisible by the number of attention " |
|
f"heads ({config_params['num_attention_heads']})" |
|
) |
|
|
|
|
|
config_params.update(kwargs) |
|
|
|
|
|
lm_config = cls(**config_params) |
|
return lm_config |
|
|
|
def to_dict(self): |
|
"""Convert config to dictionary""" |
|
return {k: v for k, v in self.__dict__.items()} |
|
|
|
class CausalSelfAttention(nn.Module): |
|
"""Causal self-attention layer""" |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.hidden_size % config.num_attention_heads == 0 |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
self.attention_head_size = config.hidden_size // config.num_attention_heads |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
|
|
self.output = nn.Sequential( |
|
nn.Linear(self.all_head_size, config.hidden_size), |
|
nn.Dropout(config.attention_probs_dropout_prob) |
|
) |
|
|
|
|
|
self.register_buffer( |
|
"causal_mask", |
|
torch.triu( |
|
torch.ones(config.max_position_embeddings, config.max_position_embeddings) * -1e10, |
|
diagonal=1 |
|
) |
|
) |
|
|
|
def transpose_for_scores(self, x): |
|
|
|
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|
x = x.view(*new_shape) |
|
|
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
batch_size, seq_length = hidden_states.size()[:2] |
|
|
|
|
|
query_layer = self.transpose_for_scores(self.query(hidden_states)) |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
|
|
|
|
causal_mask = self.causal_mask[:seq_length, :seq_length] |
|
attention_scores = attention_scores + causal_mask |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
attention_mask = (1.0 - attention_mask) * -10000.0 |
|
attention_scores = attention_scores + attention_mask |
|
|
|
|
|
attention_probs = F.softmax(attention_scores, dim=-1) |
|
attention_probs = F.dropout(attention_probs, p=0.1, training=self.training) |
|
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
context_layer = context_layer.view(*new_shape) |
|
|
|
|
|
output = self.output(context_layer) |
|
|
|
return output |
|
|
|
class TransformerBlock(nn.Module): |
|
"""Transformer block with causal attention for language modeling""" |
|
def __init__(self, config): |
|
super().__init__() |
|
self.attention = CausalSelfAttention(config) |
|
self.attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
|
self.ffn = nn.Sequential( |
|
nn.Linear(config.hidden_size, config.intermediate_size), |
|
nn.GELU(), |
|
nn.Linear(config.intermediate_size, config.hidden_size), |
|
nn.Dropout(config.hidden_dropout_prob) |
|
) |
|
self.ffn_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
|
|
attn_output = self.attention(hidden_states, attention_mask) |
|
hidden_states = self.attention_layernorm(hidden_states + attn_output) |
|
|
|
|
|
ffn_output = self.ffn(hidden_states) |
|
hidden_states = self.ffn_layernorm(hidden_states + ffn_output) |
|
|
|
return hidden_states |
|
|
|
class HindiCausalLM(nn.Module): |
|
"""Hindi Causal Language Model for text generation""" |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
self.token_embeddings = nn.Embedding( |
|
config.vocab_size, |
|
config.hidden_size, |
|
padding_idx=config.pad_token_id |
|
) |
|
self.position_embeddings = nn.Embedding( |
|
config.max_position_embeddings, |
|
config.hidden_size |
|
) |
|
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
TransformerBlock(config) for _ in range(config.num_hidden_layers) |
|
]) |
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
if config.tie_word_embeddings: |
|
self.lm_head.weight = self.token_embeddings.weight |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
"""Initialize weights with small random values""" |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def get_input_embeddings(self): |
|
return self.token_embeddings |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.token_embeddings = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
labels=None, |
|
return_dict=True |
|
): |
|
device = input_ids.device |
|
batch_size, seq_length = input_ids.size() |
|
|
|
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) |
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
|
|
|
|
|
inputs_embeds = self.token_embeddings(input_ids) |
|
position_embeds = self.position_embeddings(position_ids) |
|
|
|
|
|
hidden_states = inputs_embeds + position_embeds |
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(batch_size, seq_length, device=device) |
|
|
|
|
|
for layer in self.layers: |
|
hidden_states = layer(hidden_states, attention_mask) |
|
|
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
labels = labels.to(device) |
|
|
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
shift_labels.view(-1) |
|
) |
|
|
|
if return_dict: |
|
return { |
|
"logits": lm_logits, |
|
"loss": loss, |
|
"hidden_states": hidden_states |
|
} |
|
|
|
return (lm_logits, loss, hidden_states) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): |
|
"""Prepare inputs for text generation""" |
|
|
|
inputs = { |
|
"input_ids": input_ids, |
|
} |
|
|
|
|
|
if attention_mask is not None: |
|
inputs["attention_mask"] = attention_mask |
|
|
|
return inputs |
|
|
|
@staticmethod |
|
def _reorder_cache(past, beam_idx): |
|
"""Reorder cached hidden states for beam search generation""" |
|
reordered_past = [] |
|
for layer_past in past: |
|
reordered_past.append( |
|
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) |
|
) |
|
return reordered_past |
|
|
|
def save_pretrained(self, save_directory): |
|
"""Save model and config to directory""" |
|
import os |
|
import json |
|
import torch |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
config_path = os.path.join(save_directory, "config.json") |
|
with open(config_path, "w", encoding="utf-8") as f: |
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
model_path = os.path.join(save_directory, "pytorch_model.bin") |
|
torch.save(self.state_dict(), model_path) |
|
|
|
return [config_path, model_path] |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_path): |
|
"""Load model and config from directory""" |
|
import os |
|
import json |
|
import torch |
|
|
|
|
|
config_path = os.path.join(model_path, "config.json") |
|
with open(config_path, "r", encoding="utf-8") as f: |
|
config_dict = json.load(f) |
|
|
|
|
|
config = HindiCausalLMConfig(**config_dict) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
model_path = os.path.join(model_path, "pytorch_model.bin") |
|
model.load_state_dict(torch.load(model_path, map_location="cpu")) |
|
|
|
return model |
|
|
|
class HindiTextGenerator: |
|
"""Text generation utility for HindiCausalLM""" |
|
def __init__(self, model, tokenizer): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = next(model.parameters()).device |
|
|
|
def generate( |
|
self, |
|
prompt, |
|
max_length=100, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
**kwargs |
|
): |
|
"""Generate text from a prompt""" |
|
|
|
input_ids = self.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=self.model.config.max_position_embeddings - max_length |
|
)["input_ids"].to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
gen_kwargs = { |
|
"max_length": input_ids.shape[1] + max_length, |
|
"temperature": temperature, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"repetition_penalty": repetition_penalty, |
|
"do_sample": do_sample, |
|
"num_return_sequences": num_return_sequences, |
|
**kwargs |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
output_sequences = self._generate_text(input_ids, **gen_kwargs) |
|
|
|
|
|
generated_texts = [] |
|
for sequence in output_sequences: |
|
|
|
sequence = sequence[input_ids.shape[1]:] |
|
text = self.tokenizer.sp_model.DecodeIds(sequence.tolist()) |
|
generated_texts.append(text) |
|
|
|
if num_return_sequences == 1: |
|
return generated_texts[0] |
|
|
|
return generated_texts |
|
|
|
def _generate_text( |
|
self, |
|
input_ids, |
|
max_length, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
pad_token_id=None, |
|
eos_token_id=None, |
|
**kwargs |
|
): |
|
"""Core text generation logic""" |
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.model.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.model.config.eos_token_id |
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
input_ids = input_ids.repeat(num_return_sequences, 1) |
|
attention_mask = attention_mask.repeat(num_return_sequences, 1) |
|
|
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
|
|
|
|
cur_len = input_ids.shape[1] |
|
|
|
while cur_len < max_length: |
|
|
|
model_inputs = self.model.prepare_inputs_for_generation( |
|
input_ids, attention_mask=attention_mask |
|
) |
|
|
|
|
|
outputs = self.model(**model_inputs, return_dict=True) |
|
next_token_logits = outputs["logits"][:, -1, :] |
|
|
|
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
if repetition_penalty != 1.0: |
|
for i in range(input_ids.shape[0]): |
|
for token_id in set(input_ids[i].tolist()): |
|
next_token_logits[i, token_id] /= repetition_penalty |
|
|
|
|
|
if do_sample: |
|
|
|
if top_k > 0: |
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
|
next_token_logits[indices_to_remove] = -float("Inf") |
|
|
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
for i in range(next_token_logits.shape[0]): |
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
|
next_token_logits[i, indices_to_remove] = -float("Inf") |
|
|
|
|
|
probabilities = F.softmax(next_token_logits, dim=-1) |
|
next_tokens = torch.multinomial(probabilities, 1).squeeze(1) |
|
else: |
|
|
|
next_tokens = torch.argmax(next_token_logits, dim=-1) |
|
|
|
|
|
if eos_token_id is not None: |
|
|
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.ne(eos_token_id).long() |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0: |
|
break |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
attention_mask = torch.cat( |
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
|
) |
|
|
|
cur_len += 1 |
|
|
|
return input_ids |
|
|