Spaces:
Sleeping
Sleeping
import numpy as np | |
import faiss | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import torch | |
import json | |
import os | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load embedding model | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="paraphrase-MiniLM-L3-v2", | |
model_kwargs={"device": device} | |
) | |
all_documents = [] | |
ragbench = {} | |
index = None | |
actual_docs = [] | |
# Ensure data directory exists | |
os.makedirs("data_local", exist_ok=True) | |
def create_faiss_index_file(): | |
global index # Ensure we use the global FAISS index | |
all_documents.clear() # Reset document list | |
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', | |
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', | |
'tatqa', 'techqa']: | |
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset) | |
for split in ragbench_dataset.keys(): | |
for row in ragbench_dataset[split]: | |
doc = row["documents"] | |
if isinstance(doc, list): | |
doc = " ".join(doc) # Convert list to string if needed | |
all_documents.append(doc) | |
# Convert documents to embeddings | |
embeddings = embedding_model.embed_documents(all_documents) | |
embeddings_np = np.array(embeddings, dtype=np.float32) | |
# Initialize and store in FAISS | |
index = faiss.IndexFlatL2(embeddings_np.shape[1]) | |
index.add(embeddings_np) | |
# Save FAISS index | |
faiss.write_index(index, "data_local/rag7_index.faiss") | |
# Save documents metadata | |
with open("data_local/rag7_docs.json", "w") as f: | |
json.dump(all_documents, f) | |
print("FAISS index and metadata saved successfully!") | |
def load_ragbench(): | |
global ragbench | |
ragbench.clear() # Reset dictionary | |
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', | |
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', | |
'tatqa', 'techqa']: | |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset) | |
def load_faiss(): | |
global index | |
faiss_index_path = "data_local/rag7_index.faiss" | |
if os.path.exists(faiss_index_path): | |
index = faiss.read_index(faiss_index_path) | |
print("FAISS index loaded successfully.") | |
else: | |
print("FAISS index file not found. Run create_faiss_index_file() first.") | |
def load_metadata(): | |
global actual_docs | |
metadata_path = "data_local/rag7_docs.json" | |
if os.path.exists(metadata_path): | |
with open(metadata_path, "r") as f: | |
actual_docs = json.load(f) | |
print("Metadata loaded successfully.") | |
else: | |
print("Metadata file not found. Run create_faiss_index_file() first.") | |
def load_data_from_faiss(): | |
load_faiss() | |
load_metadata() | |
#return index, actual_docs | |