File size: 3,021 Bytes
1b04b96
9bfc86c
1b04b96
 
 
 
 
d346441
1b04b96
 
 
 
 
 
 
 
 
 
43b460f
d346441
 
1b04b96
d346441
 
1b04b96
 
d346441
 
 
 
 
 
 
 
 
 
 
 
 
 
1b04b96
d346441
 
 
1b04b96
d346441
 
 
1b04b96
 
d346441
 
 
 
1b04b96
 
d346441
1b04b96
43b460f
d346441
 
 
 
 
43b460f
 
d346441
a523549
d346441
 
 
 
 
 
1b04b96
d346441
a523549
d346441
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import faiss
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
import json
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load embedding model
embedding_model = HuggingFaceEmbeddings(
    model_name="paraphrase-MiniLM-L3-v2",
    model_kwargs={"device": device}
)

all_documents = []
ragbench = {}
index = None  
actual_docs = []  

# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)

def create_faiss_index_file():
    global index  # Ensure we use the global FAISS index
    all_documents.clear()  # Reset document list

    for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 
                    'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 
                    'tatqa', 'techqa']:
        ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)

        for split in ragbench_dataset.keys():
            for row in ragbench_dataset[split]:
                doc = row["documents"]
                if isinstance(doc, list):
                    doc = " ".join(doc)  # Convert list to string if needed
                all_documents.append(doc)  

    # Convert documents to embeddings
    embeddings = embedding_model.embed_documents(all_documents)
    embeddings_np = np.array(embeddings, dtype=np.float32)

    # Initialize and store in FAISS
    index = faiss.IndexFlatL2(embeddings_np.shape[1])
    index.add(embeddings_np)

    # Save FAISS index
    faiss.write_index(index, "data_local/rag7_index.faiss")

    # Save documents metadata
    with open("data_local/rag7_docs.json", "w") as f:
        json.dump(all_documents, f)

    print("FAISS index and metadata saved successfully!")

def load_ragbench():
    global ragbench
    ragbench.clear()  # Reset dictionary
    for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 
                    'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 
                    'tatqa', 'techqa']:
        ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)

def load_faiss():
    global index
    faiss_index_path = "data_local/rag7_index.faiss"
    if os.path.exists(faiss_index_path):
        index = faiss.read_index(faiss_index_path)
        print("FAISS index loaded successfully.")
    else:
        print("FAISS index file not found. Run create_faiss_index_file() first.")

def load_metadata():
    global actual_docs
    metadata_path = "data_local/rag7_docs.json"
    if os.path.exists(metadata_path):
        with open(metadata_path, "r") as f:
            actual_docs = json.load(f)
        print("Metadata loaded successfully.")
    else:
        print("Metadata file not found. Run create_faiss_index_file() first.")

def load_data_from_faiss():
    load_faiss()
    load_metadata()
    #return index, actual_docs