import os import torch import json import argparse import numpy as np import re from torch import nn from torch.nn import functional as F import sentencepiece as spm import math from safetensors.torch import save_file, load_file from tqdm import tqdm import faiss from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS as LangchainFAISS from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from typing import List, Dict, Any, Optional, Callable # Tokenizer wrapper class - same as in original code class SentencePieceTokenizerWrapper: def __init__(self, sp_model_path): self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(sp_model_path) self.vocab_size = self.sp_model.GetPieceSize() # Special token IDs from tokenizer training self.pad_token_id = 0 self.bos_token_id = 1 self.eos_token_id = 2 self.unk_token_id = 3 # Set special tokens self.pad_token = "" self.bos_token = "" self.eos_token = "" self.unk_token = "" self.mask_token = "" def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None): # Handle both string and list inputs if isinstance(text, str): # Encode a single string ids = self.sp_model.EncodeAsIds(text) # Handle truncation if truncation and max_length and len(ids) > max_length: ids = ids[:max_length] attention_mask = [1] * len(ids) # Handle padding if padding and max_length: padding_length = max(0, max_length - len(ids)) ids = ids + [self.pad_token_id] * padding_length attention_mask = attention_mask + [0] * padding_length result = { 'input_ids': ids, 'attention_mask': attention_mask } # Convert to tensors if requested if return_tensors == 'pt': import torch result = {k: torch.tensor([v]) for k, v in result.items()} return result # Process a batch of texts batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text] # Apply truncation if needed if truncation and max_length: batch_encoded = [ids[:max_length] for ids in batch_encoded] # Create attention masks batch_attention_mask = [[1] * len(ids) for ids in batch_encoded] # Apply padding if needed if padding: if max_length: max_len = max_length else: max_len = max(len(ids) for ids in batch_encoded) # Pad sequences to max_len batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded] batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask] result = { 'input_ids': batch_encoded, 'attention_mask': batch_attention_mask } # Convert to tensors if requested if return_tensors == 'pt': import torch result = {k: torch.tensor(v) for k, v in result.items()} return result # Model architecture definitions for inference class MultiHeadAttention(nn.Module): """Advanced multi-headed attention with relative positional encoding""" def __init__(self, config): super().__init__() self.num_attention_heads = config["num_attention_heads"] self.attention_head_size = config["hidden_size"] // config["num_attention_heads"] self.all_head_size = self.num_attention_heads * self.attention_head_size # Query, Key, Value projections self.query = nn.Linear(config["hidden_size"], self.all_head_size) self.key = nn.Linear(config["hidden_size"], self.all_head_size) self.value = nn.Linear(config["hidden_size"], self.all_head_size) # Output projection self.output = nn.Sequential( nn.Linear(self.all_head_size, config["hidden_size"]), nn.Dropout(config["attention_probs_dropout_prob"]) ) # Simplified relative position bias approach self.max_position_embeddings = config["max_position_embeddings"] self.relative_attention_bias = nn.Embedding( 2 * config["max_position_embeddings"] - 1, config["num_attention_heads"] ) def transpose_for_scores(self, x): new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask=None): batch_size, seq_length = hidden_states.size()[:2] # Project inputs to queries, keys, and values query_layer = self.transpose_for_scores(self.query(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) # Take the dot product between query and key to get the raw attention scores attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # Generate relative position matrix position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device) relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0) # [seq_len, seq_len] # Shift values to be >= 0 relative_position = relative_position + self.max_position_embeddings - 1 # Ensure indices are within bounds relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2) # Get relative position embeddings [seq_len, seq_len, num_heads] rel_attn_bias = self.relative_attention_bias(relative_position) # [seq_len, seq_len, num_heads] # Reshape to add to attention heads [1, num_heads, seq_len, seq_len] rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0) # Add to attention scores - now dimensions will match attention_scores = attention_scores + rel_attn_bias # Scale attention scores attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply attention mask if attention_mask is not None: attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities attention_probs = F.softmax(attention_scores, dim=-1) # Apply dropout attention_probs = F.dropout(attention_probs, p=0.1, training=self.training) # Apply attention to values context_layer = torch.matmul(attention_probs, value_layer) # Reshape back to [batch_size, seq_length, hidden_size] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_shape) # Final output projection output = self.output(context_layer) return output class EnhancedTransformerLayer(nn.Module): """Advanced transformer layer with pre-layer norm and enhanced attention""" def __init__(self, config): super().__init__() self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"]) self.attention = MultiHeadAttention(config) self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"]) # Feed-forward network self.ffn = nn.Sequential( nn.Linear(config["hidden_size"], config["intermediate_size"]), nn.GELU(), nn.Dropout(config["hidden_dropout_prob"]), nn.Linear(config["intermediate_size"], config["hidden_size"]), nn.Dropout(config["hidden_dropout_prob"]) ) def forward(self, hidden_states, attention_mask=None): # Pre-layer norm for attention attn_norm_hidden = self.attention_pre_norm(hidden_states) # Self-attention attention_output = self.attention(attn_norm_hidden, attention_mask) # Residual connection hidden_states = hidden_states + attention_output # Pre-layer norm for feed-forward ffn_norm_hidden = self.ffn_pre_norm(hidden_states) # Feed-forward ffn_output = self.ffn(ffn_norm_hidden) # Residual connection hidden_states = hidden_states + ffn_output return hidden_states class AdvancedTransformerModel(nn.Module): """Advanced Transformer model for inference""" def __init__(self, config): super().__init__() self.config = config # Embeddings self.word_embeddings = nn.Embedding( config["vocab_size"], config["hidden_size"], padding_idx=config["pad_token_id"] ) # Position embeddings self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"]) # Embedding dropout self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"]) # Transformer layers self.layers = nn.ModuleList([ EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"]) ]) # Final layer norm self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"]) def forward(self, input_ids, attention_mask=None): input_shape = input_ids.size() batch_size, seq_length = input_shape # Get position ids position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Get embeddings word_embeds = self.word_embeddings(input_ids) position_embeds = self.position_embeddings(position_ids) # Sum embeddings embeddings = word_embeds + position_embeds # Apply dropout embeddings = self.embedding_dropout(embeddings) # Default attention mask if attention_mask is None: attention_mask = torch.ones(input_shape, device=input_ids.device) # Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # Apply transformer layers hidden_states = embeddings for layer in self.layers: hidden_states = layer(hidden_states, extended_attention_mask) # Final layer norm hidden_states = self.final_layer_norm(hidden_states) return hidden_states class AdvancedPooling(nn.Module): """Advanced pooling module supporting multiple pooling strategies""" def __init__(self, config): super().__init__() self.pooling_mode = config["pooling_mode"] # 'mean', 'max', 'cls', 'attention' self.hidden_size = config["hidden_size"] # For attention pooling if self.pooling_mode == 'attention': self.attention_weights = nn.Linear(config["hidden_size"], 1) # For weighted pooling elif self.pooling_mode == 'weighted': self.weight_layer = nn.Linear(config["hidden_size"], 1) def forward(self, token_embeddings, attention_mask=None): if attention_mask is None: attention_mask = torch.ones_like(token_embeddings[:, :, 0]) mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() if self.pooling_mode == 'cls': # Use [CLS] token (first token) pooled = token_embeddings[:, 0] elif self.pooling_mode == 'max': # Max pooling token_embeddings = token_embeddings.clone() # Set padding tokens to large negative value to exclude them from max token_embeddings[mask_expanded == 0] = -1e9 pooled = torch.max(token_embeddings, dim=1)[0] elif self.pooling_mode == 'attention': # Attention pooling weights = self.attention_weights(token_embeddings).squeeze(-1) # Mask out padding tokens weights = weights.masked_fill(attention_mask == 0, -1e9) weights = F.softmax(weights, dim=1).unsqueeze(-1) pooled = torch.sum(token_embeddings * weights, dim=1) elif self.pooling_mode == 'weighted': # Weighted average pooling weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1) # Apply mask weights = weights * attention_mask # Normalize weights sum_weights = torch.sum(weights, dim=1, keepdim=True) sum_weights = torch.clamp(sum_weights, min=1e-9) weights = weights / sum_weights # Apply weights pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1) else: # Default to mean pooling # Mean pooling sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1) sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) pooled = sum_embeddings / sum_mask # L2 normalize pooled = F.normalize(pooled, p=2, dim=1) return pooled class SentenceEmbeddingModel(nn.Module): """Complete sentence embedding model for inference""" def __init__(self, config): super(SentenceEmbeddingModel, self).__init__() self.config = config # Create transformer model self.transformer = AdvancedTransformerModel(config) # Create pooling module self.pooling = AdvancedPooling(config) # Build projection module if needed if "projection_dim" in config and config["projection_dim"] > 0: self.use_projection = True self.projection = nn.Sequential( nn.Linear(config["hidden_size"], config["hidden_size"]), nn.GELU(), nn.Linear(config["hidden_size"], config["projection_dim"]), nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"]) ) else: self.use_projection = False def forward(self, input_ids, attention_mask=None): # Get token embeddings from transformer token_embeddings = self.transformer(input_ids, attention_mask) # Pool token embeddings pooled_output = self.pooling(token_embeddings, attention_mask) # Apply projection if enabled if self.use_projection: pooled_output = self.projection(pooled_output) pooled_output = F.normalize(pooled_output, p=2, dim=1) return pooled_output def convert_to_safetensors(model_path, output_path): """Convert PyTorch model to safetensors format""" print(f"Converting model from {model_path} to safetensors format...") try: # First try with weights_only=False to handle PyTorch 2.6+ checkpoints checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) print("Successfully loaded checkpoint with weights_only=False") except TypeError: # For older PyTorch versions that don't have weights_only parameter print("Falling back to default torch.load behavior for older PyTorch versions") checkpoint = torch.load(model_path, map_location="cpu") # Get model state dict if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] print("Extracted model_state_dict from checkpoint") else: state_dict = checkpoint print("Using entire checkpoint as state_dict") # Save as safetensors save_file(state_dict, output_path) print(f"Model converted and saved to {output_path}") def load_model_and_tokenizer(model_dir, tokenizer_dir="/home/ubuntu/hindi_tokenizer"): """Load the model and tokenizer for inference""" # Load the config config_path = os.path.join(model_dir, "config.json") with open(config_path, "r") as f: config = json.load(f) # Load the tokenizer - use specified tokenizer directory tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.model") if not os.path.exists(tokenizer_path): # Try other locations tokenizer_path = os.path.join(model_dir, "tokenizer.model") if not os.path.exists(tokenizer_path): raise FileNotFoundError(f"Could not find tokenizer model at {tokenizer_path}") tokenizer = SentencePieceTokenizerWrapper(tokenizer_path) print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {tokenizer.vocab_size}") # Load the model safetensors_path = os.path.join(model_dir, "embedding_model.safetensors") if not os.path.exists(safetensors_path): print(f"Safetensors model not found at {safetensors_path}, converting from PyTorch checkpoint...") # Convert from PyTorch checkpoint pytorch_path = os.path.join(model_dir, "embedding_model.pt") if not os.path.exists(pytorch_path): raise FileNotFoundError(f"Could not find PyTorch model at {pytorch_path}") convert_to_safetensors(pytorch_path, safetensors_path) # Load state dict from safetensors state_dict = load_file(safetensors_path) # Create model model = SentenceEmbeddingModel(config) # Load state dict try: # Try direct loading missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) print(f"Loaded model with missing keys: {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}") print(f"Unexpected keys: {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}") except Exception as e: print(f"Error loading state dict: {e}") print("Model will be initialized with random weights") model.eval() return model, tokenizer, config # LangChain Custom Embeddings Class class HindiSentenceEmbeddings(Embeddings): """ Custom Langchain Embeddings class for Hindi sentence embeddings model """ def __init__(self, model, tokenizer, device="cuda", batch_size=32, max_length=128): """Initialize with model, tokenizer, and inference parameters""" self.model = model self.tokenizer = tokenizer self.device = device self.batch_size = batch_size self.max_length = max_length def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of documents/texts""" embeddings = [] with torch.no_grad(): for i in range(0, len(texts), self.batch_size): batch = texts[i:i+self.batch_size] # Tokenize inputs = self.tokenizer( batch, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt" ) # Move to device input_ids = inputs["input_ids"].to(self.device) attention_mask = inputs["attention_mask"].to(self.device) # Get embeddings batch_embeddings = self.model(input_ids, attention_mask) # Move to CPU and convert to numpy batch_embeddings = batch_embeddings.cpu().numpy() embeddings.append(batch_embeddings) return np.vstack(embeddings).tolist() def embed_query(self, text: str) -> List[float]: """Embed a single query/text""" return self.embed_documents([text])[0] def extract_relevant_sentences(text, query, window_size=2): """ Extract the most relevant sentences from text based on query keywords Args: text: The full text content query: The user's query window_size: Number of sentences to include before and after matched sentence Returns: String containing the most relevant portion of the text """ # Clean and normalize query and text for matching query = query.strip().lower() # Remove question marks and other punctuation from query for matching query = re.sub(r'[?।॥!,.:]', '', query) # Extract keywords from the query (remove common Hindi stop words) stop_words = ['और', 'का', 'के', 'को', 'में', 'से', 'है', 'हैं', 'था', 'थे', 'की', 'कि', 'पर', 'एक', 'यह', 'वह', 'जो', 'ने', 'हो', 'कर'] query_terms = [word for word in query.split() if word not in stop_words] if not query_terms: return text # If no meaningful terms left, return the full text # Split text into sentences (using Hindi sentence terminators) sentences = re.split(r'([।॥!?.])', text) # Rejoin sentences with their terminators complete_sentences = [] for i in range(0, len(sentences)-1, 2): if i+1 < len(sentences): complete_sentences.append(sentences[i] + sentences[i+1]) else: complete_sentences.append(sentences[i]) # If the above didn't work properly, try simpler approach if len(complete_sentences) <= 1: complete_sentences = re.split(r'[।॥!?.]', text) complete_sentences = [s.strip() for s in complete_sentences if s.strip()] # Score each sentence based on how many query terms it contains sentence_scores = [] for i, sentence in enumerate(complete_sentences): sentence_lower = sentence.lower() # Calculate score based on number of query terms found score = sum(1 for term in query_terms if term in sentence_lower) sentence_scores.append((i, score)) # Find the best matching sentence if not sentence_scores: return text[:500] + "..." # Fallback # Get the index of sentence with highest score best_match_idx, best_score = max(sentence_scores, key=lambda x: x[1]) # If no good match found, return the whole text (up to a limit) if best_score == 0: # Try partial word matching as a fallback for i, sentence in enumerate(complete_sentences): sentence_lower = sentence.lower() partial_score = sum(1 for term in query_terms if any(term in word.lower() for word in sentence_lower.split())) if partial_score > 0: best_match_idx = i break else: # If still no match, just return the first part of the text if len(text) > 1000: return text[:1000] + "..." return text # Get window of sentences around the best match start_idx = max(0, best_match_idx - window_size) end_idx = min(len(complete_sentences), best_match_idx + window_size + 1) # Create excerpt relevant_text = ' '.join(complete_sentences[start_idx:end_idx]) # If the excerpt is short, return more context if len(relevant_text) < 100 and len(text) > len(relevant_text): # Add more context if end_idx < len(complete_sentences): relevant_text += ' ' + ' '.join(complete_sentences[end_idx:end_idx+2]) if start_idx > 0: relevant_text = ' '.join(complete_sentences[max(0, start_idx-2):start_idx]) + ' ' + relevant_text # If the excerpt is too short or the whole text is small anyway, return whole text if len(relevant_text) < 50 or len(text) < 1000: return text return relevant_text # Text processing and indexing functions def load_and_process_text_file(file_path, chunk_size=500, chunk_overlap=100): """ Load a text file and split it into semantically meaningful chunks """ print(f"Loading and processing text file: {file_path}") # Read the file content with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # For small files, just keep the whole content as a single chunk if len(content) <= chunk_size * 2: print(f"File content is small, keeping as a single chunk") return [Document( page_content=content, metadata={ "source": file_path, "chunk_id": 0 } )] # Split by paragraphs first paragraphs = re.split(r'\n\s*\n', content) chunks = [] current_chunk = "" current_size = 0 for para in paragraphs: if not para.strip(): continue # If adding this paragraph would exceed the chunk size, save current chunk and start new one if current_size + len(para) > chunk_size and current_size > 0: chunks.append(current_chunk) current_chunk = para current_size = len(para) else: # Add paragraph to current chunk with a newline if not empty if current_size > 0: current_chunk += "\n\n" + para else: current_chunk = para current_size = len(current_chunk) # Add the last chunk if not empty if current_chunk: chunks.append(current_chunk) print(f"Split text into {len(chunks)} chunks") # Convert to LangChain documents with metadata documents = [ Document( page_content=chunk, metadata={ "source": file_path, "chunk_id": i } ) for i, chunk in enumerate(chunks) ] return documents def create_vector_store(documents, embeddings, store_path=None): """ Create a FAISS vector store from documents using the given embeddings """ print("Creating FAISS vector store...") # Create vector store vector_store = LangchainFAISS.from_documents(documents, embeddings) # Save if path is provided if store_path: print(f"Saving vector store to {store_path}") vector_store.save_local(store_path) return vector_store def load_vector_store(store_path, embeddings): """ Load a FAISS vector store from disk """ print(f"Loading vector store from {store_path}") return LangchainFAISS.load_local(store_path, embeddings, allow_dangerous_deserialization=True) def perform_similarity_search(vector_store, query, k=6): """ Perform basic similarity search on the vector store """ print(f"Searching for: {query}") return vector_store.similarity_search_with_score(query, k=k) # Main RAG functions def index_text_files(model, tokenizer, data_dir, output_dir, device="cuda", chunk_size=500): """ Index text files from a directory and create a FAISS vector store """ print(f"Indexing text files from {data_dir} with chunk size ({chunk_size}) for fine-grained retrieval") # Create embedding model embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device) # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) # Get all text files text_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.txt')] print(f"Found {len(text_files)} text files") # Process all text files all_documents = [] for file_path in text_files: documents = load_and_process_text_file(file_path, chunk_size=chunk_size) all_documents.extend(documents) print(f"Total documents: {len(all_documents)}") # If we don't have enough chunks, reduce chunk size and try again if len(all_documents) < 10 and chunk_size > 50: print(f"Not enough chunks created. Reducing chunk size and trying again...") return index_text_files(model, tokenizer, data_dir, output_dir, device, chunk_size=chunk_size//2) # Create and save vector store vector_store_path = os.path.join(output_dir, "faiss_index") vector_store = create_vector_store(all_documents, embeddings, vector_store_path) return vector_store, embeddings def query_text_corpus(model, tokenizer, vector_store_path, query, k=6, device="cuda"): """ Query the text corpus using the indexed vector store """ # Create embedding model embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device) # Load vector store vector_store = load_vector_store(vector_store_path, embeddings) # Perform similarity search results = perform_similarity_search(vector_store, query, k=k) # Post-process results to combine adjacent chunks if they're from the same source processed_results = [] seen_chunks = set() for doc, score in results: chunk_id = doc.metadata["chunk_id"] source = doc.metadata["source"] # Skip if we've already included this chunk if (source, chunk_id) in seen_chunks: continue seen_chunks.add((source, chunk_id)) # Try to find adjacent chunks and combine them combined_content = doc.page_content # Look for adjacent chunks in results (both previous and next) for adj_id in [chunk_id-1, chunk_id+1]: for other_doc, _ in results: if (other_doc.metadata["source"] == source and other_doc.metadata["chunk_id"] == adj_id and (source, adj_id) not in seen_chunks): # Add the adjacent chunk content if adj_id < chunk_id: # Previous chunk combined_content = other_doc.page_content + " " + combined_content else: # Next chunk combined_content = combined_content + " " + other_doc.page_content seen_chunks.add((source, adj_id)) # Create a new document with combined content combined_doc = Document( page_content=combined_content, metadata={ "source": source, "chunk_id": chunk_id, "is_combined": True if combined_content != doc.page_content else False } ) processed_results.append((combined_doc, score)) return processed_results def main(): parser = argparse.ArgumentParser(description="Hindi RAG System with LangChain and FAISS") parser.add_argument("--model_dir", type=str, default="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final", help="Directory containing the model and tokenizer") parser.add_argument("--tokenizer_dir", type=str, default="/home/ubuntu/hindi_tokenizer", help="Directory containing the tokenizer") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run inference on ('cuda' or 'cpu')") parser.add_argument("--index", action="store_true", help="Index text files from data directory") parser.add_argument("--query", type=str, default=None, help="Query to search in the indexed corpus") parser.add_argument("--data_dir", type=str, default="./data", help="Directory containing text files for indexing") parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the indexed vector store") parser.add_argument("--top_k", type=int, default=6, help="Number of top results to return") parser.add_argument("--chunk_size", type=int, default=500, help="Size of text chunks for indexing") parser.add_argument("--interactive", action="store_true", help="Run in interactive mode for querying") parser.add_argument("--reindex", action="store_true", help="Force reindexing even if index exists") args = parser.parse_args() # Load model and tokenizer model, tokenizer, config = load_model_and_tokenizer(args.model_dir, args.tokenizer_dir) # Move model to device model = model.to(args.device) # Create vector store path vector_store_path = os.path.join(args.output_dir, "faiss_index") if args.index or args.reindex: # Index text files index_text_files(model, tokenizer, args.data_dir, args.output_dir, args.device, args.chunk_size) print(f"Indexing complete. Vector store saved to {vector_store_path}") if args.query: # Query the corpus results = query_text_corpus(model, tokenizer, vector_store_path, args.query, args.top_k, args.device) # Print results print("\nSearch Results:") for i, (doc, score) in enumerate(results): print(f"\nResult {i+1} (Score: {score:.4f}):") print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}") # Extract and print only relevant sentences relevant_text = extract_relevant_sentences(doc.page_content, args.query) print(f"Content: {relevant_text}") if args.interactive: print("\nInteractive mode. Enter queries (or type 'quit' to exit).") while True: print("\nEnter query:") query = input() if not query.strip(): continue if query.lower() == 'quit': break # Query the corpus results = query_text_corpus(model, tokenizer, vector_store_path, query, args.top_k, args.device) # Print results print("\nSearch Results:") for i, (doc, score) in enumerate(results): print(f"\nResult {i+1} (Score: {score:.4f}):") print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}") # Extract and print only relevant sentences relevant_text = extract_relevant_sentences(doc.page_content, query) print(f"Content: {relevant_text}") if __name__ == "__main__": main()