DeepMostInnovations commited on
Commit
4296bdb
·
verified ·
1 Parent(s): 269183c

Upload Hindi embeddings model and all associated files

Browse files
Files changed (2) hide show
  1. hindi-rag-system.py +881 -0
  2. hindi-rag-system.py.amltmp +881 -0
hindi-rag-system.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import argparse
5
+ import numpy as np
6
+ import re
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ import sentencepiece as spm
10
+ import math
11
+ from safetensors.torch import save_file, load_file
12
+ from tqdm import tqdm
13
+ import faiss
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.vectorstores import FAISS as LangchainFAISS
16
+ from langchain.docstore.document import Document
17
+ from langchain.embeddings.base import Embeddings
18
+ from typing import List, Dict, Any, Optional, Callable
19
+
20
+ # Tokenizer wrapper class - same as in original code
21
+ class SentencePieceTokenizerWrapper:
22
+ def __init__(self, sp_model_path):
23
+ self.sp_model = spm.SentencePieceProcessor()
24
+ self.sp_model.Load(sp_model_path)
25
+ self.vocab_size = self.sp_model.GetPieceSize()
26
+
27
+ # Special token IDs from tokenizer training
28
+ self.pad_token_id = 0
29
+ self.bos_token_id = 1
30
+ self.eos_token_id = 2
31
+ self.unk_token_id = 3
32
+
33
+ # Set special tokens
34
+ self.pad_token = "<pad>"
35
+ self.bos_token = "<s>"
36
+ self.eos_token = "</s>"
37
+ self.unk_token = "<unk>"
38
+ self.mask_token = "<mask>"
39
+
40
+ def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None):
41
+ # Handle both string and list inputs
42
+ if isinstance(text, str):
43
+ # Encode a single string
44
+ ids = self.sp_model.EncodeAsIds(text)
45
+
46
+ # Handle truncation
47
+ if truncation and max_length and len(ids) > max_length:
48
+ ids = ids[:max_length]
49
+
50
+ attention_mask = [1] * len(ids)
51
+
52
+ # Handle padding
53
+ if padding and max_length:
54
+ padding_length = max(0, max_length - len(ids))
55
+ ids = ids + [self.pad_token_id] * padding_length
56
+ attention_mask = attention_mask + [0] * padding_length
57
+
58
+ result = {
59
+ 'input_ids': ids,
60
+ 'attention_mask': attention_mask
61
+ }
62
+
63
+ # Convert to tensors if requested
64
+ if return_tensors == 'pt':
65
+ import torch
66
+ result = {k: torch.tensor([v]) for k, v in result.items()}
67
+
68
+ return result
69
+
70
+ # Process a batch of texts
71
+ batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text]
72
+
73
+ # Apply truncation if needed
74
+ if truncation and max_length:
75
+ batch_encoded = [ids[:max_length] for ids in batch_encoded]
76
+
77
+ # Create attention masks
78
+ batch_attention_mask = [[1] * len(ids) for ids in batch_encoded]
79
+
80
+ # Apply padding if needed
81
+ if padding:
82
+ if max_length:
83
+ max_len = max_length
84
+ else:
85
+ max_len = max(len(ids) for ids in batch_encoded)
86
+
87
+ # Pad sequences to max_len
88
+ batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded]
89
+ batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask]
90
+
91
+ result = {
92
+ 'input_ids': batch_encoded,
93
+ 'attention_mask': batch_attention_mask
94
+ }
95
+
96
+ # Convert to tensors if requested
97
+ if return_tensors == 'pt':
98
+ import torch
99
+ result = {k: torch.tensor(v) for k, v in result.items()}
100
+
101
+ return result
102
+
103
+ # Model architecture definitions for inference
104
+
105
+ class MultiHeadAttention(nn.Module):
106
+ """Advanced multi-headed attention with relative positional encoding"""
107
+ def __init__(self, config):
108
+ super().__init__()
109
+ self.num_attention_heads = config["num_attention_heads"]
110
+ self.attention_head_size = config["hidden_size"] // config["num_attention_heads"]
111
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
112
+
113
+ # Query, Key, Value projections
114
+ self.query = nn.Linear(config["hidden_size"], self.all_head_size)
115
+ self.key = nn.Linear(config["hidden_size"], self.all_head_size)
116
+ self.value = nn.Linear(config["hidden_size"], self.all_head_size)
117
+
118
+ # Output projection
119
+ self.output = nn.Sequential(
120
+ nn.Linear(self.all_head_size, config["hidden_size"]),
121
+ nn.Dropout(config["attention_probs_dropout_prob"])
122
+ )
123
+
124
+ # Simplified relative position bias approach
125
+ self.max_position_embeddings = config["max_position_embeddings"]
126
+ self.relative_attention_bias = nn.Embedding(
127
+ 2 * config["max_position_embeddings"] - 1,
128
+ config["num_attention_heads"]
129
+ )
130
+
131
+ def transpose_for_scores(self, x):
132
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
133
+ x = x.view(*new_shape)
134
+ return x.permute(0, 2, 1, 3)
135
+
136
+ def forward(self, hidden_states, attention_mask=None):
137
+ batch_size, seq_length = hidden_states.size()[:2]
138
+
139
+ # Project inputs to queries, keys, and values
140
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
141
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
142
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
143
+
144
+ # Take the dot product between query and key to get the raw attention scores
145
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
146
+
147
+ # Generate relative position matrix
148
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device)
149
+ relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0) # [seq_len, seq_len]
150
+ # Shift values to be >= 0
151
+ relative_position = relative_position + self.max_position_embeddings - 1
152
+ # Ensure indices are within bounds
153
+ relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2)
154
+
155
+ # Get relative position embeddings [seq_len, seq_len, num_heads]
156
+ rel_attn_bias = self.relative_attention_bias(relative_position) # [seq_len, seq_len, num_heads]
157
+
158
+ # Reshape to add to attention heads [1, num_heads, seq_len, seq_len]
159
+ rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0)
160
+
161
+ # Add to attention scores - now dimensions will match
162
+ attention_scores = attention_scores + rel_attn_bias
163
+
164
+ # Scale attention scores
165
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
166
+
167
+ # Apply attention mask
168
+ if attention_mask is not None:
169
+ attention_scores = attention_scores + attention_mask
170
+
171
+ # Normalize the attention scores to probabilities
172
+ attention_probs = F.softmax(attention_scores, dim=-1)
173
+
174
+ # Apply dropout
175
+ attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
176
+
177
+ # Apply attention to values
178
+ context_layer = torch.matmul(attention_probs, value_layer)
179
+
180
+ # Reshape back to [batch_size, seq_length, hidden_size]
181
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
182
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
183
+ context_layer = context_layer.view(*new_shape)
184
+
185
+ # Final output projection
186
+ output = self.output(context_layer)
187
+
188
+ return output
189
+
190
+ class EnhancedTransformerLayer(nn.Module):
191
+ """Advanced transformer layer with pre-layer norm and enhanced attention"""
192
+ def __init__(self, config):
193
+ super().__init__()
194
+ self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
195
+ self.attention = MultiHeadAttention(config)
196
+
197
+ self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
198
+
199
+ # Feed-forward network
200
+ self.ffn = nn.Sequential(
201
+ nn.Linear(config["hidden_size"], config["intermediate_size"]),
202
+ nn.GELU(),
203
+ nn.Dropout(config["hidden_dropout_prob"]),
204
+ nn.Linear(config["intermediate_size"], config["hidden_size"]),
205
+ nn.Dropout(config["hidden_dropout_prob"])
206
+ )
207
+
208
+ def forward(self, hidden_states, attention_mask=None):
209
+ # Pre-layer norm for attention
210
+ attn_norm_hidden = self.attention_pre_norm(hidden_states)
211
+
212
+ # Self-attention
213
+ attention_output = self.attention(attn_norm_hidden, attention_mask)
214
+
215
+ # Residual connection
216
+ hidden_states = hidden_states + attention_output
217
+
218
+ # Pre-layer norm for feed-forward
219
+ ffn_norm_hidden = self.ffn_pre_norm(hidden_states)
220
+
221
+ # Feed-forward
222
+ ffn_output = self.ffn(ffn_norm_hidden)
223
+
224
+ # Residual connection
225
+ hidden_states = hidden_states + ffn_output
226
+
227
+ return hidden_states
228
+
229
+ class AdvancedTransformerModel(nn.Module):
230
+ """Advanced Transformer model for inference"""
231
+
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.config = config
235
+
236
+ # Embeddings
237
+ self.word_embeddings = nn.Embedding(
238
+ config["vocab_size"],
239
+ config["hidden_size"],
240
+ padding_idx=config["pad_token_id"]
241
+ )
242
+
243
+ # Position embeddings
244
+ self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"])
245
+
246
+ # Embedding dropout
247
+ self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"])
248
+
249
+ # Transformer layers
250
+ self.layers = nn.ModuleList([
251
+ EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"])
252
+ ])
253
+
254
+ # Final layer norm
255
+ self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
256
+
257
+ def forward(self, input_ids, attention_mask=None):
258
+ input_shape = input_ids.size()
259
+ batch_size, seq_length = input_shape
260
+
261
+ # Get position ids
262
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
263
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
264
+
265
+ # Get embeddings
266
+ word_embeds = self.word_embeddings(input_ids)
267
+ position_embeds = self.position_embeddings(position_ids)
268
+
269
+ # Sum embeddings
270
+ embeddings = word_embeds + position_embeds
271
+
272
+ # Apply dropout
273
+ embeddings = self.embedding_dropout(embeddings)
274
+
275
+ # Default attention mask
276
+ if attention_mask is None:
277
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
278
+
279
+ # Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens)
280
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
281
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
282
+
283
+ # Apply transformer layers
284
+ hidden_states = embeddings
285
+ for layer in self.layers:
286
+ hidden_states = layer(hidden_states, extended_attention_mask)
287
+
288
+ # Final layer norm
289
+ hidden_states = self.final_layer_norm(hidden_states)
290
+
291
+ return hidden_states
292
+
293
+ class AdvancedPooling(nn.Module):
294
+ """Advanced pooling module supporting multiple pooling strategies"""
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ self.pooling_mode = config["pooling_mode"] # 'mean', 'max', 'cls', 'attention'
298
+ self.hidden_size = config["hidden_size"]
299
+
300
+ # For attention pooling
301
+ if self.pooling_mode == 'attention':
302
+ self.attention_weights = nn.Linear(config["hidden_size"], 1)
303
+
304
+ # For weighted pooling
305
+ elif self.pooling_mode == 'weighted':
306
+ self.weight_layer = nn.Linear(config["hidden_size"], 1)
307
+
308
+ def forward(self, token_embeddings, attention_mask=None):
309
+ if attention_mask is None:
310
+ attention_mask = torch.ones_like(token_embeddings[:, :, 0])
311
+
312
+ mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
313
+
314
+ if self.pooling_mode == 'cls':
315
+ # Use [CLS] token (first token)
316
+ pooled = token_embeddings[:, 0]
317
+
318
+ elif self.pooling_mode == 'max':
319
+ # Max pooling
320
+ token_embeddings = token_embeddings.clone()
321
+ # Set padding tokens to large negative value to exclude them from max
322
+ token_embeddings[mask_expanded == 0] = -1e9
323
+ pooled = torch.max(token_embeddings, dim=1)[0]
324
+
325
+ elif self.pooling_mode == 'attention':
326
+ # Attention pooling
327
+ weights = self.attention_weights(token_embeddings).squeeze(-1)
328
+ # Mask out padding tokens
329
+ weights = weights.masked_fill(attention_mask == 0, -1e9)
330
+ weights = F.softmax(weights, dim=1).unsqueeze(-1)
331
+ pooled = torch.sum(token_embeddings * weights, dim=1)
332
+
333
+ elif self.pooling_mode == 'weighted':
334
+ # Weighted average pooling
335
+ weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1)
336
+ # Apply mask
337
+ weights = weights * attention_mask
338
+ # Normalize weights
339
+ sum_weights = torch.sum(weights, dim=1, keepdim=True)
340
+ sum_weights = torch.clamp(sum_weights, min=1e-9)
341
+ weights = weights / sum_weights
342
+ # Apply weights
343
+ pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1)
344
+
345
+ else: # Default to mean pooling
346
+ # Mean pooling
347
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
348
+ sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
349
+ pooled = sum_embeddings / sum_mask
350
+
351
+ # L2 normalize
352
+ pooled = F.normalize(pooled, p=2, dim=1)
353
+
354
+ return pooled
355
+
356
+ class SentenceEmbeddingModel(nn.Module):
357
+ """Complete sentence embedding model for inference"""
358
+ def __init__(self, config):
359
+ super(SentenceEmbeddingModel, self).__init__()
360
+ self.config = config
361
+
362
+ # Create transformer model
363
+ self.transformer = AdvancedTransformerModel(config)
364
+
365
+ # Create pooling module
366
+ self.pooling = AdvancedPooling(config)
367
+
368
+ # Build projection module if needed
369
+ if "projection_dim" in config and config["projection_dim"] > 0:
370
+ self.use_projection = True
371
+ self.projection = nn.Sequential(
372
+ nn.Linear(config["hidden_size"], config["hidden_size"]),
373
+ nn.GELU(),
374
+ nn.Linear(config["hidden_size"], config["projection_dim"]),
375
+ nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"])
376
+ )
377
+ else:
378
+ self.use_projection = False
379
+
380
+ def forward(self, input_ids, attention_mask=None):
381
+ # Get token embeddings from transformer
382
+ token_embeddings = self.transformer(input_ids, attention_mask)
383
+
384
+ # Pool token embeddings
385
+ pooled_output = self.pooling(token_embeddings, attention_mask)
386
+
387
+ # Apply projection if enabled
388
+ if self.use_projection:
389
+ pooled_output = self.projection(pooled_output)
390
+ pooled_output = F.normalize(pooled_output, p=2, dim=1)
391
+
392
+ return pooled_output
393
+
394
+ def convert_to_safetensors(model_path, output_path):
395
+ """Convert PyTorch model to safetensors format"""
396
+ print(f"Converting model from {model_path} to safetensors format...")
397
+
398
+ try:
399
+ # First try with weights_only=False to handle PyTorch 2.6+ checkpoints
400
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
401
+ print("Successfully loaded checkpoint with weights_only=False")
402
+ except TypeError:
403
+ # For older PyTorch versions that don't have weights_only parameter
404
+ print("Falling back to default torch.load behavior for older PyTorch versions")
405
+ checkpoint = torch.load(model_path, map_location="cpu")
406
+
407
+ # Get model state dict
408
+ if "model_state_dict" in checkpoint:
409
+ state_dict = checkpoint["model_state_dict"]
410
+ print("Extracted model_state_dict from checkpoint")
411
+ else:
412
+ state_dict = checkpoint
413
+ print("Using entire checkpoint as state_dict")
414
+
415
+ # Save as safetensors
416
+ save_file(state_dict, output_path)
417
+ print(f"Model converted and saved to {output_path}")
418
+
419
+ def load_model_and_tokenizer(model_dir, tokenizer_dir="/home/ubuntu/hindi_tokenizer"):
420
+ """Load the model and tokenizer for inference"""
421
+
422
+ # Load the config
423
+ config_path = os.path.join(model_dir, "config.json")
424
+ with open(config_path, "r") as f:
425
+ config = json.load(f)
426
+
427
+ # Load the tokenizer - use specified tokenizer directory
428
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.model")
429
+ if not os.path.exists(tokenizer_path):
430
+ # Try other locations
431
+ tokenizer_path = os.path.join(model_dir, "tokenizer.model")
432
+ if not os.path.exists(tokenizer_path):
433
+ raise FileNotFoundError(f"Could not find tokenizer model at {tokenizer_path}")
434
+
435
+ tokenizer = SentencePieceTokenizerWrapper(tokenizer_path)
436
+ print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {tokenizer.vocab_size}")
437
+
438
+ # Load the model
439
+ safetensors_path = os.path.join(model_dir, "embedding_model.safetensors")
440
+
441
+ if not os.path.exists(safetensors_path):
442
+ print(f"Safetensors model not found at {safetensors_path}, converting from PyTorch checkpoint...")
443
+
444
+ # Convert from PyTorch checkpoint
445
+ pytorch_path = os.path.join(model_dir, "embedding_model.pt")
446
+ if not os.path.exists(pytorch_path):
447
+ raise FileNotFoundError(f"Could not find PyTorch model at {pytorch_path}")
448
+
449
+ convert_to_safetensors(pytorch_path, safetensors_path)
450
+
451
+ # Load state dict from safetensors
452
+ state_dict = load_file(safetensors_path)
453
+
454
+ # Create model
455
+ model = SentenceEmbeddingModel(config)
456
+
457
+ # Load state dict
458
+ try:
459
+ # Try direct loading
460
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
461
+ print(f"Loaded model with missing keys: {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}")
462
+ print(f"Unexpected keys: {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}")
463
+ except Exception as e:
464
+ print(f"Error loading state dict: {e}")
465
+ print("Model will be initialized with random weights")
466
+
467
+ model.eval()
468
+
469
+ return model, tokenizer, config
470
+
471
+ # LangChain Custom Embeddings Class
472
+ class HindiSentenceEmbeddings(Embeddings):
473
+ """
474
+ Custom Langchain Embeddings class for Hindi sentence embeddings model
475
+ """
476
+ def __init__(self, model, tokenizer, device="cuda", batch_size=32, max_length=128):
477
+ """Initialize with model, tokenizer, and inference parameters"""
478
+ self.model = model
479
+ self.tokenizer = tokenizer
480
+ self.device = device
481
+ self.batch_size = batch_size
482
+ self.max_length = max_length
483
+
484
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
485
+ """Embed a list of documents/texts"""
486
+ embeddings = []
487
+
488
+ with torch.no_grad():
489
+ for i in range(0, len(texts), self.batch_size):
490
+ batch = texts[i:i+self.batch_size]
491
+
492
+ # Tokenize
493
+ inputs = self.tokenizer(
494
+ batch,
495
+ padding="max_length",
496
+ truncation=True,
497
+ max_length=self.max_length,
498
+ return_tensors="pt"
499
+ )
500
+
501
+ # Move to device
502
+ input_ids = inputs["input_ids"].to(self.device)
503
+ attention_mask = inputs["attention_mask"].to(self.device)
504
+
505
+ # Get embeddings
506
+ batch_embeddings = self.model(input_ids, attention_mask)
507
+
508
+ # Move to CPU and convert to numpy
509
+ batch_embeddings = batch_embeddings.cpu().numpy()
510
+ embeddings.append(batch_embeddings)
511
+
512
+ return np.vstack(embeddings).tolist()
513
+
514
+ def embed_query(self, text: str) -> List[float]:
515
+ """Embed a single query/text"""
516
+ return self.embed_documents([text])[0]
517
+
518
+ def extract_relevant_sentences(text, query, window_size=2):
519
+ """
520
+ Extract the most relevant sentences from text based on query keywords
521
+
522
+ Args:
523
+ text: The full text content
524
+ query: The user's query
525
+ window_size: Number of sentences to include before and after matched sentence
526
+
527
+ Returns:
528
+ String containing the most relevant portion of the text
529
+ """
530
+ # Clean and normalize query and text for matching
531
+ query = query.strip().lower()
532
+
533
+ # Remove question marks and other punctuation from query for matching
534
+ query = re.sub(r'[?।॥!,.:]', '', query)
535
+
536
+ # Extract keywords from the query (remove common Hindi stop words)
537
+ stop_words = ['और', 'का', 'के', 'को', 'में', 'से', 'है', 'हैं', 'था', 'थे', 'की', 'कि', 'पर', 'एक', 'यह', 'वह', 'जो', 'ने', 'हो', 'कर']
538
+ query_terms = [word for word in query.split() if word not in stop_words]
539
+
540
+ if not query_terms:
541
+ return text # If no meaningful terms left, return the full text
542
+
543
+ # Split text into sentences (using Hindi sentence terminators)
544
+ sentences = re.split(r'([।॥!?.])', text)
545
+
546
+ # Rejoin sentences with their terminators
547
+ complete_sentences = []
548
+ for i in range(0, len(sentences)-1, 2):
549
+ if i+1 < len(sentences):
550
+ complete_sentences.append(sentences[i] + sentences[i+1])
551
+ else:
552
+ complete_sentences.append(sentences[i])
553
+
554
+ # If the above didn't work properly, try simpler approach
555
+ if len(complete_sentences) <= 1:
556
+ complete_sentences = re.split(r'[।॥!?.]', text)
557
+ complete_sentences = [s.strip() for s in complete_sentences if s.strip()]
558
+
559
+ # Score each sentence based on how many query terms it contains
560
+ sentence_scores = []
561
+ for i, sentence in enumerate(complete_sentences):
562
+ sentence_lower = sentence.lower()
563
+ # Calculate score based on number of query terms found
564
+ score = sum(1 for term in query_terms if term in sentence_lower)
565
+ sentence_scores.append((i, score))
566
+
567
+ # Find the best matching sentence
568
+ if not sentence_scores:
569
+ return text[:500] + "..." # Fallback
570
+
571
+ # Get the index of sentence with highest score
572
+ best_match_idx, best_score = max(sentence_scores, key=lambda x: x[1])
573
+
574
+ # If no good match found, return the whole text (up to a limit)
575
+ if best_score == 0:
576
+ # Try partial word matching as a fallback
577
+ for i, sentence in enumerate(complete_sentences):
578
+ sentence_lower = sentence.lower()
579
+ partial_score = sum(1 for term in query_terms if any(term in word.lower() for word in sentence_lower.split()))
580
+ if partial_score > 0:
581
+ best_match_idx = i
582
+ break
583
+ else:
584
+ # If still no match, just return the first part of the text
585
+ if len(text) > 1000:
586
+ return text[:1000] + "..."
587
+ return text
588
+
589
+ # Get window of sentences around the best match
590
+ start_idx = max(0, best_match_idx - window_size)
591
+ end_idx = min(len(complete_sentences), best_match_idx + window_size + 1)
592
+
593
+ # Create excerpt
594
+ relevant_text = ' '.join(complete_sentences[start_idx:end_idx])
595
+
596
+ # If the excerpt is short, return more context
597
+ if len(relevant_text) < 100 and len(text) > len(relevant_text):
598
+ # Add more context
599
+ if end_idx < len(complete_sentences):
600
+ relevant_text += ' ' + ' '.join(complete_sentences[end_idx:end_idx+2])
601
+ if start_idx > 0:
602
+ relevant_text = ' '.join(complete_sentences[max(0, start_idx-2):start_idx]) + ' ' + relevant_text
603
+
604
+ # If the excerpt is too short or the whole text is small anyway, return whole text
605
+ if len(relevant_text) < 50 or len(text) < 1000:
606
+ return text
607
+
608
+ return relevant_text
609
+
610
+ # Text processing and indexing functions
611
+ def load_and_process_text_file(file_path, chunk_size=500, chunk_overlap=100):
612
+ """
613
+ Load a text file and split it into semantically meaningful chunks
614
+ """
615
+ print(f"Loading and processing text file: {file_path}")
616
+
617
+ # Read the file content
618
+ with open(file_path, 'r', encoding='utf-8') as f:
619
+ content = f.read()
620
+
621
+ # For small files, just keep the whole content as a single chunk
622
+ if len(content) <= chunk_size * 2:
623
+ print(f"File content is small, keeping as a single chunk")
624
+ return [Document(
625
+ page_content=content,
626
+ metadata={
627
+ "source": file_path,
628
+ "chunk_id": 0
629
+ }
630
+ )]
631
+
632
+ # Split by paragraphs first
633
+ paragraphs = re.split(r'\n\s*\n', content)
634
+ chunks = []
635
+
636
+ current_chunk = ""
637
+ current_size = 0
638
+
639
+ for para in paragraphs:
640
+ if not para.strip():
641
+ continue
642
+
643
+ # If adding this paragraph would exceed the chunk size, save current chunk and start new one
644
+ if current_size + len(para) > chunk_size and current_size > 0:
645
+ chunks.append(current_chunk)
646
+ current_chunk = para
647
+ current_size = len(para)
648
+ else:
649
+ # Add paragraph to current chunk with a newline if not empty
650
+ if current_size > 0:
651
+ current_chunk += "\n\n" + para
652
+ else:
653
+ current_chunk = para
654
+ current_size = len(current_chunk)
655
+
656
+ # Add the last chunk if not empty
657
+ if current_chunk:
658
+ chunks.append(current_chunk)
659
+
660
+ print(f"Split text into {len(chunks)} chunks")
661
+
662
+ # Convert to LangChain documents with metadata
663
+ documents = [
664
+ Document(
665
+ page_content=chunk,
666
+ metadata={
667
+ "source": file_path,
668
+ "chunk_id": i
669
+ }
670
+ ) for i, chunk in enumerate(chunks)
671
+ ]
672
+
673
+ return documents
674
+
675
+ def create_vector_store(documents, embeddings, store_path=None):
676
+ """
677
+ Create a FAISS vector store from documents using the given embeddings
678
+ """
679
+ print("Creating FAISS vector store...")
680
+
681
+ # Create vector store
682
+ vector_store = LangchainFAISS.from_documents(documents, embeddings)
683
+
684
+ # Save if path is provided
685
+ if store_path:
686
+ print(f"Saving vector store to {store_path}")
687
+ vector_store.save_local(store_path)
688
+
689
+ return vector_store
690
+
691
+ def load_vector_store(store_path, embeddings):
692
+ """
693
+ Load a FAISS vector store from disk
694
+ """
695
+ print(f"Loading vector store from {store_path}")
696
+ return LangchainFAISS.load_local(store_path, embeddings, allow_dangerous_deserialization=True)
697
+
698
+ def perform_similarity_search(vector_store, query, k=6):
699
+ """
700
+ Perform basic similarity search on the vector store
701
+ """
702
+ print(f"Searching for: {query}")
703
+ return vector_store.similarity_search_with_score(query, k=k)
704
+
705
+ # Main RAG functions
706
+ def index_text_files(model, tokenizer, data_dir, output_dir, device="cuda", chunk_size=500):
707
+ """
708
+ Index text files from a directory and create a FAISS vector store
709
+ """
710
+ print(f"Indexing text files from {data_dir} with chunk size ({chunk_size}) for fine-grained retrieval")
711
+
712
+ # Create embedding model
713
+ embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device)
714
+
715
+ # Create output directory if it doesn't exist
716
+ os.makedirs(output_dir, exist_ok=True)
717
+
718
+ # Get all text files
719
+ text_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.txt')]
720
+ print(f"Found {len(text_files)} text files")
721
+
722
+ # Process all text files
723
+ all_documents = []
724
+ for file_path in text_files:
725
+ documents = load_and_process_text_file(file_path, chunk_size=chunk_size)
726
+ all_documents.extend(documents)
727
+
728
+ print(f"Total documents: {len(all_documents)}")
729
+
730
+ # If we don't have enough chunks, reduce chunk size and try again
731
+ if len(all_documents) < 10 and chunk_size > 50:
732
+ print(f"Not enough chunks created. Reducing chunk size and trying again...")
733
+ return index_text_files(model, tokenizer, data_dir, output_dir, device, chunk_size=chunk_size//2)
734
+
735
+ # Create and save vector store
736
+ vector_store_path = os.path.join(output_dir, "faiss_index")
737
+ vector_store = create_vector_store(all_documents, embeddings, vector_store_path)
738
+
739
+ return vector_store, embeddings
740
+
741
+ def query_text_corpus(model, tokenizer, vector_store_path, query, k=6, device="cuda"):
742
+ """
743
+ Query the text corpus using the indexed vector store
744
+ """
745
+ # Create embedding model
746
+ embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device)
747
+
748
+ # Load vector store
749
+ vector_store = load_vector_store(vector_store_path, embeddings)
750
+
751
+ # Perform similarity search
752
+ results = perform_similarity_search(vector_store, query, k=k)
753
+
754
+ # Post-process results to combine adjacent chunks if they're from the same source
755
+ processed_results = []
756
+ seen_chunks = set()
757
+
758
+ for doc, score in results:
759
+ chunk_id = doc.metadata["chunk_id"]
760
+ source = doc.metadata["source"]
761
+
762
+ # Skip if we've already included this chunk
763
+ if (source, chunk_id) in seen_chunks:
764
+ continue
765
+
766
+ seen_chunks.add((source, chunk_id))
767
+
768
+ # Try to find adjacent chunks and combine them
769
+ combined_content = doc.page_content
770
+
771
+ # Look for adjacent chunks in results (both previous and next)
772
+ for adj_id in [chunk_id-1, chunk_id+1]:
773
+ for other_doc, _ in results:
774
+ if (other_doc.metadata["source"] == source and
775
+ other_doc.metadata["chunk_id"] == adj_id and
776
+ (source, adj_id) not in seen_chunks):
777
+
778
+ # Add the adjacent chunk content
779
+ if adj_id < chunk_id: # Previous chunk
780
+ combined_content = other_doc.page_content + " " + combined_content
781
+ else: # Next chunk
782
+ combined_content = combined_content + " " + other_doc.page_content
783
+
784
+ seen_chunks.add((source, adj_id))
785
+
786
+ # Create a new document with combined content
787
+ combined_doc = Document(
788
+ page_content=combined_content,
789
+ metadata={
790
+ "source": source,
791
+ "chunk_id": chunk_id,
792
+ "is_combined": True if combined_content != doc.page_content else False
793
+ }
794
+ )
795
+
796
+ processed_results.append((combined_doc, score))
797
+
798
+ return processed_results
799
+
800
+ def main():
801
+ parser = argparse.ArgumentParser(description="Hindi RAG System with LangChain and FAISS")
802
+ parser.add_argument("--model_dir", type=str, default="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final",
803
+ help="Directory containing the model and tokenizer")
804
+ parser.add_argument("--tokenizer_dir", type=str, default="/home/ubuntu/hindi_tokenizer",
805
+ help="Directory containing the tokenizer")
806
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
807
+ help="Device to run inference on ('cuda' or 'cpu')")
808
+ parser.add_argument("--index", action="store_true",
809
+ help="Index text files from data directory")
810
+ parser.add_argument("--query", type=str, default=None,
811
+ help="Query to search in the indexed corpus")
812
+ parser.add_argument("--data_dir", type=str, default="./data",
813
+ help="Directory containing text files for indexing")
814
+ parser.add_argument("--output_dir", type=str, default="./output",
815
+ help="Directory to save the indexed vector store")
816
+ parser.add_argument("--top_k", type=int, default=6,
817
+ help="Number of top results to return")
818
+ parser.add_argument("--chunk_size", type=int, default=500,
819
+ help="Size of text chunks for indexing")
820
+ parser.add_argument("--interactive", action="store_true",
821
+ help="Run in interactive mode for querying")
822
+ parser.add_argument("--reindex", action="store_true",
823
+ help="Force reindexing even if index exists")
824
+ args = parser.parse_args()
825
+
826
+ # Load model and tokenizer
827
+ model, tokenizer, config = load_model_and_tokenizer(args.model_dir, args.tokenizer_dir)
828
+
829
+ # Move model to device
830
+ model = model.to(args.device)
831
+
832
+ # Create vector store path
833
+ vector_store_path = os.path.join(args.output_dir, "faiss_index")
834
+
835
+ if args.index or args.reindex:
836
+ # Index text files
837
+ index_text_files(model, tokenizer, args.data_dir, args.output_dir, args.device, args.chunk_size)
838
+ print(f"Indexing complete. Vector store saved to {vector_store_path}")
839
+
840
+ if args.query:
841
+ # Query the corpus
842
+ results = query_text_corpus(model, tokenizer, vector_store_path, args.query, args.top_k, args.device)
843
+
844
+ # Print results
845
+ print("\nSearch Results:")
846
+ for i, (doc, score) in enumerate(results):
847
+ print(f"\nResult {i+1} (Score: {score:.4f}):")
848
+ print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}")
849
+
850
+ # Extract and print only relevant sentences
851
+ relevant_text = extract_relevant_sentences(doc.page_content, args.query)
852
+ print(f"Content: {relevant_text}")
853
+
854
+ if args.interactive:
855
+ print("\nInteractive mode. Enter queries (or type 'quit' to exit).")
856
+
857
+ while True:
858
+ print("\nEnter query:")
859
+ query = input()
860
+
861
+ if not query.strip():
862
+ continue
863
+
864
+ if query.lower() == 'quit':
865
+ break
866
+
867
+ # Query the corpus
868
+ results = query_text_corpus(model, tokenizer, vector_store_path, query, args.top_k, args.device)
869
+
870
+ # Print results
871
+ print("\nSearch Results:")
872
+ for i, (doc, score) in enumerate(results):
873
+ print(f"\nResult {i+1} (Score: {score:.4f}):")
874
+ print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}")
875
+
876
+ # Extract and print only relevant sentences
877
+ relevant_text = extract_relevant_sentences(doc.page_content, query)
878
+ print(f"Content: {relevant_text}")
879
+
880
+ if __name__ == "__main__":
881
+ main()
hindi-rag-system.py.amltmp ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import argparse
5
+ import numpy as np
6
+ import re
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ import sentencepiece as spm
10
+ import math
11
+ from safetensors.torch import save_file, load_file
12
+ from tqdm import tqdm
13
+ import faiss
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.vectorstores import FAISS as LangchainFAISS
16
+ from langchain.docstore.document import Document
17
+ from langchain.embeddings.base import Embeddings
18
+ from typing import List, Dict, Any, Optional, Callable
19
+
20
+ # Tokenizer wrapper class - same as in original code
21
+ class SentencePieceTokenizerWrapper:
22
+ def __init__(self, sp_model_path):
23
+ self.sp_model = spm.SentencePieceProcessor()
24
+ self.sp_model.Load(sp_model_path)
25
+ self.vocab_size = self.sp_model.GetPieceSize()
26
+
27
+ # Special token IDs from tokenizer training
28
+ self.pad_token_id = 0
29
+ self.bos_token_id = 1
30
+ self.eos_token_id = 2
31
+ self.unk_token_id = 3
32
+
33
+ # Set special tokens
34
+ self.pad_token = "<pad>"
35
+ self.bos_token = "<s>"
36
+ self.eos_token = "</s>"
37
+ self.unk_token = "<unk>"
38
+ self.mask_token = "<mask>"
39
+
40
+ def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None):
41
+ # Handle both string and list inputs
42
+ if isinstance(text, str):
43
+ # Encode a single string
44
+ ids = self.sp_model.EncodeAsIds(text)
45
+
46
+ # Handle truncation
47
+ if truncation and max_length and len(ids) > max_length:
48
+ ids = ids[:max_length]
49
+
50
+ attention_mask = [1] * len(ids)
51
+
52
+ # Handle padding
53
+ if padding and max_length:
54
+ padding_length = max(0, max_length - len(ids))
55
+ ids = ids + [self.pad_token_id] * padding_length
56
+ attention_mask = attention_mask + [0] * padding_length
57
+
58
+ result = {
59
+ 'input_ids': ids,
60
+ 'attention_mask': attention_mask
61
+ }
62
+
63
+ # Convert to tensors if requested
64
+ if return_tensors == 'pt':
65
+ import torch
66
+ result = {k: torch.tensor([v]) for k, v in result.items()}
67
+
68
+ return result
69
+
70
+ # Process a batch of texts
71
+ batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text]
72
+
73
+ # Apply truncation if needed
74
+ if truncation and max_length:
75
+ batch_encoded = [ids[:max_length] for ids in batch_encoded]
76
+
77
+ # Create attention masks
78
+ batch_attention_mask = [[1] * len(ids) for ids in batch_encoded]
79
+
80
+ # Apply padding if needed
81
+ if padding:
82
+ if max_length:
83
+ max_len = max_length
84
+ else:
85
+ max_len = max(len(ids) for ids in batch_encoded)
86
+
87
+ # Pad sequences to max_len
88
+ batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded]
89
+ batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask]
90
+
91
+ result = {
92
+ 'input_ids': batch_encoded,
93
+ 'attention_mask': batch_attention_mask
94
+ }
95
+
96
+ # Convert to tensors if requested
97
+ if return_tensors == 'pt':
98
+ import torch
99
+ result = {k: torch.tensor(v) for k, v in result.items()}
100
+
101
+ return result
102
+
103
+ # Model architecture definitions for inference
104
+
105
+ class MultiHeadAttention(nn.Module):
106
+ """Advanced multi-headed attention with relative positional encoding"""
107
+ def __init__(self, config):
108
+ super().__init__()
109
+ self.num_attention_heads = config["num_attention_heads"]
110
+ self.attention_head_size = config["hidden_size"] // config["num_attention_heads"]
111
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
112
+
113
+ # Query, Key, Value projections
114
+ self.query = nn.Linear(config["hidden_size"], self.all_head_size)
115
+ self.key = nn.Linear(config["hidden_size"], self.all_head_size)
116
+ self.value = nn.Linear(config["hidden_size"], self.all_head_size)
117
+
118
+ # Output projection
119
+ self.output = nn.Sequential(
120
+ nn.Linear(self.all_head_size, config["hidden_size"]),
121
+ nn.Dropout(config["attention_probs_dropout_prob"])
122
+ )
123
+
124
+ # Simplified relative position bias approach
125
+ self.max_position_embeddings = config["max_position_embeddings"]
126
+ self.relative_attention_bias = nn.Embedding(
127
+ 2 * config["max_position_embeddings"] - 1,
128
+ config["num_attention_heads"]
129
+ )
130
+
131
+ def transpose_for_scores(self, x):
132
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
133
+ x = x.view(*new_shape)
134
+ return x.permute(0, 2, 1, 3)
135
+
136
+ def forward(self, hidden_states, attention_mask=None):
137
+ batch_size, seq_length = hidden_states.size()[:2]
138
+
139
+ # Project inputs to queries, keys, and values
140
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
141
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
142
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
143
+
144
+ # Take the dot product between query and key to get the raw attention scores
145
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
146
+
147
+ # Generate relative position matrix
148
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device)
149
+ relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0) # [seq_len, seq_len]
150
+ # Shift values to be >= 0
151
+ relative_position = relative_position + self.max_position_embeddings - 1
152
+ # Ensure indices are within bounds
153
+ relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2)
154
+
155
+ # Get relative position embeddings [seq_len, seq_len, num_heads]
156
+ rel_attn_bias = self.relative_attention_bias(relative_position) # [seq_len, seq_len, num_heads]
157
+
158
+ # Reshape to add to attention heads [1, num_heads, seq_len, seq_len]
159
+ rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0)
160
+
161
+ # Add to attention scores - now dimensions will match
162
+ attention_scores = attention_scores + rel_attn_bias
163
+
164
+ # Scale attention scores
165
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
166
+
167
+ # Apply attention mask
168
+ if attention_mask is not None:
169
+ attention_scores = attention_scores + attention_mask
170
+
171
+ # Normalize the attention scores to probabilities
172
+ attention_probs = F.softmax(attention_scores, dim=-1)
173
+
174
+ # Apply dropout
175
+ attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
176
+
177
+ # Apply attention to values
178
+ context_layer = torch.matmul(attention_probs, value_layer)
179
+
180
+ # Reshape back to [batch_size, seq_length, hidden_size]
181
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
182
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
183
+ context_layer = context_layer.view(*new_shape)
184
+
185
+ # Final output projection
186
+ output = self.output(context_layer)
187
+
188
+ return output
189
+
190
+ class EnhancedTransformerLayer(nn.Module):
191
+ """Advanced transformer layer with pre-layer norm and enhanced attention"""
192
+ def __init__(self, config):
193
+ super().__init__()
194
+ self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
195
+ self.attention = MultiHeadAttention(config)
196
+
197
+ self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
198
+
199
+ # Feed-forward network
200
+ self.ffn = nn.Sequential(
201
+ nn.Linear(config["hidden_size"], config["intermediate_size"]),
202
+ nn.GELU(),
203
+ nn.Dropout(config["hidden_dropout_prob"]),
204
+ nn.Linear(config["intermediate_size"], config["hidden_size"]),
205
+ nn.Dropout(config["hidden_dropout_prob"])
206
+ )
207
+
208
+ def forward(self, hidden_states, attention_mask=None):
209
+ # Pre-layer norm for attention
210
+ attn_norm_hidden = self.attention_pre_norm(hidden_states)
211
+
212
+ # Self-attention
213
+ attention_output = self.attention(attn_norm_hidden, attention_mask)
214
+
215
+ # Residual connection
216
+ hidden_states = hidden_states + attention_output
217
+
218
+ # Pre-layer norm for feed-forward
219
+ ffn_norm_hidden = self.ffn_pre_norm(hidden_states)
220
+
221
+ # Feed-forward
222
+ ffn_output = self.ffn(ffn_norm_hidden)
223
+
224
+ # Residual connection
225
+ hidden_states = hidden_states + ffn_output
226
+
227
+ return hidden_states
228
+
229
+ class AdvancedTransformerModel(nn.Module):
230
+ """Advanced Transformer model for inference"""
231
+
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.config = config
235
+
236
+ # Embeddings
237
+ self.word_embeddings = nn.Embedding(
238
+ config["vocab_size"],
239
+ config["hidden_size"],
240
+ padding_idx=config["pad_token_id"]
241
+ )
242
+
243
+ # Position embeddings
244
+ self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"])
245
+
246
+ # Embedding dropout
247
+ self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"])
248
+
249
+ # Transformer layers
250
+ self.layers = nn.ModuleList([
251
+ EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"])
252
+ ])
253
+
254
+ # Final layer norm
255
+ self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
256
+
257
+ def forward(self, input_ids, attention_mask=None):
258
+ input_shape = input_ids.size()
259
+ batch_size, seq_length = input_shape
260
+
261
+ # Get position ids
262
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
263
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
264
+
265
+ # Get embeddings
266
+ word_embeds = self.word_embeddings(input_ids)
267
+ position_embeds = self.position_embeddings(position_ids)
268
+
269
+ # Sum embeddings
270
+ embeddings = word_embeds + position_embeds
271
+
272
+ # Apply dropout
273
+ embeddings = self.embedding_dropout(embeddings)
274
+
275
+ # Default attention mask
276
+ if attention_mask is None:
277
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
278
+
279
+ # Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens)
280
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
281
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
282
+
283
+ # Apply transformer layers
284
+ hidden_states = embeddings
285
+ for layer in self.layers:
286
+ hidden_states = layer(hidden_states, extended_attention_mask)
287
+
288
+ # Final layer norm
289
+ hidden_states = self.final_layer_norm(hidden_states)
290
+
291
+ return hidden_states
292
+
293
+ class AdvancedPooling(nn.Module):
294
+ """Advanced pooling module supporting multiple pooling strategies"""
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ self.pooling_mode = config["pooling_mode"] # 'mean', 'max', 'cls', 'attention'
298
+ self.hidden_size = config["hidden_size"]
299
+
300
+ # For attention pooling
301
+ if self.pooling_mode == 'attention':
302
+ self.attention_weights = nn.Linear(config["hidden_size"], 1)
303
+
304
+ # For weighted pooling
305
+ elif self.pooling_mode == 'weighted':
306
+ self.weight_layer = nn.Linear(config["hidden_size"], 1)
307
+
308
+ def forward(self, token_embeddings, attention_mask=None):
309
+ if attention_mask is None:
310
+ attention_mask = torch.ones_like(token_embeddings[:, :, 0])
311
+
312
+ mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
313
+
314
+ if self.pooling_mode == 'cls':
315
+ # Use [CLS] token (first token)
316
+ pooled = token_embeddings[:, 0]
317
+
318
+ elif self.pooling_mode == 'max':
319
+ # Max pooling
320
+ token_embeddings = token_embeddings.clone()
321
+ # Set padding tokens to large negative value to exclude them from max
322
+ token_embeddings[mask_expanded == 0] = -1e9
323
+ pooled = torch.max(token_embeddings, dim=1)[0]
324
+
325
+ elif self.pooling_mode == 'attention':
326
+ # Attention pooling
327
+ weights = self.attention_weights(token_embeddings).squeeze(-1)
328
+ # Mask out padding tokens
329
+ weights = weights.masked_fill(attention_mask == 0, -1e9)
330
+ weights = F.softmax(weights, dim=1).unsqueeze(-1)
331
+ pooled = torch.sum(token_embeddings * weights, dim=1)
332
+
333
+ elif self.pooling_mode == 'weighted':
334
+ # Weighted average pooling
335
+ weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1)
336
+ # Apply mask
337
+ weights = weights * attention_mask
338
+ # Normalize weights
339
+ sum_weights = torch.sum(weights, dim=1, keepdim=True)
340
+ sum_weights = torch.clamp(sum_weights, min=1e-9)
341
+ weights = weights / sum_weights
342
+ # Apply weights
343
+ pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1)
344
+
345
+ else: # Default to mean pooling
346
+ # Mean pooling
347
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
348
+ sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
349
+ pooled = sum_embeddings / sum_mask
350
+
351
+ # L2 normalize
352
+ pooled = F.normalize(pooled, p=2, dim=1)
353
+
354
+ return pooled
355
+
356
+ class SentenceEmbeddingModel(nn.Module):
357
+ """Complete sentence embedding model for inference"""
358
+ def __init__(self, config):
359
+ super(SentenceEmbeddingModel, self).__init__()
360
+ self.config = config
361
+
362
+ # Create transformer model
363
+ self.transformer = AdvancedTransformerModel(config)
364
+
365
+ # Create pooling module
366
+ self.pooling = AdvancedPooling(config)
367
+
368
+ # Build projection module if needed
369
+ if "projection_dim" in config and config["projection_dim"] > 0:
370
+ self.use_projection = True
371
+ self.projection = nn.Sequential(
372
+ nn.Linear(config["hidden_size"], config["hidden_size"]),
373
+ nn.GELU(),
374
+ nn.Linear(config["hidden_size"], config["projection_dim"]),
375
+ nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"])
376
+ )
377
+ else:
378
+ self.use_projection = False
379
+
380
+ def forward(self, input_ids, attention_mask=None):
381
+ # Get token embeddings from transformer
382
+ token_embeddings = self.transformer(input_ids, attention_mask)
383
+
384
+ # Pool token embeddings
385
+ pooled_output = self.pooling(token_embeddings, attention_mask)
386
+
387
+ # Apply projection if enabled
388
+ if self.use_projection:
389
+ pooled_output = self.projection(pooled_output)
390
+ pooled_output = F.normalize(pooled_output, p=2, dim=1)
391
+
392
+ return pooled_output
393
+
394
+ def convert_to_safetensors(model_path, output_path):
395
+ """Convert PyTorch model to safetensors format"""
396
+ print(f"Converting model from {model_path} to safetensors format...")
397
+
398
+ try:
399
+ # First try with weights_only=False to handle PyTorch 2.6+ checkpoints
400
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
401
+ print("Successfully loaded checkpoint with weights_only=False")
402
+ except TypeError:
403
+ # For older PyTorch versions that don't have weights_only parameter
404
+ print("Falling back to default torch.load behavior for older PyTorch versions")
405
+ checkpoint = torch.load(model_path, map_location="cpu")
406
+
407
+ # Get model state dict
408
+ if "model_state_dict" in checkpoint:
409
+ state_dict = checkpoint["model_state_dict"]
410
+ print("Extracted model_state_dict from checkpoint")
411
+ else:
412
+ state_dict = checkpoint
413
+ print("Using entire checkpoint as state_dict")
414
+
415
+ # Save as safetensors
416
+ save_file(state_dict, output_path)
417
+ print(f"Model converted and saved to {output_path}")
418
+
419
+ def load_model_and_tokenizer(model_dir, tokenizer_dir="/home/ubuntu/hindi_tokenizer"):
420
+ """Load the model and tokenizer for inference"""
421
+
422
+ # Load the config
423
+ config_path = os.path.join(model_dir, "config.json")
424
+ with open(config_path, "r") as f:
425
+ config = json.load(f)
426
+
427
+ # Load the tokenizer - use specified tokenizer directory
428
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.model")
429
+ if not os.path.exists(tokenizer_path):
430
+ # Try other locations
431
+ tokenizer_path = os.path.join(model_dir, "tokenizer.model")
432
+ if not os.path.exists(tokenizer_path):
433
+ raise FileNotFoundError(f"Could not find tokenizer model at {tokenizer_path}")
434
+
435
+ tokenizer = SentencePieceTokenizerWrapper(tokenizer_path)
436
+ print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {tokenizer.vocab_size}")
437
+
438
+ # Load the model
439
+ safetensors_path = os.path.join(model_dir, "embedding_model.safetensors")
440
+
441
+ if not os.path.exists(safetensors_path):
442
+ print(f"Safetensors model not found at {safetensors_path}, converting from PyTorch checkpoint...")
443
+
444
+ # Convert from PyTorch checkpoint
445
+ pytorch_path = os.path.join(model_dir, "embedding_model.pt")
446
+ if not os.path.exists(pytorch_path):
447
+ raise FileNotFoundError(f"Could not find PyTorch model at {pytorch_path}")
448
+
449
+ convert_to_safetensors(pytorch_path, safetensors_path)
450
+
451
+ # Load state dict from safetensors
452
+ state_dict = load_file(safetensors_path)
453
+
454
+ # Create model
455
+ model = SentenceEmbeddingModel(config)
456
+
457
+ # Load state dict
458
+ try:
459
+ # Try direct loading
460
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
461
+ print(f"Loaded model with missing keys: {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}")
462
+ print(f"Unexpected keys: {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}")
463
+ except Exception as e:
464
+ print(f"Error loading state dict: {e}")
465
+ print("Model will be initialized with random weights")
466
+
467
+ model.eval()
468
+
469
+ return model, tokenizer, config
470
+
471
+ # LangChain Custom Embeddings Class
472
+ class HindiSentenceEmbeddings(Embeddings):
473
+ """
474
+ Custom Langchain Embeddings class for Hindi sentence embeddings model
475
+ """
476
+ def __init__(self, model, tokenizer, device="cuda", batch_size=32, max_length=128):
477
+ """Initialize with model, tokenizer, and inference parameters"""
478
+ self.model = model
479
+ self.tokenizer = tokenizer
480
+ self.device = device
481
+ self.batch_size = batch_size
482
+ self.max_length = max_length
483
+
484
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
485
+ """Embed a list of documents/texts"""
486
+ embeddings = []
487
+
488
+ with torch.no_grad():
489
+ for i in range(0, len(texts), self.batch_size):
490
+ batch = texts[i:i+self.batch_size]
491
+
492
+ # Tokenize
493
+ inputs = self.tokenizer(
494
+ batch,
495
+ padding="max_length",
496
+ truncation=True,
497
+ max_length=self.max_length,
498
+ return_tensors="pt"
499
+ )
500
+
501
+ # Move to device
502
+ input_ids = inputs["input_ids"].to(self.device)
503
+ attention_mask = inputs["attention_mask"].to(self.device)
504
+
505
+ # Get embeddings
506
+ batch_embeddings = self.model(input_ids, attention_mask)
507
+
508
+ # Move to CPU and convert to numpy
509
+ batch_embeddings = batch_embeddings.cpu().numpy()
510
+ embeddings.append(batch_embeddings)
511
+
512
+ return np.vstack(embeddings).tolist()
513
+
514
+ def embed_query(self, text: str) -> List[float]:
515
+ """Embed a single query/text"""
516
+ return self.embed_documents([text])[0]
517
+
518
+ def extract_relevant_sentences(text, query, window_size=2):
519
+ """
520
+ Extract the most relevant sentences from text based on query keywords
521
+
522
+ Args:
523
+ text: The full text content
524
+ query: The user's query
525
+ window_size: Number of sentences to include before and after matched sentence
526
+
527
+ Returns:
528
+ String containing the most relevant portion of the text
529
+ """
530
+ # Clean and normalize query and text for matching
531
+ query = query.strip().lower()
532
+
533
+ # Remove question marks and other punctuation from query for matching
534
+ query = re.sub(r'[?।॥!,.:]', '', query)
535
+
536
+ # Extract keywords from the query (remove common Hindi stop words)
537
+ stop_words = ['और', 'का', 'के', 'को', 'में', 'से', 'है', 'हैं', 'था', 'थे', 'की', 'कि', 'पर', 'एक', 'यह', 'वह', 'जो', 'ने', 'हो', 'कर']
538
+ query_terms = [word for word in query.split() if word not in stop_words]
539
+
540
+ if not query_terms:
541
+ return text # If no meaningful terms left, return the full text
542
+
543
+ # Split text into sentences (using Hindi sentence terminators)
544
+ sentences = re.split(r'([।॥!?.])', text)
545
+
546
+ # Rejoin sentences with their terminators
547
+ complete_sentences = []
548
+ for i in range(0, len(sentences)-1, 2):
549
+ if i+1 < len(sentences):
550
+ complete_sentences.append(sentences[i] + sentences[i+1])
551
+ else:
552
+ complete_sentences.append(sentences[i])
553
+
554
+ # If the above didn't work properly, try simpler approach
555
+ if len(complete_sentences) <= 1:
556
+ complete_sentences = re.split(r'[।॥!?.]', text)
557
+ complete_sentences = [s.strip() for s in complete_sentences if s.strip()]
558
+
559
+ # Score each sentence based on how many query terms it contains
560
+ sentence_scores = []
561
+ for i, sentence in enumerate(complete_sentences):
562
+ sentence_lower = sentence.lower()
563
+ # Calculate score based on number of query terms found
564
+ score = sum(1 for term in query_terms if term in sentence_lower)
565
+ sentence_scores.append((i, score))
566
+
567
+ # Find the best matching sentence
568
+ if not sentence_scores:
569
+ return text[:500] + "..." # Fallback
570
+
571
+ # Get the index of sentence with highest score
572
+ best_match_idx, best_score = max(sentence_scores, key=lambda x: x[1])
573
+
574
+ # If no good match found, return the whole text (up to a limit)
575
+ if best_score == 0:
576
+ # Try partial word matching as a fallback
577
+ for i, sentence in enumerate(complete_sentences):
578
+ sentence_lower = sentence.lower()
579
+ partial_score = sum(1 for term in query_terms if any(term in word.lower() for word in sentence_lower.split()))
580
+ if partial_score > 0:
581
+ best_match_idx = i
582
+ break
583
+ else:
584
+ # If still no match, just return the first part of the text
585
+ if len(text) > 1000:
586
+ return text[:1000] + "..."
587
+ return text
588
+
589
+ # Get window of sentences around the best match
590
+ start_idx = max(0, best_match_idx - window_size)
591
+ end_idx = min(len(complete_sentences), best_match_idx + window_size + 1)
592
+
593
+ # Create excerpt
594
+ relevant_text = ' '.join(complete_sentences[start_idx:end_idx])
595
+
596
+ # If the excerpt is short, return more context
597
+ if len(relevant_text) < 100 and len(text) > len(relevant_text):
598
+ # Add more context
599
+ if end_idx < len(complete_sentences):
600
+ relevant_text += ' ' + ' '.join(complete_sentences[end_idx:end_idx+2])
601
+ if start_idx > 0:
602
+ relevant_text = ' '.join(complete_sentences[max(0, start_idx-2):start_idx]) + ' ' + relevant_text
603
+
604
+ # If the excerpt is too short or the whole text is small anyway, return whole text
605
+ if len(relevant_text) < 50 or len(text) < 1000:
606
+ return text
607
+
608
+ return relevant_text
609
+
610
+ # Text processing and indexing functions
611
+ def load_and_process_text_file(file_path, chunk_size=500, chunk_overlap=100):
612
+ """
613
+ Load a text file and split it into semantically meaningful chunks
614
+ """
615
+ print(f"Loading and processing text file: {file_path}")
616
+
617
+ # Read the file content
618
+ with open(file_path, 'r', encoding='utf-8') as f:
619
+ content = f.read()
620
+
621
+ # For small files, just keep the whole content as a single chunk
622
+ if len(content) <= chunk_size * 2:
623
+ print(f"File content is small, keeping as a single chunk")
624
+ return [Document(
625
+ page_content=content,
626
+ metadata={
627
+ "source": file_path,
628
+ "chunk_id": 0
629
+ }
630
+ )]
631
+
632
+ # Split by paragraphs first
633
+ paragraphs = re.split(r'\n\s*\n', content)
634
+ chunks = []
635
+
636
+ current_chunk = ""
637
+ current_size = 0
638
+
639
+ for para in paragraphs:
640
+ if not para.strip():
641
+ continue
642
+
643
+ # If adding this paragraph would exceed the chunk size, save current chunk and start new one
644
+ if current_size + len(para) > chunk_size and current_size > 0:
645
+ chunks.append(current_chunk)
646
+ current_chunk = para
647
+ current_size = len(para)
648
+ else:
649
+ # Add paragraph to current chunk with a newline if not empty
650
+ if current_size > 0:
651
+ current_chunk += "\n\n" + para
652
+ else:
653
+ current_chunk = para
654
+ current_size = len(current_chunk)
655
+
656
+ # Add the last chunk if not empty
657
+ if current_chunk:
658
+ chunks.append(current_chunk)
659
+
660
+ print(f"Split text into {len(chunks)} chunks")
661
+
662
+ # Convert to LangChain documents with metadata
663
+ documents = [
664
+ Document(
665
+ page_content=chunk,
666
+ metadata={
667
+ "source": file_path,
668
+ "chunk_id": i
669
+ }
670
+ ) for i, chunk in enumerate(chunks)
671
+ ]
672
+
673
+ return documents
674
+
675
+ def create_vector_store(documents, embeddings, store_path=None):
676
+ """
677
+ Create a FAISS vector store from documents using the given embeddings
678
+ """
679
+ print("Creating FAISS vector store...")
680
+
681
+ # Create vector store
682
+ vector_store = LangchainFAISS.from_documents(documents, embeddings)
683
+
684
+ # Save if path is provided
685
+ if store_path:
686
+ print(f"Saving vector store to {store_path}")
687
+ vector_store.save_local(store_path)
688
+
689
+ return vector_store
690
+
691
+ def load_vector_store(store_path, embeddings):
692
+ """
693
+ Load a FAISS vector store from disk
694
+ """
695
+ print(f"Loading vector store from {store_path}")
696
+ return LangchainFAISS.load_local(store_path, embeddings, allow_dangerous_deserialization=True)
697
+
698
+ def perform_similarity_search(vector_store, query, k=6):
699
+ """
700
+ Perform basic similarity search on the vector store
701
+ """
702
+ print(f"Searching for: {query}")
703
+ return vector_store.similarity_search_with_score(query, k=k)
704
+
705
+ # Main RAG functions
706
+ def index_text_files(model, tokenizer, data_dir, output_dir, device="cuda", chunk_size=500):
707
+ """
708
+ Index text files from a directory and create a FAISS vector store
709
+ """
710
+ print(f"Indexing text files from {data_dir} with chunk size ({chunk_size}) for fine-grained retrieval")
711
+
712
+ # Create embedding model
713
+ embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device)
714
+
715
+ # Create output directory if it doesn't exist
716
+ os.makedirs(output_dir, exist_ok=True)
717
+
718
+ # Get all text files
719
+ text_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.txt')]
720
+ print(f"Found {len(text_files)} text files")
721
+
722
+ # Process all text files
723
+ all_documents = []
724
+ for file_path in text_files:
725
+ documents = load_and_process_text_file(file_path, chunk_size=chunk_size)
726
+ all_documents.extend(documents)
727
+
728
+ print(f"Total documents: {len(all_documents)}")
729
+
730
+ # If we don't have enough chunks, reduce chunk size and try again
731
+ if len(all_documents) < 10 and chunk_size > 50:
732
+ print(f"Not enough chunks created. Reducing chunk size and trying again...")
733
+ return index_text_files(model, tokenizer, data_dir, output_dir, device, chunk_size=chunk_size//2)
734
+
735
+ # Create and save vector store
736
+ vector_store_path = os.path.join(output_dir, "faiss_index")
737
+ vector_store = create_vector_store(all_documents, embeddings, vector_store_path)
738
+
739
+ return vector_store, embeddings
740
+
741
+ def query_text_corpus(model, tokenizer, vector_store_path, query, k=6, device="cuda"):
742
+ """
743
+ Query the text corpus using the indexed vector store
744
+ """
745
+ # Create embedding model
746
+ embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device)
747
+
748
+ # Load vector store
749
+ vector_store = load_vector_store(vector_store_path, embeddings)
750
+
751
+ # Perform similarity search
752
+ results = perform_similarity_search(vector_store, query, k=k)
753
+
754
+ # Post-process results to combine adjacent chunks if they're from the same source
755
+ processed_results = []
756
+ seen_chunks = set()
757
+
758
+ for doc, score in results:
759
+ chunk_id = doc.metadata["chunk_id"]
760
+ source = doc.metadata["source"]
761
+
762
+ # Skip if we've already included this chunk
763
+ if (source, chunk_id) in seen_chunks:
764
+ continue
765
+
766
+ seen_chunks.add((source, chunk_id))
767
+
768
+ # Try to find adjacent chunks and combine them
769
+ combined_content = doc.page_content
770
+
771
+ # Look for adjacent chunks in results (both previous and next)
772
+ for adj_id in [chunk_id-1, chunk_id+1]:
773
+ for other_doc, _ in results:
774
+ if (other_doc.metadata["source"] == source and
775
+ other_doc.metadata["chunk_id"] == adj_id and
776
+ (source, adj_id) not in seen_chunks):
777
+
778
+ # Add the adjacent chunk content
779
+ if adj_id < chunk_id: # Previous chunk
780
+ combined_content = other_doc.page_content + " " + combined_content
781
+ else: # Next chunk
782
+ combined_content = combined_content + " " + other_doc.page_content
783
+
784
+ seen_chunks.add((source, adj_id))
785
+
786
+ # Create a new document with combined content
787
+ combined_doc = Document(
788
+ page_content=combined_content,
789
+ metadata={
790
+ "source": source,
791
+ "chunk_id": chunk_id,
792
+ "is_combined": True if combined_content != doc.page_content else False
793
+ }
794
+ )
795
+
796
+ processed_results.append((combined_doc, score))
797
+
798
+ return processed_results
799
+
800
+ def main():
801
+ parser = argparse.ArgumentParser(description="Hindi RAG System with LangChain and FAISS")
802
+ parser.add_argument("--model_dir", type=str, default="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final",
803
+ help="Directory containing the model and tokenizer")
804
+ parser.add_argument("--tokenizer_dir", type=str, default="/home/ubuntu/hindi_tokenizer",
805
+ help="Directory containing the tokenizer")
806
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
807
+ help="Device to run inference on ('cuda' or 'cpu')")
808
+ parser.add_argument("--index", action="store_true",
809
+ help="Index text files from data directory")
810
+ parser.add_argument("--query", type=str, default=None,
811
+ help="Query to search in the indexed corpus")
812
+ parser.add_argument("--data_dir", type=str, default="./data",
813
+ help="Directory containing text files for indexing")
814
+ parser.add_argument("--output_dir", type=str, default="./output",
815
+ help="Directory to save the indexed vector store")
816
+ parser.add_argument("--top_k", type=int, default=6,
817
+ help="Number of top results to return")
818
+ parser.add_argument("--chunk_size", type=int, default=500,
819
+ help="Size of text chunks for indexing")
820
+ parser.add_argument("--interactive", action="store_true",
821
+ help="Run in interactive mode for querying")
822
+ parser.add_argument("--reindex", action="store_true",
823
+ help="Force reindexing even if index exists")
824
+ args = parser.parse_args()
825
+
826
+ # Load model and tokenizer
827
+ model, tokenizer, config = load_model_and_tokenizer(args.model_dir, args.tokenizer_dir)
828
+
829
+ # Move model to device
830
+ model = model.to(args.device)
831
+
832
+ # Create vector store path
833
+ vector_store_path = os.path.join(args.output_dir, "faiss_index")
834
+
835
+ if args.index or args.reindex:
836
+ # Index text files
837
+ index_text_files(model, tokenizer, args.data_dir, args.output_dir, args.device, args.chunk_size)
838
+ print(f"Indexing complete. Vector store saved to {vector_store_path}")
839
+
840
+ if args.query:
841
+ # Query the corpus
842
+ results = query_text_corpus(model, tokenizer, vector_store_path, args.query, args.top_k, args.device)
843
+
844
+ # Print results
845
+ print("\nSearch Results:")
846
+ for i, (doc, score) in enumerate(results):
847
+ print(f"\nResult {i+1} (Score: {score:.4f}):")
848
+ print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}")
849
+
850
+ # Extract and print only relevant sentences
851
+ relevant_text = extract_relevant_sentences(doc.page_content, args.query)
852
+ print(f"Content: {relevant_text}")
853
+
854
+ if args.interactive:
855
+ print("\nInteractive mode. Enter queries (or type 'quit' to exit).")
856
+
857
+ while True:
858
+ print("\nEnter query:")
859
+ query = input()
860
+
861
+ if not query.strip():
862
+ continue
863
+
864
+ if query.lower() == 'quit':
865
+ break
866
+
867
+ # Query the corpus
868
+ results = query_text_corpus(model, tokenizer, vector_store_path, query, args.top_k, args.device)
869
+
870
+ # Print results
871
+ print("\nSearch Results:")
872
+ for i, (doc, score) in enumerate(results):
873
+ print(f"\nResult {i+1} (Score: {score:.4f}):")
874
+ print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}")
875
+
876
+ # Extract and print only relevant sentences
877
+ relevant_text = extract_relevant_sentences(doc.page_content, query)
878
+ print(f"Content: {relevant_text}")
879
+
880
+ if __name__ == "__main__":
881
+ main()