23RAG7 / data_processing.py
cb1716pics's picture
Upload data_processing.py
d346441 verified
raw
history blame
3.02 kB
import numpy as np
import faiss
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
import json
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="paraphrase-MiniLM-L3-v2",
model_kwargs={"device": device}
)
all_documents = []
ragbench = {}
index = None
actual_docs = []
# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)
def create_faiss_index_file():
global index # Ensure we use the global FAISS index
all_documents.clear() # Reset document list
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']:
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
for split in ragbench_dataset.keys():
for row in ragbench_dataset[split]:
doc = row["documents"]
if isinstance(doc, list):
doc = " ".join(doc) # Convert list to string if needed
all_documents.append(doc)
# Convert documents to embeddings
embeddings = embedding_model.embed_documents(all_documents)
embeddings_np = np.array(embeddings, dtype=np.float32)
# Initialize and store in FAISS
index = faiss.IndexFlatL2(embeddings_np.shape[1])
index.add(embeddings_np)
# Save FAISS index
faiss.write_index(index, "data_local/rag7_index.faiss")
# Save documents metadata
with open("data_local/rag7_docs.json", "w") as f:
json.dump(all_documents, f)
print("FAISS index and metadata saved successfully!")
def load_ragbench():
global ragbench
ragbench.clear() # Reset dictionary
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
def load_faiss():
global index
faiss_index_path = "data_local/rag7_index.faiss"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully.")
else:
print("FAISS index file not found. Run create_faiss_index_file() first.")
def load_metadata():
global actual_docs
metadata_path = "data_local/rag7_docs.json"
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
actual_docs = json.load(f)
print("Metadata loaded successfully.")
else:
print("Metadata file not found. Run create_faiss_index_file() first.")
def load_data_from_faiss():
load_faiss()
load_metadata()
#return index, actual_docs