23RAG7 / data_processing.py
cb1716pics's picture
Upload 2 files
ce3af46 verified
raw
history blame
4.9 kB
import faiss
import torch
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="all-MiniLM-L12-v2",
model_kwargs={"device": device}
)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# File path for storing recently asked questions and metrics
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
# Ensure the file exists and initialize if empty
if not os.path.exists(RECENT_QUESTIONS_FILE):
with open(RECENT_QUESTIONS_FILE, "w") as file:
json.dump({"questions": []}, file, indent=4)
all_documents = []
ragbench = {}
index = None
chunk_docs = []
documents = []
# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)
# Initialize a text splitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=100
)
def chunk_documents(docs):
chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
return chunks
def create_faiss_index(dataset):
# Load dataset
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
for split in ragbench_dataset.keys():
for row in ragbench_dataset[split]:
# Ensure document is a string before appending
doc = row["documents"]
if isinstance(doc, list):
# If doc is a list, join its elements into a single string
doc = " ".join(doc)
documents.append(doc) # Extract document text
# Chunking
chunked_documents = chunk_documents(documents)
# Save documents in JSON (metadata storage)
with open(f"{dataset}_chunked_docs.json", "w") as f:
json.dump(chunked_documents, f)
print(len(chunked_documents))
# Convert to embeddings
embeddings = embedding_model.embed_documents(chunked_documents)
# Convert embeddings to a NumPy array
embeddings_np = np.array(embeddings, dtype=np.float32)
# Save FAISS index
index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size
index.add(embeddings_np)
faiss.write_index(index, f"{dataset}_chunked_index.faiss")
print(f"{dataset} stored as individual FAISS index!")
def load_ragbench():
global ragbench
if ragbench:
return ragbench
datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']
for dataset in datasets:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
return ragbench
def load_faiss(query_dataset):
global index
faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully.")
else:
print("FAISS index file not found. Run create_faiss_index_file() first.")
def load_chunks(query_dataset):
global chunk_docs
metadata_path = f"data_local/{query_dataset}_chunked_docs.json"
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
chunk_docs = json.load(f)
print("Metadata loaded successfully.")
else:
print("Metadata file not found. Run create_faiss_index_file() first.")
def load_data_from_faiss(query_dataset):
load_faiss(query_dataset)
load_chunks(query_dataset)
def rerank_documents(query, retrieved_docs):
doc_texts = [doc for doc in retrieved_docs]
scores = reranker.predict([[query, doc] for doc in doc_texts])
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
return ranked_docs[:5] # Return top 5 most relevant
def load_recent_questions():
if os.path.exists(RECENT_QUESTIONS_FILE):
with open(RECENT_QUESTIONS_FILE, "r") as file:
return json.load(file)
return {"questions": []} # Default structure if file doesn't exist
def save_recent_question(question, metrics):
data = load_recent_questions()
# Append new question & metrics
data["questions"].append({
"question": question,
"metrics": metrics
})
# Keep only the last 5 questions
data["questions"] = data["questions"][-5:]
# Write back to file
with open(RECENT_QUESTIONS_FILE, "w") as file:
json.dump(data, file, indent=4)