File size: 5,462 Bytes
9bfc86c
1b04b96
 
d346441
99afa50
 
 
 
 
 
 
1b04b96
 
 
 
 
99afa50
1b04b96
 
 
99afa50
8848e89
99afa50
ce3af46
 
0e36212
 
 
 
ce3af46
1b04b96
43b460f
d346441
99afa50
 
da626d3
1b04b96
ece1395
99afa50
 
 
 
d346441
99afa50
 
 
 
 
 
 
 
 
 
 
 
 
ece1395
 
 
99afa50
 
 
 
 
d346441
99afa50
 
 
1b04b96
99afa50
d346441
1b04b96
 
99afa50
 
 
1b04b96
99afa50
1b04b96
43b460f
d346441
58a211a
3fd6562
58a211a
 
 
43b460f
3fd6562
43b460f
599d161
da626d3
8848e89
 
 
 
599d161
8848e89
 
 
 
599d161
da626d3
8848e89
7c78daa
a523549
7c78daa
d346441
 
 
 
99afa50
1b04b96
7c78daa
99afa50
7c78daa
d346441
 
99afa50
d346441
 
 
 
7c78daa
 
 
99afa50
 
 
 
 
 
 
0e36212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce3af46
 
 
 
0e36212
e6167f8
0e36212
 
4433c64
0e36212
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import faiss
import torch
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load embedding model
embedding_model = HuggingFaceEmbeddings(
    model_name="all-MiniLM-L12-v2",
    model_kwargs={"device": device}
)

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
query_dataset_data = {}

RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"

# # Ensure the file exists and initialize if empty
# if not os.path.exists(RECENT_QUESTIONS_FILE):
#     with open(RECENT_QUESTIONS_FILE, "w") as file:
#         json.dump({"questions": []}, file, indent=4)

all_documents = []
ragbench = {}
index = None  
chunk_docs = []
documents = [] 
query_dataset_data = {}

# Text splitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1024,
    chunk_overlap=100
)

def chunk_documents(docs):
    chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
    return chunks

def create_faiss_index(dataset):
    # Load dataset
    ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)

    for split in ragbench_dataset.keys():
        for row in ragbench_dataset[split]:
            doc = row["documents"]
            if isinstance(doc, list):
                doc = " ".join(doc)
            documents.append(doc)  #
            
    # Chunking
    chunked_documents = chunk_documents(documents)

    # Save documents in JSON (metadata storage)
    with open(f"{dataset}_chunked_docs.json", "w") as f:
        json.dump(chunked_documents, f)

    print(len(chunked_documents))
    # Convert to embeddings
    embeddings = embedding_model.embed_documents(chunked_documents)

    # Convert embeddings to a NumPy array
    embeddings_np = np.array(embeddings, dtype=np.float32)

    # Save FAISS index
    index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32)  # 32 is the graph size
    index.add(embeddings_np)
    faiss.write_index(index, f"{dataset}_chunked_index.faiss")

    print(f"{dataset} stored as individual FAISS index!")

def load_ragbench():
    global ragbench
    if ragbench:
        return ragbench  
    datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 
                    'tatqa', 'techqa']
    for dataset in datasets:
        ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
    return ragbench

def load_query_dataset(q_dataset):
    global query_dataset_data

    if query_dataset_data.get(q_dataset):  
        return query_dataset_data[q_dataset]
    try:
        query_dataset_data[q_dataset] = load_dataset("rungalileo/ragbench", q_dataset)
    except Exception as e:
        print(f"Error loading dataset '{q_dataset}': {e}")
        return None  # Return None if the dataset fails to load

    return query_dataset_data[q_dataset]


def load_faiss(q_dataset):
    global index
    faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
    if os.path.exists(faiss_index_path):
        index = faiss.read_index(faiss_index_path)
        print("FAISS index loaded successfully.")
    else:
        print("FAISS index file not found. Run create_faiss_index_file() first.") 

def load_chunks(q_dataset):
    global chunk_docs
    metadata_path = f"data_local/{q_dataset}_chunked_docs.json"
    if os.path.exists(metadata_path):
        with open(metadata_path, "r") as f:
            chunk_docs = json.load(f)
        print("Metadata loaded successfully.")
    else:
        print("Metadata file not found. Run create_faiss_index_file() first.")

def load_data_from_faiss(q_dataset):
    load_faiss(q_dataset)
    load_chunks(q_dataset)

def rerank_documents(query, retrieved_docs):
    doc_texts = [doc for doc in retrieved_docs]
    scores = reranker.predict([[query, doc] for doc in doc_texts])
    ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
    return ranked_docs[:5]  # Return top 5 most relevant

# def load_recent_questions():
#     if os.path.exists(RECENT_QUESTIONS_FILE):
#         with open(RECENT_QUESTIONS_FILE, "r") as file:
#             return json.load(file)
#     return {"questions": []}  

# def save_recent_question(question, metrics_1):
#     data = load_recent_questions()

#     # Append new question & metrics
#     data["questions"].append({
#         "question": question,
#         "metrics": metrics_1
#     })

#     # # Keep only the last 5 questions
#     # data["questions"]  = data["questions"][-5:]

#     # Write back to file
#     with open(RECENT_QUESTIONS_FILE, "w") as file:
#         json.dump(data, file, indent=4)

   # Load previous questions from file
def load_recent_questions():
    if os.path.exists(RECENT_QUESTIONS_FILE):
        with open(RECENT_QUESTIONS_FILE, "r") as file:
            return json.load(file)
    return []

# Save questions to file
def save_recent_questions(data):
    with open(RECENT_QUESTIONS_FILE, "w") as file:
        json.dump(data, file, indent=4)