Spaces:
Sleeping
Sleeping
import numpy as np | |
import faiss | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import torch | |
import json | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load embedding model | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="paraphrase-MiniLM-L3-v2", | |
model_kwargs={"device": device} | |
) | |
all_documents = [] | |
ragbench = {} | |
def create_faiss_index_file(): | |
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', | |
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', | |
'tatqa', 'techqa']: | |
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset) | |
for split in ragbench_dataset.keys(): | |
for row in ragbench_dataset[split]: | |
doc = row["documents"] | |
if isinstance(doc, list): | |
doc = " ".join(doc) | |
all_documents.append(doc) | |
# Convert to embeddings | |
embeddings = embedding_model.embed_documents(all_documents) | |
# Convert embeddings to a NumPy array | |
embeddings_np = np.array(embeddings, dtype=np.float32) | |
global index_w | |
# Store in FAISS using the NumPy array's shape | |
index_w = faiss.IndexFlatL2(embeddings_np.shape[1]) | |
index_w.add(embeddings_np) | |
# Save FAISS index | |
faiss.write_index(index, f"data_local/rag7_index.faiss") | |
# Save documents in JSON (metadata storage) | |
with open(f"data_local/rag7_docs.json", "w") as f: | |
json.dump(all_documents, f) | |
print(f"data is stored!") | |
def load_data_from_faiss(): | |
load_faiss() | |
load_metatdata() | |
def load_ragbench(): | |
ragbench = {} | |
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']: | |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset) | |
def load_faiss(): | |
global index | |
faiss_index_path = f"data_local/rag7_index.faiss" | |
index = faiss.read_index(faiss_index_path) | |
def load_metatdata(): | |
global actual_docs | |
with open(f"data_local/rag7_docs.json", "r") as f: | |
actual_docs = json.load(f) # Contains all documents for this dataset | |