23RAG7 / data_processing.py
cb1716pics's picture
Upload 4 files
0e36212 verified
raw
history blame
5.46 kB
import faiss
import torch
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="all-MiniLM-L12-v2",
model_kwargs={"device": device}
)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
query_dataset_data = {}
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
# # Ensure the file exists and initialize if empty
# if not os.path.exists(RECENT_QUESTIONS_FILE):
# with open(RECENT_QUESTIONS_FILE, "w") as file:
# json.dump({"questions": []}, file, indent=4)
all_documents = []
ragbench = {}
index = None
chunk_docs = []
documents = []
query_dataset_data = {}
# Text splitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=100
)
def chunk_documents(docs):
chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
return chunks
def create_faiss_index(dataset):
# Load dataset
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
for split in ragbench_dataset.keys():
for row in ragbench_dataset[split]:
doc = row["documents"]
if isinstance(doc, list):
doc = " ".join(doc)
documents.append(doc) #
# Chunking
chunked_documents = chunk_documents(documents)
# Save documents in JSON (metadata storage)
with open(f"{dataset}_chunked_docs.json", "w") as f:
json.dump(chunked_documents, f)
print(len(chunked_documents))
# Convert to embeddings
embeddings = embedding_model.embed_documents(chunked_documents)
# Convert embeddings to a NumPy array
embeddings_np = np.array(embeddings, dtype=np.float32)
# Save FAISS index
index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size
index.add(embeddings_np)
faiss.write_index(index, f"{dataset}_chunked_index.faiss")
print(f"{dataset} stored as individual FAISS index!")
def load_ragbench():
global ragbench
if ragbench:
return ragbench
datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']
for dataset in datasets:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
return ragbench
def load_query_dataset(q_dataset):
global query_dataset_data
if query_dataset_data.get(q_dataset):
return query_dataset_data[q_dataset]
try:
query_dataset_data[q_dataset] = load_dataset("rungalileo/ragbench", q_dataset)
except Exception as e:
print(f"Error loading dataset '{q_dataset}': {e}")
return None # Return None if the dataset fails to load
return query_dataset_data[q_dataset]
def load_faiss(q_dataset):
global index
faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully.")
else:
print("FAISS index file not found. Run create_faiss_index_file() first.")
def load_chunks(q_dataset):
global chunk_docs
metadata_path = f"data_local/{q_dataset}_chunked_docs.json"
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
chunk_docs = json.load(f)
print("Metadata loaded successfully.")
else:
print("Metadata file not found. Run create_faiss_index_file() first.")
def load_data_from_faiss(q_dataset):
load_faiss(q_dataset)
load_chunks(q_dataset)
def rerank_documents(query, retrieved_docs):
doc_texts = [doc for doc in retrieved_docs]
scores = reranker.predict([[query, doc] for doc in doc_texts])
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
return ranked_docs[:5] # Return top 5 most relevant
# def load_recent_questions():
# if os.path.exists(RECENT_QUESTIONS_FILE):
# with open(RECENT_QUESTIONS_FILE, "r") as file:
# return json.load(file)
# return {"questions": []}
# def save_recent_question(question, metrics_1):
# data = load_recent_questions()
# # Append new question & metrics
# data["questions"].append({
# "question": question,
# "metrics": metrics_1
# })
# # # Keep only the last 5 questions
# # data["questions"] = data["questions"][-5:]
# # Write back to file
# with open(RECENT_QUESTIONS_FILE, "w") as file:
# json.dump(data, file, indent=4)
# Load previous questions from file
def load_recent_questions():
if os.path.exists(RECENT_QUESTIONS_FILE):
with open(RECENT_QUESTIONS_FILE, "r") as file:
return json.load(file)
return []
# Save questions to file
def save_recent_questions(data):
with open(RECENT_QUESTIONS_FILE, "w") as file:
json.dump(data, file, indent=4)