Initial upload of custom Hindi LM v1
Browse files- README.md +190 -0
- config.json +23 -0
- hindi_embeddings.py +730 -0
- hindi_language_model.py +547 -0
- model.safetensors +3 -0
- step_loss_lr.png +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +9 -0
README.md
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
---
|
3 |
+
license: mit # Or choose another appropriate license: https://huggingface.co/docs/hub/repositories-licenses
|
4 |
+
language: hi
|
5 |
+
tags:
|
6 |
+
- hindi
|
7 |
+
- text-generation
|
8 |
+
- causal-lm
|
9 |
+
- custom-model
|
10 |
+
# Add more specific tags if applicable, e.g., based on training data domain
|
11 |
+
pipeline_tag: text-generation
|
12 |
+
# Specify model size if known from config
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
# Hindi Causal Language Model (convaiinnovations/hindi-foundational-model-base)
|
17 |
+
|
18 |
+
This repository contains a custom-trained Hindi Causal Language Model.
|
19 |
+
|
20 |
+
## Model Description
|
21 |
+
|
22 |
+
* **Architecture:** Custom Transformer (12 layers, hidden=768, 16 heads, ffn=3072, act=swiglu, norm=rmsnorm) based on the `HindiCausalLM` class with modified attention mechanisms and Hindi-specific optimizations including multi-resolution attention to capture both character-level and word-level patterns, morphology-aware feed-forward layers, and script-mix processing for Hindi-English code-mixing.
|
23 |
+
* **Language:** Hindi (hi)
|
24 |
+
* **Training Data:** The model was trained on a diverse corpus of 2.7 million high-quality Hindi text samples from multiple sources including IITB Parallel Corpus (1.2M sentences), Samanantar (750K samples), Oscar Hindi (450K sentences), CC-100 Hindi (300K sentences), Hindi Wikipedia (150K articles), Hindi news articles (100K pieces), XNLI Hindi (50K premise-hypothesis pairs), IndicGLUE (30K samples), and Hindi literature (5K passages).
|
25 |
+
* **Tokenizer:** SentencePiece trained on Hindi text (`tokenizer.model`). Vocab Size: 16000
|
26 |
+
* **Model Details:** Trained for [Unknown] epochs [Add more training details like batch size, learning rate if desired].
|
27 |
+
|
28 |
+
## How to Use
|
29 |
+
|
30 |
+
**⚠️ Important:** This model uses custom Python classes (`HindiCausalLM`, `HindiCausalLMConfig`, `SentencePieceTokenizerWrapper`) which are **not** part of the standard Hugging Face `transformers` library. To use this model, you **must** have the Python files defining these classes (e.g., `hindi_language_model.py`, `hindi_embeddings.py`) available in your Python environment.
|
31 |
+
|
32 |
+
```python
|
33 |
+
import os
|
34 |
+
import json
|
35 |
+
import torch
|
36 |
+
import numpy as np
|
37 |
+
from huggingface_hub import hf_hub_download
|
38 |
+
|
39 |
+
# --- ENSURE THESE CLASSES ARE DEFINED OR IMPORTED ---
|
40 |
+
# You MUST have hindi_language_model.py and hindi_embeddings.py available
|
41 |
+
try:
|
42 |
+
from hindi_language_model import HindiCausalLM, HindiCausalLMConfig
|
43 |
+
from hindi_embeddings import SentencePieceTokenizerWrapper
|
44 |
+
print("Custom classes imported successfully.")
|
45 |
+
except ImportError:
|
46 |
+
print("ERROR: Cannot import custom classes.")
|
47 |
+
print("Please place hindi_language_model.py and hindi_embeddings.py in your working directory or Python path.")
|
48 |
+
# Define minimal dummy classes to potentially allow script execution, but loading will fail
|
49 |
+
class SentencePieceTokenizerWrapper: pass
|
50 |
+
class HindiCausalLMConfig: pass
|
51 |
+
class HindiCausalLM(torch.nn.Module): pass
|
52 |
+
# Exit if classes are truly needed
|
53 |
+
# exit()
|
54 |
+
|
55 |
+
|
56 |
+
# --- Configuration ---
|
57 |
+
repo_id = "convaiinnovations/hindi-foundational-model-base"
|
58 |
+
# model_dir = "./downloaded_model" # Example download location
|
59 |
+
# os.makedirs(model_dir, exist_ok=True)
|
60 |
+
# Use current directory if preferred
|
61 |
+
model_dir = "."
|
62 |
+
|
63 |
+
|
64 |
+
# --- Download Files ---
|
65 |
+
print(f"Downloading files for {repo_id} to '{os.path.abspath(model_dir)}'...")
|
66 |
+
try:
|
67 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", local_dir=model_dir, local_dir_use_symlinks=False)
|
68 |
+
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.model", local_dir=model_dir, local_dir_use_symlinks=False)
|
69 |
+
# Try safetensors first, then bin
|
70 |
+
using_safetensors = True
|
71 |
+
try: weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", local_dir=model_dir, local_dir_use_symlinks=False)
|
72 |
+
except Exception: # More specific: from huggingface_hub.utils import EntryNotFoundError
|
73 |
+
try: weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", local_dir=model_dir, local_dir_use_symlinks=False); using_safetensors = False
|
74 |
+
except Exception as e_inner: raise FileNotFoundError(f"Could not download weights (.safetensors or .bin): {e_inner}") from e_inner
|
75 |
+
except Exception as e: raise RuntimeError(f"Failed to download files from Hub: {e}") from e
|
76 |
+
print("Files downloaded.")
|
77 |
+
|
78 |
+
# --- Load Components ---
|
79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
80 |
+
print(f"Using device: {device}")
|
81 |
+
|
82 |
+
try:
|
83 |
+
# 1. Load Tokenizer
|
84 |
+
tokenizer = SentencePieceTokenizerWrapper(tokenizer_path) # Assumes constructor takes path
|
85 |
+
|
86 |
+
# 2. Load Config
|
87 |
+
with open(config_path, 'r', encoding='utf-8') as f: config_dict = json.load(f)
|
88 |
+
if hasattr(HindiCausalLMConfig, 'from_dict'): config = HindiCausalLMConfig.from_dict(config_dict)
|
89 |
+
else: config = HindiCausalLMConfig(**config_dict) # Assumes __init__ takes kwargs
|
90 |
+
|
91 |
+
# 3. Load Model
|
92 |
+
model = HindiCausalLM(config) # Instantiate model
|
93 |
+
if using_safetensors:
|
94 |
+
from safetensors.torch import load_file
|
95 |
+
state_dict = load_file(weights_path, device="cpu")
|
96 |
+
else:
|
97 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
98 |
+
model.load_state_dict(state_dict, strict=True) # Use strict=True after training
|
99 |
+
del state_dict
|
100 |
+
model.to(device)
|
101 |
+
model.eval()
|
102 |
+
print("Model and tokenizer loaded successfully.")
|
103 |
+
|
104 |
+
except Exception as e:
|
105 |
+
print(f"ERROR: Failed loading model components: {e}")
|
106 |
+
# Add more specific error handling if needed
|
107 |
+
exit()
|
108 |
+
|
109 |
+
|
110 |
+
# --- Example Inference ---
|
111 |
+
prompt = "भारत की संस्कृति" # Example prompt
|
112 |
+
max_new_tokens = 60 # Generate N new tokens
|
113 |
+
temperature = 0.7
|
114 |
+
top_k = 50
|
115 |
+
seed = 42
|
116 |
+
|
117 |
+
print(f"\nGenerating text for prompt: '{prompt}'...")
|
118 |
+
|
119 |
+
torch.manual_seed(seed)
|
120 |
+
np.random.seed(seed)
|
121 |
+
if device.type == 'cuda': torch.cuda.manual_seed_all(seed)
|
122 |
+
|
123 |
+
try:
|
124 |
+
# Encoding (use the correct method from your wrapper)
|
125 |
+
if hasattr(tokenizer, '__call__'): encoded = tokenizer(prompt, return_tensors=None); input_ids = encoded.get('input_ids')
|
126 |
+
elif hasattr(tokenizer, 'sp_model') and hasattr(tokenizer.sp_model, 'EncodeAsIds'): input_ids = tokenizer.sp_model.EncodeAsIds(prompt)
|
127 |
+
else: raise AttributeError("Tokenizer lacks encoding method.")
|
128 |
+
assert input_ids, "Encoding failed"
|
129 |
+
|
130 |
+
bos_id = getattr(tokenizer, 'bos_token_id', 1)
|
131 |
+
if bos_id is not None: input_ids = [bos_id] + input_ids
|
132 |
+
|
133 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
|
134 |
+
generated_ids = input_tensor
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
for _ in range(max_new_tokens):
|
138 |
+
outputs = model(input_ids=generated_ids)
|
139 |
+
# Access logits
|
140 |
+
if isinstance(outputs, dict) and 'logits' in outputs: logits = outputs['logits']
|
141 |
+
elif hasattr(outputs, 'logits'): logits = outputs.logits
|
142 |
+
else: raise TypeError("Model output format error.")
|
143 |
+
|
144 |
+
next_token_logits = logits[:, -1, :]
|
145 |
+
# Sampling
|
146 |
+
if temperature > 0: scaled_logits = next_token_logits / temperature
|
147 |
+
else: scaled_logits = next_token_logits
|
148 |
+
if top_k > 0: kth_vals, _ = torch.topk(scaled_logits, k=top_k); scaled_logits[scaled_logits < kth_vals[:, -1].unsqueeze(-1)] = -float("Inf")
|
149 |
+
probs = torch.softmax(scaled_logits, dim=-1)
|
150 |
+
next_token_id = torch.multinomial(probs, num_samples=1)
|
151 |
+
generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
|
152 |
+
# Check EOS
|
153 |
+
eos_id = getattr(tokenizer, 'eos_token_id', 2)
|
154 |
+
if eos_id is not None and next_token_id.item() == eos_id: break
|
155 |
+
|
156 |
+
# Decoding
|
157 |
+
output_ids = generated_ids[0].cpu().tolist()
|
158 |
+
# Remove special tokens
|
159 |
+
if bos_id and output_ids and output_ids[0] == bos_id: output_ids = output_ids[1:]
|
160 |
+
if eos_id and output_ids and output_ids[-1] == eos_id: output_ids = output_ids[:-1]
|
161 |
+
|
162 |
+
# Use appropriate decode method
|
163 |
+
if hasattr(tokenizer, 'sp_model') and hasattr(tokenizer.sp_model, 'DecodeIds'): generated_text = tokenizer.sp_model.DecodeIds(output_ids)
|
164 |
+
elif hasattr(tokenizer, 'decode'): generated_text = tokenizer.decode(output_ids)
|
165 |
+
else: raise AttributeError("Tokenizer lacks decoding method.")
|
166 |
+
|
167 |
+
print("\n--- Generated Text ---")
|
168 |
+
# Print prompt + generated text for context
|
169 |
+
print(prompt + generated_text)
|
170 |
+
print("----------------------")
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
print(f"\nERROR during example inference: {e}")
|
174 |
+
```
|
175 |
+
|
176 |
+
## Limitations and Biases
|
177 |
+
|
178 |
+
This model was trained on a diverse corpus of Hindi text from sources including IITB Parallel Corpus, Samanantar, Oscar Hindi, CC-100 Hindi, Hindi Wikipedia, news articles, XNLI Hindi, IndicGLUE, and Hindi literature. As such, it may reflect biases present in that data, including potential cultural, gender, or regional biases found in these source materials.
|
179 |
+
|
180 |
+
The model's performance is limited by its architecture (12 layers, hidden=768, 16 heads, ffn=3072, act=swiglu, norm=rmsnorm) and the size of the training dataset.
|
181 |
+
|
182 |
+
It may generate repetitive, nonsensical, or factually incorrect text.
|
183 |
+
|
184 |
+
The model uses a weighted pooling strategy with sensitivity to Hindi's SOV structure, but may still struggle with complex semantic relationships in longer texts.
|
185 |
+
|
186 |
+
As noted in the DeepRAG research paper, the model may have particular difficulties with cultural concepts that lack direct English translations, idiomatic expressions specific to Hindi, and formal/informal speech distinctions.
|
187 |
+
|
188 |
+
Please use this model responsibly.
|
189 |
+
|
190 |
+
Model trained using custom scripts.
|
config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 16000,
|
3 |
+
"hidden_size": 768,
|
4 |
+
"num_hidden_layers": 12,
|
5 |
+
"num_attention_heads": 16,
|
6 |
+
"intermediate_size": 3072,
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"attention_probs_dropout_prob": 0.1,
|
9 |
+
"max_position_embeddings": 512,
|
10 |
+
"layer_norm_eps": 1e-12,
|
11 |
+
"pad_token_id": 0,
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"eos_token_id": 2,
|
14 |
+
"tie_word_embeddings": true,
|
15 |
+
"activation_function": "swiglu",
|
16 |
+
"normalization_layer": "rmsnorm",
|
17 |
+
"positional_encoding_type": "rope",
|
18 |
+
"unk_token_id": 3,
|
19 |
+
"architectures": [
|
20 |
+
"HindiCausalLM"
|
21 |
+
],
|
22 |
+
"model_type": "hindi_causal_lm"
|
23 |
+
}
|
hindi_embeddings.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import sentencepiece as spm
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
from tqdm import tqdm
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from sklearn.manifold import TSNE
|
12 |
+
|
13 |
+
# Tokenizer wrapper class
|
14 |
+
class SentencePieceTokenizerWrapper:
|
15 |
+
def __init__(self, sp_model_path):
|
16 |
+
self.sp_model = spm.SentencePieceProcessor()
|
17 |
+
self.sp_model.Load(sp_model_path)
|
18 |
+
self.vocab_size = self.sp_model.GetPieceSize()
|
19 |
+
|
20 |
+
# Special token IDs from tokenizer training
|
21 |
+
self.pad_token_id = 0
|
22 |
+
self.bos_token_id = 1
|
23 |
+
self.eos_token_id = 2
|
24 |
+
self.unk_token_id = 3
|
25 |
+
|
26 |
+
# Set special tokens
|
27 |
+
self.pad_token = "<pad>"
|
28 |
+
self.bos_token = "<s>"
|
29 |
+
self.eos_token = "</s>"
|
30 |
+
self.unk_token = "<unk>"
|
31 |
+
self.mask_token = "<mask>"
|
32 |
+
|
33 |
+
def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None):
|
34 |
+
# Handle both string and list inputs
|
35 |
+
if isinstance(text, str):
|
36 |
+
# Encode a single string
|
37 |
+
ids = self.sp_model.EncodeAsIds(text)
|
38 |
+
|
39 |
+
# Handle truncation
|
40 |
+
if truncation and max_length and len(ids) > max_length:
|
41 |
+
ids = ids[:max_length]
|
42 |
+
|
43 |
+
attention_mask = [1] * len(ids)
|
44 |
+
|
45 |
+
# Handle padding
|
46 |
+
if padding and max_length:
|
47 |
+
padding_length = max(0, max_length - len(ids))
|
48 |
+
ids = ids + [self.pad_token_id] * padding_length
|
49 |
+
attention_mask = attention_mask + [0] * padding_length
|
50 |
+
|
51 |
+
result = {
|
52 |
+
'input_ids': ids,
|
53 |
+
'attention_mask': attention_mask
|
54 |
+
}
|
55 |
+
|
56 |
+
# Convert to tensors if requested
|
57 |
+
if return_tensors == 'pt':
|
58 |
+
import torch
|
59 |
+
result = {k: torch.tensor([v]) for k, v in result.items()}
|
60 |
+
|
61 |
+
return result
|
62 |
+
|
63 |
+
# Process a batch of texts
|
64 |
+
batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text]
|
65 |
+
|
66 |
+
# Apply truncation if needed
|
67 |
+
if truncation and max_length:
|
68 |
+
batch_encoded = [ids[:max_length] for ids in batch_encoded]
|
69 |
+
|
70 |
+
# Create attention masks
|
71 |
+
batch_attention_mask = [[1] * len(ids) for ids in batch_encoded]
|
72 |
+
|
73 |
+
# Apply padding if needed
|
74 |
+
if padding:
|
75 |
+
if max_length:
|
76 |
+
max_len = max_length
|
77 |
+
else:
|
78 |
+
max_len = max(len(ids) for ids in batch_encoded)
|
79 |
+
|
80 |
+
# Pad sequences to max_len
|
81 |
+
batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded]
|
82 |
+
batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask]
|
83 |
+
|
84 |
+
result = {
|
85 |
+
'input_ids': batch_encoded,
|
86 |
+
'attention_mask': batch_attention_mask
|
87 |
+
}
|
88 |
+
|
89 |
+
# Convert to tensors if requested
|
90 |
+
if return_tensors == 'pt':
|
91 |
+
import torch
|
92 |
+
result = {k: torch.tensor(v) for k, v in result.items()}
|
93 |
+
|
94 |
+
return result
|
95 |
+
|
96 |
+
# Model architecture components
|
97 |
+
class MultiHeadAttention(nn.Module):
|
98 |
+
"""Multi-headed attention mechanism"""
|
99 |
+
def __init__(self, config):
|
100 |
+
super().__init__()
|
101 |
+
self.num_attention_heads = config["num_attention_heads"]
|
102 |
+
self.attention_head_size = config["hidden_size"] // config["num_attention_heads"]
|
103 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
104 |
+
|
105 |
+
# Query, Key, Value projections
|
106 |
+
self.query = nn.Linear(config["hidden_size"], self.all_head_size)
|
107 |
+
self.key = nn.Linear(config["hidden_size"], self.all_head_size)
|
108 |
+
self.value = nn.Linear(config["hidden_size"], self.all_head_size)
|
109 |
+
|
110 |
+
# Output projection
|
111 |
+
self.output = nn.Sequential(
|
112 |
+
nn.Linear(self.all_head_size, config["hidden_size"]),
|
113 |
+
nn.Dropout(config["attention_probs_dropout_prob"])
|
114 |
+
)
|
115 |
+
|
116 |
+
# Simplified relative position bias
|
117 |
+
self.max_position_embeddings = config["max_position_embeddings"]
|
118 |
+
self.relative_attention_bias = nn.Embedding(
|
119 |
+
2 * config["max_position_embeddings"] - 1,
|
120 |
+
config["num_attention_heads"]
|
121 |
+
)
|
122 |
+
|
123 |
+
def transpose_for_scores(self, x):
|
124 |
+
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
125 |
+
x = x.view(*new_shape)
|
126 |
+
return x.permute(0, 2, 1, 3)
|
127 |
+
|
128 |
+
def forward(self, hidden_states, attention_mask=None):
|
129 |
+
batch_size, seq_length = hidden_states.size()[:2]
|
130 |
+
|
131 |
+
# Project inputs to queries, keys, and values
|
132 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
133 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
134 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
135 |
+
|
136 |
+
# Take the dot product between query and key to get the raw attention scores
|
137 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
138 |
+
|
139 |
+
# Generate relative position matrix
|
140 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device)
|
141 |
+
relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0) # [seq_len, seq_len]
|
142 |
+
# Shift values to be >= 0
|
143 |
+
relative_position = relative_position + self.max_position_embeddings - 1
|
144 |
+
# Ensure indices are within bounds
|
145 |
+
relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2)
|
146 |
+
|
147 |
+
# Get relative position embeddings [seq_len, seq_len, num_heads]
|
148 |
+
rel_attn_bias = self.relative_attention_bias(relative_position) # [seq_len, seq_len, num_heads]
|
149 |
+
|
150 |
+
# Reshape to add to attention heads [1, num_heads, seq_len, seq_len]
|
151 |
+
rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0)
|
152 |
+
|
153 |
+
# Add to attention scores - now dimensions will match
|
154 |
+
attention_scores = attention_scores + rel_attn_bias
|
155 |
+
|
156 |
+
# Scale attention scores
|
157 |
+
attention_scores = attention_scores / (self.attention_head_size ** 0.5)
|
158 |
+
|
159 |
+
# Apply attention mask
|
160 |
+
if attention_mask is not None:
|
161 |
+
attention_scores = attention_scores + attention_mask
|
162 |
+
|
163 |
+
# Normalize the attention scores to probabilities
|
164 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
165 |
+
|
166 |
+
# Apply dropout
|
167 |
+
attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
|
168 |
+
|
169 |
+
# Apply attention to values
|
170 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
171 |
+
|
172 |
+
# Reshape back to [batch_size, seq_length, hidden_size]
|
173 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
174 |
+
new_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
175 |
+
context_layer = context_layer.view(*new_shape)
|
176 |
+
|
177 |
+
# Final output projection
|
178 |
+
output = self.output(context_layer)
|
179 |
+
|
180 |
+
return output
|
181 |
+
|
182 |
+
class EnhancedTransformerLayer(nn.Module):
|
183 |
+
"""Advanced transformer layer with pre-layer norm and enhanced attention"""
|
184 |
+
def __init__(self, config):
|
185 |
+
super().__init__()
|
186 |
+
self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
|
187 |
+
self.attention = MultiHeadAttention(config)
|
188 |
+
|
189 |
+
self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
|
190 |
+
|
191 |
+
# Feed-forward network
|
192 |
+
self.ffn = nn.Sequential(
|
193 |
+
nn.Linear(config["hidden_size"], config["intermediate_size"]),
|
194 |
+
nn.GELU(),
|
195 |
+
nn.Dropout(config["hidden_dropout_prob"]),
|
196 |
+
nn.Linear(config["intermediate_size"], config["hidden_size"]),
|
197 |
+
nn.Dropout(config["hidden_dropout_prob"])
|
198 |
+
)
|
199 |
+
|
200 |
+
def forward(self, hidden_states, attention_mask=None):
|
201 |
+
# Pre-layer norm for attention
|
202 |
+
attn_norm_hidden = self.attention_pre_norm(hidden_states)
|
203 |
+
|
204 |
+
# Self-attention
|
205 |
+
attention_output = self.attention(attn_norm_hidden, attention_mask)
|
206 |
+
|
207 |
+
# Residual connection
|
208 |
+
hidden_states = hidden_states + attention_output
|
209 |
+
|
210 |
+
# Pre-layer norm for feed-forward
|
211 |
+
ffn_norm_hidden = self.ffn_pre_norm(hidden_states)
|
212 |
+
|
213 |
+
# Feed-forward
|
214 |
+
ffn_output = self.ffn(ffn_norm_hidden)
|
215 |
+
|
216 |
+
# Residual connection
|
217 |
+
hidden_states = hidden_states + ffn_output
|
218 |
+
|
219 |
+
return hidden_states
|
220 |
+
|
221 |
+
class AdvancedTransformerModel(nn.Module):
|
222 |
+
"""Advanced Transformer model for inference"""
|
223 |
+
|
224 |
+
def __init__(self, config):
|
225 |
+
super().__init__()
|
226 |
+
self.config = config
|
227 |
+
|
228 |
+
# Embeddings
|
229 |
+
self.word_embeddings = nn.Embedding(
|
230 |
+
config["vocab_size"],
|
231 |
+
config["hidden_size"],
|
232 |
+
padding_idx=config["pad_token_id"]
|
233 |
+
)
|
234 |
+
|
235 |
+
# Position embeddings
|
236 |
+
self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"])
|
237 |
+
|
238 |
+
# Embedding dropout
|
239 |
+
self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"])
|
240 |
+
|
241 |
+
# Transformer layers
|
242 |
+
self.layers = nn.ModuleList([
|
243 |
+
EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"])
|
244 |
+
])
|
245 |
+
|
246 |
+
# Final layer norm
|
247 |
+
self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
|
248 |
+
|
249 |
+
def forward(self, input_ids, attention_mask=None):
|
250 |
+
input_shape = input_ids.size()
|
251 |
+
batch_size, seq_length = input_shape
|
252 |
+
|
253 |
+
# Get position ids
|
254 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
255 |
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
256 |
+
|
257 |
+
# Get embeddings
|
258 |
+
word_embeds = self.word_embeddings(input_ids)
|
259 |
+
position_embeds = self.position_embeddings(position_ids)
|
260 |
+
|
261 |
+
# Sum embeddings
|
262 |
+
embeddings = word_embeds + position_embeds
|
263 |
+
|
264 |
+
# Apply dropout
|
265 |
+
embeddings = self.embedding_dropout(embeddings)
|
266 |
+
|
267 |
+
# Default attention mask
|
268 |
+
if attention_mask is None:
|
269 |
+
attention_mask = torch.ones(input_shape, device=input_ids.device)
|
270 |
+
|
271 |
+
# Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens)
|
272 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
273 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
274 |
+
|
275 |
+
# Apply transformer layers
|
276 |
+
hidden_states = embeddings
|
277 |
+
for layer in self.layers:
|
278 |
+
hidden_states = layer(hidden_states, extended_attention_mask)
|
279 |
+
|
280 |
+
# Final layer norm
|
281 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
282 |
+
|
283 |
+
return hidden_states
|
284 |
+
|
285 |
+
class AdvancedPooling(nn.Module):
|
286 |
+
"""Advanced pooling module supporting multiple pooling strategies"""
|
287 |
+
def __init__(self, config):
|
288 |
+
super().__init__()
|
289 |
+
self.pooling_mode = config["pooling_mode"] # 'mean', 'max', 'cls', 'attention'
|
290 |
+
self.hidden_size = config["hidden_size"]
|
291 |
+
|
292 |
+
# For attention pooling
|
293 |
+
if self.pooling_mode == 'attention':
|
294 |
+
self.attention_weights = nn.Linear(config["hidden_size"], 1)
|
295 |
+
|
296 |
+
# For weighted pooling
|
297 |
+
elif self.pooling_mode == 'weighted':
|
298 |
+
self.weight_layer = nn.Linear(config["hidden_size"], 1)
|
299 |
+
|
300 |
+
def forward(self, token_embeddings, attention_mask=None):
|
301 |
+
if attention_mask is None:
|
302 |
+
attention_mask = torch.ones_like(token_embeddings[:, :, 0])
|
303 |
+
|
304 |
+
mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
305 |
+
|
306 |
+
if self.pooling_mode == 'cls':
|
307 |
+
# Use [CLS] token (first token)
|
308 |
+
pooled = token_embeddings[:, 0]
|
309 |
+
|
310 |
+
elif self.pooling_mode == 'max':
|
311 |
+
# Max pooling
|
312 |
+
token_embeddings = token_embeddings.clone()
|
313 |
+
# Set padding tokens to large negative value to exclude them from max
|
314 |
+
token_embeddings[mask_expanded == 0] = -1e9
|
315 |
+
pooled = torch.max(token_embeddings, dim=1)[0]
|
316 |
+
|
317 |
+
elif self.pooling_mode == 'attention':
|
318 |
+
# Attention pooling
|
319 |
+
weights = self.attention_weights(token_embeddings).squeeze(-1)
|
320 |
+
# Mask out padding tokens
|
321 |
+
weights = weights.masked_fill(attention_mask == 0, -1e9)
|
322 |
+
weights = F.softmax(weights, dim=1).unsqueeze(-1)
|
323 |
+
pooled = torch.sum(token_embeddings * weights, dim=1)
|
324 |
+
|
325 |
+
elif self.pooling_mode == 'weighted':
|
326 |
+
# Weighted average pooling
|
327 |
+
weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1)
|
328 |
+
# Apply mask
|
329 |
+
weights = weights * attention_mask
|
330 |
+
# Normalize weights
|
331 |
+
sum_weights = torch.sum(weights, dim=1, keepdim=True)
|
332 |
+
sum_weights = torch.clamp(sum_weights, min=1e-9)
|
333 |
+
weights = weights / sum_weights
|
334 |
+
# Apply weights
|
335 |
+
pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1)
|
336 |
+
|
337 |
+
else: # Default to mean pooling
|
338 |
+
# Mean pooling
|
339 |
+
sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
|
340 |
+
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
341 |
+
pooled = sum_embeddings / sum_mask
|
342 |
+
|
343 |
+
# L2 normalize
|
344 |
+
pooled = F.normalize(pooled, p=2, dim=1)
|
345 |
+
|
346 |
+
return pooled
|
347 |
+
|
348 |
+
class SentenceEmbeddingModel(nn.Module):
|
349 |
+
"""Complete sentence embedding model for inference"""
|
350 |
+
def __init__(self, config):
|
351 |
+
super(SentenceEmbeddingModel, self).__init__()
|
352 |
+
self.config = config
|
353 |
+
|
354 |
+
# Create transformer model
|
355 |
+
self.transformer = AdvancedTransformerModel(config)
|
356 |
+
|
357 |
+
# Create pooling module
|
358 |
+
self.pooling = AdvancedPooling(config)
|
359 |
+
|
360 |
+
# Build projection module if needed
|
361 |
+
if "projection_dim" in config and config["projection_dim"] > 0:
|
362 |
+
self.use_projection = True
|
363 |
+
self.projection = nn.Sequential(
|
364 |
+
nn.Linear(config["hidden_size"], config["hidden_size"]),
|
365 |
+
nn.GELU(),
|
366 |
+
nn.Linear(config["hidden_size"], config["projection_dim"]),
|
367 |
+
nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"])
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
self.use_projection = False
|
371 |
+
|
372 |
+
def forward(self, input_ids, attention_mask=None):
|
373 |
+
# Get token embeddings from transformer
|
374 |
+
token_embeddings = self.transformer(input_ids, attention_mask)
|
375 |
+
|
376 |
+
# Pool token embeddings
|
377 |
+
pooled_output = self.pooling(token_embeddings, attention_mask)
|
378 |
+
|
379 |
+
# Apply projection if enabled
|
380 |
+
if self.use_projection:
|
381 |
+
pooled_output = self.projection(pooled_output)
|
382 |
+
pooled_output = F.normalize(pooled_output, p=2, dim=1)
|
383 |
+
|
384 |
+
return pooled_output
|
385 |
+
|
386 |
+
class HindiEmbedder:
|
387 |
+
def __init__(self, model_path="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final"):
|
388 |
+
"""
|
389 |
+
Initialize the Hindi sentence embedder.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
model_path: Path to the model directory
|
393 |
+
"""
|
394 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
395 |
+
print(f"Using device: {self.device}")
|
396 |
+
|
397 |
+
# Load tokenizer - look for it in the model directory
|
398 |
+
tokenizer_path = os.path.join(model_path, "tokenizer.model")
|
399 |
+
|
400 |
+
if not os.path.exists(tokenizer_path):
|
401 |
+
raise FileNotFoundError(f"Could not find tokenizer at {tokenizer_path}")
|
402 |
+
|
403 |
+
self.tokenizer = SentencePieceTokenizerWrapper(tokenizer_path)
|
404 |
+
print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {self.tokenizer.vocab_size}")
|
405 |
+
|
406 |
+
# Load model config
|
407 |
+
config_path = os.path.join(model_path, "config.json")
|
408 |
+
with open(config_path, "r") as f:
|
409 |
+
self.config = json.load(f)
|
410 |
+
print(f"Loaded model config with hidden_size={self.config['hidden_size']}")
|
411 |
+
|
412 |
+
# Load model
|
413 |
+
model_pt_path = os.path.join(model_path, "embedding_model.pt")
|
414 |
+
|
415 |
+
try:
|
416 |
+
# Support both PyTorch 2.6+ and older versions
|
417 |
+
try:
|
418 |
+
checkpoint = torch.load(model_pt_path, map_location=self.device, weights_only=False)
|
419 |
+
print("Loaded model using PyTorch 2.6+ style loading")
|
420 |
+
except TypeError:
|
421 |
+
checkpoint = torch.load(model_pt_path, map_location=self.device)
|
422 |
+
print("Loaded model using older PyTorch style loading")
|
423 |
+
|
424 |
+
# Create model
|
425 |
+
self.model = SentenceEmbeddingModel(self.config)
|
426 |
+
|
427 |
+
# Load state dict
|
428 |
+
if "model_state_dict" in checkpoint:
|
429 |
+
state_dict = checkpoint["model_state_dict"]
|
430 |
+
else:
|
431 |
+
state_dict = checkpoint
|
432 |
+
|
433 |
+
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
|
434 |
+
print(f"Loaded model with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
|
435 |
+
|
436 |
+
# Move to device
|
437 |
+
self.model.to(self.device)
|
438 |
+
self.model.eval()
|
439 |
+
print("Model loaded successfully and placed in evaluation mode")
|
440 |
+
|
441 |
+
except Exception as e:
|
442 |
+
print(f"Error loading model: {e}")
|
443 |
+
raise RuntimeError(f"Failed to load the model: {e}")
|
444 |
+
|
445 |
+
def encode(self, sentences, batch_size=32, normalize=True):
|
446 |
+
"""
|
447 |
+
Encode sentences to embeddings.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
sentences: A string or list of strings to encode
|
451 |
+
batch_size: Batch size for encoding
|
452 |
+
normalize: Whether to normalize the embeddings
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
Numpy array of embeddings
|
456 |
+
"""
|
457 |
+
# Handle single sentence
|
458 |
+
if isinstance(sentences, str):
|
459 |
+
sentences = [sentences]
|
460 |
+
|
461 |
+
all_embeddings = []
|
462 |
+
|
463 |
+
# Process in batches
|
464 |
+
with torch.no_grad():
|
465 |
+
for i in range(0, len(sentences), batch_size):
|
466 |
+
batch = sentences[i:i+batch_size]
|
467 |
+
|
468 |
+
# Tokenize
|
469 |
+
inputs = self.tokenizer(
|
470 |
+
batch,
|
471 |
+
padding=True,
|
472 |
+
truncation=True,
|
473 |
+
max_length=self.config.get("max_position_embeddings", 128),
|
474 |
+
return_tensors="pt"
|
475 |
+
)
|
476 |
+
|
477 |
+
# Move to device
|
478 |
+
input_ids = inputs["input_ids"].to(self.device)
|
479 |
+
attention_mask = inputs["attention_mask"].to(self.device)
|
480 |
+
|
481 |
+
# Get embeddings
|
482 |
+
embeddings = self.model(input_ids, attention_mask)
|
483 |
+
|
484 |
+
# Move to CPU and convert to numpy
|
485 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
486 |
+
|
487 |
+
# Concatenate all embeddings
|
488 |
+
all_embeddings = np.vstack(all_embeddings)
|
489 |
+
|
490 |
+
# Normalize if requested
|
491 |
+
if normalize:
|
492 |
+
all_embeddings = all_embeddings / np.linalg.norm(all_embeddings, axis=1, keepdims=True)
|
493 |
+
|
494 |
+
return all_embeddings
|
495 |
+
|
496 |
+
def compute_similarity(self, texts1, texts2=None):
|
497 |
+
"""
|
498 |
+
Compute cosine similarity between texts.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
texts1: First set of texts
|
502 |
+
texts2: Second set of texts. If None, compute similarity matrix within texts1.
|
503 |
+
|
504 |
+
Returns:
|
505 |
+
Similarity scores
|
506 |
+
"""
|
507 |
+
# Convert single strings to lists for consistent handling
|
508 |
+
if isinstance(texts1, str):
|
509 |
+
texts1 = [texts1]
|
510 |
+
|
511 |
+
if texts2 is not None and isinstance(texts2, str):
|
512 |
+
texts2 = [texts2]
|
513 |
+
|
514 |
+
embeddings1 = self.encode(texts1)
|
515 |
+
|
516 |
+
if texts2 is None:
|
517 |
+
# Compute similarity matrix within texts1
|
518 |
+
similarities = cosine_similarity(embeddings1)
|
519 |
+
return similarities
|
520 |
+
else:
|
521 |
+
# Compute similarity between texts1 and texts2
|
522 |
+
embeddings2 = self.encode(texts2)
|
523 |
+
|
524 |
+
if len(texts1) == len(texts2):
|
525 |
+
# Compute pairwise similarity when the number of texts match
|
526 |
+
similarities = np.array([
|
527 |
+
cosine_similarity([e1], [e2])[0][0]
|
528 |
+
for e1, e2 in zip(embeddings1, embeddings2)
|
529 |
+
])
|
530 |
+
|
531 |
+
# If there's just one pair, return a scalar
|
532 |
+
if len(similarities) == 1:
|
533 |
+
return similarities[0]
|
534 |
+
return similarities
|
535 |
+
else:
|
536 |
+
# Return full similarity matrix
|
537 |
+
return cosine_similarity(embeddings1, embeddings2)
|
538 |
+
|
539 |
+
def search(self, query, documents, top_k=5):
|
540 |
+
"""
|
541 |
+
Search for similar documents to a query.
|
542 |
+
|
543 |
+
Args:
|
544 |
+
query: The query text
|
545 |
+
documents: List of documents to search
|
546 |
+
top_k: Number of top results to return
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
List of dictionaries with document and score
|
550 |
+
"""
|
551 |
+
# Get embeddings
|
552 |
+
query_embedding = self.encode([query])[0]
|
553 |
+
document_embeddings = self.encode(documents)
|
554 |
+
|
555 |
+
# Compute similarities
|
556 |
+
similarities = np.dot(document_embeddings, query_embedding)
|
557 |
+
|
558 |
+
# Get top indices
|
559 |
+
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
560 |
+
|
561 |
+
# Return results
|
562 |
+
results = []
|
563 |
+
for idx in top_indices:
|
564 |
+
results.append({
|
565 |
+
"document": documents[idx],
|
566 |
+
"score": float(similarities[idx])
|
567 |
+
})
|
568 |
+
|
569 |
+
return results
|
570 |
+
|
571 |
+
def evaluate_similarity_samples(self):
|
572 |
+
"""Evaluate model on some standard similarity examples for Hindi"""
|
573 |
+
test_pairs = [
|
574 |
+
(
|
575 |
+
"मुझे हिंदी में पढ़ना बहुत पसंद है।",
|
576 |
+
"मैं हिंदी किताबें बहुत पसंद करता हूँ।"
|
577 |
+
),
|
578 |
+
(
|
579 |
+
"आज मौसम बहुत अच्छा है।",
|
580 |
+
"आज बारिश हो रही है।"
|
581 |
+
),
|
582 |
+
(
|
583 |
+
"भारत एक विशाल देश है।",
|
584 |
+
"भारत में कई भाषाएँ बोली जाती हैं।"
|
585 |
+
),
|
586 |
+
(
|
587 |
+
"कंप्यूटर विज्ञान एक रोचक विषय है।",
|
588 |
+
"मैं कंप्यूटर साइंस का छात्र हूँ।"
|
589 |
+
),
|
590 |
+
(
|
591 |
+
"मैं रोज सुबह योग करता हूँ।",
|
592 |
+
"स्वस्थ रहने के लिए व्यायाम जरूरी है।"
|
593 |
+
),
|
594 |
+
# Add contrasting pairs to test discrimination
|
595 |
+
(
|
596 |
+
"मुझे हिंदी में पढ़ना बहुत पसंद है।",
|
597 |
+
"क्रिकेट भारत में सबसे लोकप्रिय खेल है।"
|
598 |
+
),
|
599 |
+
(
|
600 |
+
"आज मौसम बहुत अच्छा है।",
|
601 |
+
"भारतीय व्यंजन दुनिया भर में मशहूर हैं।"
|
602 |
+
),
|
603 |
+
(
|
604 |
+
"कंप्यूटर विज्ञान एक रोचक विषय है।",
|
605 |
+
"हिमालय दुनिया का सबसे ऊंचा पर्वत है।"
|
606 |
+
)
|
607 |
+
]
|
608 |
+
|
609 |
+
print("Evaluating model on standard similarity samples:")
|
610 |
+
for i, (text1, text2) in enumerate(test_pairs):
|
611 |
+
similarity = self.compute_similarity([text1], [text2])[0]
|
612 |
+
print(f"\nPair {i+1}:")
|
613 |
+
print(f" Sentence 1: {text1}")
|
614 |
+
print(f" Sentence 2: {text2}")
|
615 |
+
print(f" Similarity: {similarity:.4f}")
|
616 |
+
|
617 |
+
return
|
618 |
+
|
619 |
+
def visualize_embeddings(self, sentences, labels=None, output_path="hindi_embeddings_visualization.png"):
|
620 |
+
"""
|
621 |
+
Create a t-SNE visualization of the embeddings.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
sentences: List of sentences to visualize
|
625 |
+
labels: Optional list of labels for the points
|
626 |
+
output_path: Path to save the visualization
|
627 |
+
|
628 |
+
Returns:
|
629 |
+
Path to the saved visualization
|
630 |
+
"""
|
631 |
+
# Encode sentences
|
632 |
+
embeddings = self.encode(sentences)
|
633 |
+
|
634 |
+
# Apply t-SNE
|
635 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
|
636 |
+
reduced_embeddings = tsne.fit_transform(embeddings)
|
637 |
+
|
638 |
+
# Create plot
|
639 |
+
plt.figure(figsize=(12, 10))
|
640 |
+
|
641 |
+
# Plot points
|
642 |
+
scatter = plt.scatter(
|
643 |
+
reduced_embeddings[:, 0],
|
644 |
+
reduced_embeddings[:, 1],
|
645 |
+
c=range(len(reduced_embeddings)),
|
646 |
+
cmap='viridis',
|
647 |
+
alpha=0.8,
|
648 |
+
s=100
|
649 |
+
)
|
650 |
+
|
651 |
+
# Add labels if provided
|
652 |
+
if labels:
|
653 |
+
for i, label in enumerate(labels):
|
654 |
+
plt.annotate(
|
655 |
+
label,
|
656 |
+
(reduced_embeddings[i, 0], reduced_embeddings[i, 1]),
|
657 |
+
fontsize=10,
|
658 |
+
alpha=0.7
|
659 |
+
)
|
660 |
+
|
661 |
+
plt.title("t-SNE Visualization of Hindi Sentence Embeddings", fontsize=16)
|
662 |
+
plt.xlabel("Dimension 1", fontsize=12)
|
663 |
+
plt.ylabel("Dimension 2", fontsize=12)
|
664 |
+
plt.colorbar(scatter, label="Sentence Index")
|
665 |
+
plt.grid(alpha=0.3)
|
666 |
+
|
667 |
+
# Save the figure
|
668 |
+
plt.tight_layout()
|
669 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
670 |
+
plt.close()
|
671 |
+
|
672 |
+
print(f"Visualization saved to {output_path}")
|
673 |
+
return output_path
|
674 |
+
|
675 |
+
def main():
|
676 |
+
# Create embedder
|
677 |
+
embedder = HindiEmbedder()
|
678 |
+
|
679 |
+
# Run sample evaluation
|
680 |
+
embedder.evaluate_similarity_samples()
|
681 |
+
|
682 |
+
# Example of semantic search
|
683 |
+
print("\nSemantic Search Example:")
|
684 |
+
query = "भारत की संस्कृति"
|
685 |
+
documents = [
|
686 |
+
"भारतीय संस्कृति दुनिया की सबसे प्राचीन संस्कृतियों में से एक है।",
|
687 |
+
"भारत की आबादी 1.3 अरब से अधिक है।",
|
688 |
+
"हिमालय पर्वत श्रृंखला भारत के उत्तर में स्थित है।",
|
689 |
+
"भारतीय व्यंजन में मसालों का प्रयोग किया जाता है।",
|
690 |
+
"भारत में 22 आधिकारिक भाषाएँ हैं।",
|
691 |
+
"संस्कृति लोगों के रहन-सहन का तरीका है।",
|
692 |
+
"भारत के विभिन्न राज्यों की अपनी अलग संस्कृति है।",
|
693 |
+
"रामायण और महाभारत भारतीय संस्कृति के महत्वपूर्ण हिस्से हैं।",
|
694 |
+
]
|
695 |
+
|
696 |
+
results = embedder.search(query, documents)
|
697 |
+
|
698 |
+
print(f"Query: {query}")
|
699 |
+
print("Top results:")
|
700 |
+
for i, result in enumerate(results):
|
701 |
+
print(f"{i+1}. Score: {result['score']:.4f}")
|
702 |
+
print(f" {result['document']}")
|
703 |
+
|
704 |
+
# Create visualization example
|
705 |
+
print("\nCreating embedding visualization...")
|
706 |
+
visualization_sentences = [
|
707 |
+
"मुझे हिंदी में पढ़ना बहुत पसंद है।",
|
708 |
+
"मैं हिंदी किताबें बहुत पसंद करता हूँ।",
|
709 |
+
"आज मौसम बहुत अच्छा है।",
|
710 |
+
"आज बारिश हो रही है।",
|
711 |
+
"भारत एक विशाल देश है।",
|
712 |
+
"भारत में कई भाषाएँ बोली जाती हैं।",
|
713 |
+
"कंप्यूटर विज्ञान एक रोचक विषय है।",
|
714 |
+
"मैं कंप्यूटर साइंस का छात्र हूँ।",
|
715 |
+
"क्रिकेट भारत में सबसे लोकप्रिय खेल है।",
|
716 |
+
"भारतीय व्यंजन दुनिया भर में मशहूर हैं।",
|
717 |
+
"हिमालय दुनिया का सबसे ऊंचा पर्वत है।",
|
718 |
+
"गंगा भारत की सबसे पवित्र नदी है।",
|
719 |
+
"दिल्ली भारत की राजधानी है।",
|
720 |
+
"मुंबई भारत का आर्थिक केंद्र है।",
|
721 |
+
"तमिल, तेलुगु, कन्नड़ और मलयालम दक्षिण भारत की प्रमुख भाषाएँ हैं।"
|
722 |
+
]
|
723 |
+
|
724 |
+
labels = ["पढ़ना", "किताबें", "मौसम", "बारिश", "भारत", "भाषाएँ", "क��प्यूटर",
|
725 |
+
"छात्र", "क्रिकेट", "व्यंजन", "हिमालय", "गंगा", "दिल्ली", "मुंबई", "भाषाएँ"]
|
726 |
+
|
727 |
+
embedder.visualize_embeddings(visualization_sentences, labels)
|
728 |
+
|
729 |
+
if __name__ == "__main__":
|
730 |
+
main()
|
hindi_language_model.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import Optional, Tuple, List, Dict, Any, Union
|
6 |
+
|
7 |
+
class HindiCausalLMConfig:
|
8 |
+
"""Configuration class for Hindi Causal Language Model"""
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
vocab_size: int = 32000,
|
12 |
+
hidden_size: int = 768,
|
13 |
+
num_hidden_layers: int = 12,
|
14 |
+
num_attention_heads: int = 12,
|
15 |
+
intermediate_size: int = 3072,
|
16 |
+
hidden_dropout_prob: float = 0.1,
|
17 |
+
attention_probs_dropout_prob: float = 0.1,
|
18 |
+
max_position_embeddings: int = 512,
|
19 |
+
layer_norm_eps: float = 1e-12,
|
20 |
+
pad_token_id: int = 0,
|
21 |
+
bos_token_id: int = 1,
|
22 |
+
eos_token_id: int = 2,
|
23 |
+
tie_word_embeddings: bool = True,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
self.vocab_size = vocab_size
|
27 |
+
self.hidden_size = hidden_size
|
28 |
+
self.num_hidden_layers = num_hidden_layers
|
29 |
+
self.num_attention_heads = num_attention_heads
|
30 |
+
self.intermediate_size = intermediate_size
|
31 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
32 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
33 |
+
self.max_position_embeddings = max_position_embeddings
|
34 |
+
self.layer_norm_eps = layer_norm_eps
|
35 |
+
self.pad_token_id = pad_token_id
|
36 |
+
self.bos_token_id = bos_token_id
|
37 |
+
self.eos_token_id = eos_token_id
|
38 |
+
self.tie_word_embeddings = tie_word_embeddings
|
39 |
+
|
40 |
+
# Add any additional kwargs as attributes
|
41 |
+
for key, value in kwargs.items():
|
42 |
+
setattr(self, key, value)
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_embedding_config(cls, config_dict, **kwargs):
|
46 |
+
"""Create LM config from embedding model config"""
|
47 |
+
# Check if override parameters are provided
|
48 |
+
override_params = {}
|
49 |
+
for key in ["num_hidden_layers", "hidden_size", "num_attention_heads",
|
50 |
+
"intermediate_size", "max_position_embeddings", "vocab_size"]:
|
51 |
+
if key in kwargs:
|
52 |
+
override_params[key] = kwargs.pop(key)
|
53 |
+
|
54 |
+
# Get hidden size first to calculate appropriate number of attention heads
|
55 |
+
hidden_size = override_params.get("hidden_size", config_dict.get("hidden_size", 768))
|
56 |
+
|
57 |
+
# If num_attention_heads is not provided, choose a value that divides hidden_size evenly
|
58 |
+
if "num_attention_heads" not in override_params:
|
59 |
+
# Default options to try: 12, 16, 8, 4
|
60 |
+
for heads in [12, 16, 8, 4]:
|
61 |
+
if hidden_size % heads == 0:
|
62 |
+
print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}")
|
63 |
+
override_params["num_attention_heads"] = heads
|
64 |
+
break
|
65 |
+
|
66 |
+
# If none of the defaults work, find the largest factor <= 32
|
67 |
+
if "num_attention_heads" not in override_params:
|
68 |
+
# Find the largest factor of hidden_size that is <= 32
|
69 |
+
for heads in range(min(32, hidden_size), 0, -1):
|
70 |
+
if hidden_size % heads == 0:
|
71 |
+
print(f"Automatically setting num_attention_heads to {heads} to match hidden_size {hidden_size}")
|
72 |
+
override_params["num_attention_heads"] = heads
|
73 |
+
break
|
74 |
+
|
75 |
+
# Build the config, with overrides taking precedence
|
76 |
+
config_params = {
|
77 |
+
"vocab_size": override_params.get("vocab_size", config_dict.get("vocab_size", 32000)),
|
78 |
+
"hidden_size": hidden_size,
|
79 |
+
"num_hidden_layers": override_params.get("num_hidden_layers", config_dict.get("num_hidden_layers", 12)),
|
80 |
+
"num_attention_heads": override_params.get("num_attention_heads", config_dict.get("num_attention_heads", 12)),
|
81 |
+
"intermediate_size": override_params.get("intermediate_size", config_dict.get("intermediate_size", 3072)),
|
82 |
+
"hidden_dropout_prob": config_dict.get("hidden_dropout_prob", 0.1),
|
83 |
+
"attention_probs_dropout_prob": config_dict.get("attention_probs_dropout_prob", 0.1),
|
84 |
+
"max_position_embeddings": override_params.get("max_position_embeddings",
|
85 |
+
config_dict.get("max_position_embeddings", 512)),
|
86 |
+
"layer_norm_eps": config_dict.get("layer_norm_eps", 1e-12),
|
87 |
+
"pad_token_id": config_dict.get("pad_token_id", 0),
|
88 |
+
}
|
89 |
+
|
90 |
+
# Verify that hidden_size is divisible by num_attention_heads
|
91 |
+
if config_params["hidden_size"] % config_params["num_attention_heads"] != 0:
|
92 |
+
raise ValueError(
|
93 |
+
f"Hidden size ({config_params['hidden_size']}) must be divisible by the number of attention "
|
94 |
+
f"heads ({config_params['num_attention_heads']})"
|
95 |
+
)
|
96 |
+
|
97 |
+
# Add remaining kwargs
|
98 |
+
config_params.update(kwargs)
|
99 |
+
|
100 |
+
# Create and return the config
|
101 |
+
lm_config = cls(**config_params)
|
102 |
+
return lm_config
|
103 |
+
|
104 |
+
def to_dict(self):
|
105 |
+
"""Convert config to dictionary"""
|
106 |
+
return {k: v for k, v in self.__dict__.items()}
|
107 |
+
|
108 |
+
class CausalSelfAttention(nn.Module):
|
109 |
+
"""Causal self-attention layer"""
|
110 |
+
def __init__(self, config):
|
111 |
+
super().__init__()
|
112 |
+
assert config.hidden_size % config.num_attention_heads == 0
|
113 |
+
|
114 |
+
self.num_attention_heads = config.num_attention_heads
|
115 |
+
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
116 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
117 |
+
|
118 |
+
# Query, Key, Value projections
|
119 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
120 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
121 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
122 |
+
|
123 |
+
# Output projection
|
124 |
+
self.output = nn.Sequential(
|
125 |
+
nn.Linear(self.all_head_size, config.hidden_size),
|
126 |
+
nn.Dropout(config.attention_probs_dropout_prob)
|
127 |
+
)
|
128 |
+
|
129 |
+
# Causal mask to ensure that attention is only applied to the left in the input sequence
|
130 |
+
self.register_buffer(
|
131 |
+
"causal_mask",
|
132 |
+
torch.triu(
|
133 |
+
torch.ones(config.max_position_embeddings, config.max_position_embeddings) * -1e10,
|
134 |
+
diagonal=1
|
135 |
+
)
|
136 |
+
)
|
137 |
+
|
138 |
+
def transpose_for_scores(self, x):
|
139 |
+
# Reshape from [batch_size, seq_length, hidden_size] to [batch_size, seq_length, num_heads, head_size]
|
140 |
+
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
141 |
+
x = x.view(*new_shape)
|
142 |
+
# Transpose to [batch_size, num_heads, seq_length, head_size]
|
143 |
+
return x.permute(0, 2, 1, 3)
|
144 |
+
|
145 |
+
def forward(self, hidden_states, attention_mask=None):
|
146 |
+
batch_size, seq_length = hidden_states.size()[:2]
|
147 |
+
|
148 |
+
# Project inputs to queries, keys, and values
|
149 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
150 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
151 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
152 |
+
|
153 |
+
# Scale dot-product attention
|
154 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
155 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
156 |
+
|
157 |
+
# Apply causal mask - prevents attending to future tokens
|
158 |
+
causal_mask = self.causal_mask[:seq_length, :seq_length]
|
159 |
+
attention_scores = attention_scores + causal_mask
|
160 |
+
|
161 |
+
# Apply attention mask if provided
|
162 |
+
if attention_mask is not None:
|
163 |
+
# Expand mask to match attention_scores shape
|
164 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
165 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
166 |
+
attention_scores = attention_scores + attention_mask
|
167 |
+
|
168 |
+
# Softmax normalization
|
169 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
170 |
+
attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
|
171 |
+
|
172 |
+
# Apply attention to values
|
173 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
174 |
+
|
175 |
+
# Reshape back to [batch_size, seq_length, hidden_size]
|
176 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
177 |
+
new_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
178 |
+
context_layer = context_layer.view(*new_shape)
|
179 |
+
|
180 |
+
# Final output projection
|
181 |
+
output = self.output(context_layer)
|
182 |
+
|
183 |
+
return output
|
184 |
+
|
185 |
+
class TransformerBlock(nn.Module):
|
186 |
+
"""Transformer block with causal attention for language modeling"""
|
187 |
+
def __init__(self, config):
|
188 |
+
super().__init__()
|
189 |
+
self.attention = CausalSelfAttention(config)
|
190 |
+
self.attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
191 |
+
|
192 |
+
# Feed-forward network
|
193 |
+
self.ffn = nn.Sequential(
|
194 |
+
nn.Linear(config.hidden_size, config.intermediate_size),
|
195 |
+
nn.GELU(),
|
196 |
+
nn.Linear(config.intermediate_size, config.hidden_size),
|
197 |
+
nn.Dropout(config.hidden_dropout_prob)
|
198 |
+
)
|
199 |
+
self.ffn_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
200 |
+
|
201 |
+
def forward(self, hidden_states, attention_mask=None):
|
202 |
+
# Self-attention block with residual connection and layer norm
|
203 |
+
attn_output = self.attention(hidden_states, attention_mask)
|
204 |
+
hidden_states = self.attention_layernorm(hidden_states + attn_output)
|
205 |
+
|
206 |
+
# Feed-forward block with residual connection and layer norm
|
207 |
+
ffn_output = self.ffn(hidden_states)
|
208 |
+
hidden_states = self.ffn_layernorm(hidden_states + ffn_output)
|
209 |
+
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
class HindiCausalLM(nn.Module):
|
213 |
+
"""Hindi Causal Language Model for text generation"""
|
214 |
+
def __init__(self, config):
|
215 |
+
super().__init__()
|
216 |
+
self.config = config
|
217 |
+
|
218 |
+
# Embeddings
|
219 |
+
self.token_embeddings = nn.Embedding(
|
220 |
+
config.vocab_size,
|
221 |
+
config.hidden_size,
|
222 |
+
padding_idx=config.pad_token_id
|
223 |
+
)
|
224 |
+
self.position_embeddings = nn.Embedding(
|
225 |
+
config.max_position_embeddings,
|
226 |
+
config.hidden_size
|
227 |
+
)
|
228 |
+
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
|
229 |
+
|
230 |
+
# Transformer layers
|
231 |
+
self.layers = nn.ModuleList([
|
232 |
+
TransformerBlock(config) for _ in range(config.num_hidden_layers)
|
233 |
+
])
|
234 |
+
|
235 |
+
# LM head
|
236 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
237 |
+
|
238 |
+
# Tie weights if configured
|
239 |
+
if config.tie_word_embeddings:
|
240 |
+
self.lm_head.weight = self.token_embeddings.weight
|
241 |
+
|
242 |
+
# Initialize weights
|
243 |
+
self.apply(self._init_weights)
|
244 |
+
|
245 |
+
def _init_weights(self, module):
|
246 |
+
"""Initialize weights with small random values"""
|
247 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
248 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
249 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
250 |
+
module.bias.data.zero_()
|
251 |
+
elif isinstance(module, nn.LayerNorm):
|
252 |
+
module.bias.data.zero_()
|
253 |
+
module.weight.data.fill_(1.0)
|
254 |
+
|
255 |
+
def get_input_embeddings(self):
|
256 |
+
return self.token_embeddings
|
257 |
+
|
258 |
+
def set_input_embeddings(self, new_embeddings):
|
259 |
+
self.token_embeddings = new_embeddings
|
260 |
+
|
261 |
+
def forward(
|
262 |
+
self,
|
263 |
+
input_ids=None,
|
264 |
+
attention_mask=None,
|
265 |
+
labels=None,
|
266 |
+
return_dict=True
|
267 |
+
):
|
268 |
+
device = input_ids.device
|
269 |
+
batch_size, seq_length = input_ids.size()
|
270 |
+
|
271 |
+
# Create position ids
|
272 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
273 |
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
274 |
+
|
275 |
+
# Get embeddings
|
276 |
+
inputs_embeds = self.token_embeddings(input_ids)
|
277 |
+
position_embeds = self.position_embeddings(position_ids)
|
278 |
+
|
279 |
+
# Sum token and position embeddings
|
280 |
+
hidden_states = inputs_embeds + position_embeds
|
281 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
282 |
+
|
283 |
+
# Default attention mask (all tokens can be attended to)
|
284 |
+
if attention_mask is None:
|
285 |
+
attention_mask = torch.ones(batch_size, seq_length, device=device)
|
286 |
+
|
287 |
+
# Apply transformer layers
|
288 |
+
for layer in self.layers:
|
289 |
+
hidden_states = layer(hidden_states, attention_mask)
|
290 |
+
|
291 |
+
# Language model head
|
292 |
+
lm_logits = self.lm_head(hidden_states)
|
293 |
+
|
294 |
+
loss = None
|
295 |
+
if labels is not None:
|
296 |
+
# Move labels to correct device
|
297 |
+
labels = labels.to(device)
|
298 |
+
|
299 |
+
# Shift so that tokens < n predict n
|
300 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
301 |
+
shift_labels = labels[..., 1:].contiguous()
|
302 |
+
|
303 |
+
# Flatten the tokens
|
304 |
+
loss_fct = nn.CrossEntropyLoss()
|
305 |
+
loss = loss_fct(
|
306 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
307 |
+
shift_labels.view(-1)
|
308 |
+
)
|
309 |
+
|
310 |
+
if return_dict:
|
311 |
+
return {
|
312 |
+
"logits": lm_logits,
|
313 |
+
"loss": loss,
|
314 |
+
"hidden_states": hidden_states
|
315 |
+
}
|
316 |
+
|
317 |
+
return (lm_logits, loss, hidden_states)
|
318 |
+
|
319 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
|
320 |
+
"""Prepare inputs for text generation"""
|
321 |
+
# Only keep inputs needed for forward pass
|
322 |
+
inputs = {
|
323 |
+
"input_ids": input_ids,
|
324 |
+
}
|
325 |
+
|
326 |
+
# Add attention mask if provided
|
327 |
+
if attention_mask is not None:
|
328 |
+
inputs["attention_mask"] = attention_mask
|
329 |
+
|
330 |
+
return inputs
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def _reorder_cache(past, beam_idx):
|
334 |
+
"""Reorder cached hidden states for beam search generation"""
|
335 |
+
reordered_past = []
|
336 |
+
for layer_past in past:
|
337 |
+
reordered_past.append(
|
338 |
+
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
339 |
+
)
|
340 |
+
return reordered_past
|
341 |
+
|
342 |
+
def save_pretrained(self, save_directory):
|
343 |
+
"""Save model and config to directory"""
|
344 |
+
import os
|
345 |
+
import json
|
346 |
+
import torch
|
347 |
+
|
348 |
+
os.makedirs(save_directory, exist_ok=True)
|
349 |
+
|
350 |
+
# Save config
|
351 |
+
config_path = os.path.join(save_directory, "config.json")
|
352 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
353 |
+
json.dump(self.config.to_dict(), f, indent=2)
|
354 |
+
|
355 |
+
# Save model weights
|
356 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
357 |
+
torch.save(self.state_dict(), model_path)
|
358 |
+
|
359 |
+
return [config_path, model_path]
|
360 |
+
|
361 |
+
@classmethod
|
362 |
+
def from_pretrained(cls, model_path):
|
363 |
+
"""Load model and config from directory"""
|
364 |
+
import os
|
365 |
+
import json
|
366 |
+
import torch
|
367 |
+
|
368 |
+
# Load config
|
369 |
+
config_path = os.path.join(model_path, "config.json")
|
370 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
371 |
+
config_dict = json.load(f)
|
372 |
+
|
373 |
+
# Create config object
|
374 |
+
config = HindiCausalLMConfig(**config_dict)
|
375 |
+
|
376 |
+
# Create model
|
377 |
+
model = cls(config)
|
378 |
+
|
379 |
+
# Load model weights
|
380 |
+
model_path = os.path.join(model_path, "pytorch_model.bin")
|
381 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
382 |
+
|
383 |
+
return model
|
384 |
+
|
385 |
+
class HindiTextGenerator:
|
386 |
+
"""Text generation utility for HindiCausalLM"""
|
387 |
+
def __init__(self, model, tokenizer):
|
388 |
+
self.model = model
|
389 |
+
self.tokenizer = tokenizer
|
390 |
+
self.device = next(model.parameters()).device
|
391 |
+
|
392 |
+
def generate(
|
393 |
+
self,
|
394 |
+
prompt,
|
395 |
+
max_length=100,
|
396 |
+
temperature=1.0,
|
397 |
+
top_k=50,
|
398 |
+
top_p=0.95,
|
399 |
+
repetition_penalty=1.0,
|
400 |
+
do_sample=True,
|
401 |
+
num_return_sequences=1,
|
402 |
+
**kwargs
|
403 |
+
):
|
404 |
+
"""Generate text from a prompt"""
|
405 |
+
# Encode the prompt
|
406 |
+
input_ids = self.tokenizer(
|
407 |
+
prompt,
|
408 |
+
return_tensors="pt",
|
409 |
+
truncation=True,
|
410 |
+
max_length=self.model.config.max_position_embeddings - max_length
|
411 |
+
)["input_ids"].to(self.device)
|
412 |
+
|
413 |
+
# Set the model to evaluation mode
|
414 |
+
self.model.eval()
|
415 |
+
|
416 |
+
# Set generation parameters
|
417 |
+
gen_kwargs = {
|
418 |
+
"max_length": input_ids.shape[1] + max_length,
|
419 |
+
"temperature": temperature,
|
420 |
+
"top_k": top_k,
|
421 |
+
"top_p": top_p,
|
422 |
+
"repetition_penalty": repetition_penalty,
|
423 |
+
"do_sample": do_sample,
|
424 |
+
"num_return_sequences": num_return_sequences,
|
425 |
+
**kwargs
|
426 |
+
}
|
427 |
+
|
428 |
+
# Generate text
|
429 |
+
with torch.no_grad():
|
430 |
+
output_sequences = self._generate_text(input_ids, **gen_kwargs)
|
431 |
+
|
432 |
+
# Decode generated sequences
|
433 |
+
generated_texts = []
|
434 |
+
for sequence in output_sequences:
|
435 |
+
# Remove the prompt from the generated text
|
436 |
+
sequence = sequence[input_ids.shape[1]:]
|
437 |
+
text = self.tokenizer.sp_model.DecodeIds(sequence.tolist())
|
438 |
+
generated_texts.append(text)
|
439 |
+
|
440 |
+
if num_return_sequences == 1:
|
441 |
+
return generated_texts[0]
|
442 |
+
|
443 |
+
return generated_texts
|
444 |
+
|
445 |
+
def _generate_text(
|
446 |
+
self,
|
447 |
+
input_ids,
|
448 |
+
max_length,
|
449 |
+
temperature=1.0,
|
450 |
+
top_k=50,
|
451 |
+
top_p=0.95,
|
452 |
+
repetition_penalty=1.0,
|
453 |
+
do_sample=True,
|
454 |
+
num_return_sequences=1,
|
455 |
+
pad_token_id=None,
|
456 |
+
eos_token_id=None,
|
457 |
+
**kwargs
|
458 |
+
):
|
459 |
+
"""Core text generation logic"""
|
460 |
+
# Set pad_token_id and eos_token_id
|
461 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.model.config.pad_token_id
|
462 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.model.config.eos_token_id
|
463 |
+
|
464 |
+
batch_size = input_ids.shape[0]
|
465 |
+
|
466 |
+
# Create attention mask
|
467 |
+
attention_mask = torch.ones_like(input_ids)
|
468 |
+
|
469 |
+
# Clone the input_ids for each return sequence
|
470 |
+
input_ids = input_ids.repeat(num_return_sequences, 1)
|
471 |
+
attention_mask = attention_mask.repeat(num_return_sequences, 1)
|
472 |
+
|
473 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
474 |
+
|
475 |
+
# Keep track of which sequences are already finished
|
476 |
+
cur_len = input_ids.shape[1]
|
477 |
+
|
478 |
+
while cur_len < max_length:
|
479 |
+
# Prepare model inputs
|
480 |
+
model_inputs = self.model.prepare_inputs_for_generation(
|
481 |
+
input_ids, attention_mask=attention_mask
|
482 |
+
)
|
483 |
+
|
484 |
+
# Forward pass
|
485 |
+
outputs = self.model(**model_inputs, return_dict=True)
|
486 |
+
next_token_logits = outputs["logits"][:, -1, :]
|
487 |
+
|
488 |
+
# Apply temperature scaling
|
489 |
+
next_token_logits = next_token_logits / temperature
|
490 |
+
|
491 |
+
# Apply repetition penalty
|
492 |
+
if repetition_penalty != 1.0:
|
493 |
+
for i in range(input_ids.shape[0]):
|
494 |
+
for token_id in set(input_ids[i].tolist()):
|
495 |
+
next_token_logits[i, token_id] /= repetition_penalty
|
496 |
+
|
497 |
+
# Filter logits using top-k and top-p sampling
|
498 |
+
if do_sample:
|
499 |
+
# Top-k filtering
|
500 |
+
if top_k > 0:
|
501 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
502 |
+
next_token_logits[indices_to_remove] = -float("Inf")
|
503 |
+
|
504 |
+
# Top-p (nucleus) filtering
|
505 |
+
if top_p < 1.0:
|
506 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
507 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
508 |
+
|
509 |
+
# Remove tokens with cumulative probability above the threshold
|
510 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
511 |
+
# Shift the indices to the right to keep the first token above threshold
|
512 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
513 |
+
sorted_indices_to_remove[..., 0] = 0
|
514 |
+
|
515 |
+
for i in range(next_token_logits.shape[0]):
|
516 |
+
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
517 |
+
next_token_logits[i, indices_to_remove] = -float("Inf")
|
518 |
+
|
519 |
+
# Sample from the filtered distribution
|
520 |
+
probabilities = F.softmax(next_token_logits, dim=-1)
|
521 |
+
next_tokens = torch.multinomial(probabilities, 1).squeeze(1)
|
522 |
+
else:
|
523 |
+
# Greedy decoding
|
524 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
525 |
+
|
526 |
+
# Update unfinished sequences based on EOS token
|
527 |
+
if eos_token_id is not None:
|
528 |
+
# Set the unfinished flag to 0 if the sequence reaches EOS
|
529 |
+
unfinished_sequences = unfinished_sequences.mul(
|
530 |
+
next_tokens.ne(eos_token_id).long()
|
531 |
+
)
|
532 |
+
|
533 |
+
# Check if all sequences are finished
|
534 |
+
if unfinished_sequences.max() == 0:
|
535 |
+
break
|
536 |
+
|
537 |
+
# Concatenate next tokens to input_ids
|
538 |
+
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
539 |
+
|
540 |
+
# Expand attention mask
|
541 |
+
attention_mask = torch.cat(
|
542 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
543 |
+
)
|
544 |
+
|
545 |
+
cur_len += 1
|
546 |
+
|
547 |
+
return input_ids
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ebf86827285c98325d2434abc44fd95f79fcc56161e3b7019faa49799e9eb7c
|
3 |
+
size 452698216
|
step_loss_lr.png
ADDED
![]() |
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1658727f4ea5c571f69a60f6defcd180014a1d69be6e9c1ec360d9510aa6b5e
|
3 |
+
size 642200
|
tokenizer_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"tokenizer_class": "SentencePieceTokenizerWrapper",
|
3 |
+
"vocab_size": 16000,
|
4 |
+
"bos_token_id": 1,
|
5 |
+
"eos_token_id": 2,
|
6 |
+
"pad_token_id": 0,
|
7 |
+
"unk_token_id": 3,
|
8 |
+
"model_max_length": 512
|
9 |
+
}
|