Spaces:
Sleeping
Sleeping
File size: 3,021 Bytes
1b04b96 9bfc86c 1b04b96 d346441 1b04b96 43b460f d346441 1b04b96 d346441 1b04b96 d346441 1b04b96 d346441 1b04b96 d346441 1b04b96 d346441 1b04b96 d346441 1b04b96 43b460f d346441 43b460f d346441 a523549 d346441 1b04b96 d346441 a523549 d346441 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import numpy as np
import faiss
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
import json
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="paraphrase-MiniLM-L3-v2",
model_kwargs={"device": device}
)
all_documents = []
ragbench = {}
index = None
actual_docs = []
# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)
def create_faiss_index_file():
global index # Ensure we use the global FAISS index
all_documents.clear() # Reset document list
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']:
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
for split in ragbench_dataset.keys():
for row in ragbench_dataset[split]:
doc = row["documents"]
if isinstance(doc, list):
doc = " ".join(doc) # Convert list to string if needed
all_documents.append(doc)
# Convert documents to embeddings
embeddings = embedding_model.embed_documents(all_documents)
embeddings_np = np.array(embeddings, dtype=np.float32)
# Initialize and store in FAISS
index = faiss.IndexFlatL2(embeddings_np.shape[1])
index.add(embeddings_np)
# Save FAISS index
faiss.write_index(index, "data_local/rag7_index.faiss")
# Save documents metadata
with open("data_local/rag7_docs.json", "w") as f:
json.dump(all_documents, f)
print("FAISS index and metadata saved successfully!")
def load_ragbench():
global ragbench
ragbench.clear() # Reset dictionary
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
def load_faiss():
global index
faiss_index_path = "data_local/rag7_index.faiss"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully.")
else:
print("FAISS index file not found. Run create_faiss_index_file() first.")
def load_metadata():
global actual_docs
metadata_path = "data_local/rag7_docs.json"
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
actual_docs = json.load(f)
print("Metadata loaded successfully.")
else:
print("Metadata file not found. Run create_faiss_index_file() first.")
def load_data_from_faiss():
load_faiss()
load_metadata()
#return index, actual_docs
|