Spaces:
Sleeping
Sleeping
Upload data_processing.py
Browse files- data_processing.py +55 -37
data_processing.py
CHANGED
@@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer
|
|
5 |
from datasets import load_dataset
|
6 |
import torch
|
7 |
import json
|
|
|
8 |
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
@@ -16,56 +17,73 @@ embedding_model = HuggingFaceEmbeddings(
|
|
16 |
|
17 |
all_documents = []
|
18 |
ragbench = {}
|
|
|
|
|
19 |
|
|
|
|
|
20 |
|
21 |
def create_faiss_index_file():
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
# Convert
|
38 |
-
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
index_w.add(embeddings_np)
|
44 |
|
45 |
# Save FAISS index
|
46 |
-
faiss.write_index(index,
|
47 |
-
|
48 |
-
# Save documents
|
49 |
-
with open(
|
50 |
json.dump(all_documents, f)
|
51 |
|
52 |
-
print(
|
53 |
-
|
54 |
-
def load_data_from_faiss():
|
55 |
-
load_faiss()
|
56 |
-
load_metatdata()
|
57 |
|
58 |
def load_ragbench():
|
59 |
-
ragbench
|
60 |
-
|
|
|
|
|
|
|
61 |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
|
62 |
|
63 |
-
def load_faiss():
|
64 |
global index
|
65 |
-
faiss_index_path =
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
def
|
69 |
global actual_docs
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from datasets import load_dataset
|
6 |
import torch
|
7 |
import json
|
8 |
+
import os
|
9 |
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
|
|
17 |
|
18 |
all_documents = []
|
19 |
ragbench = {}
|
20 |
+
index = None
|
21 |
+
actual_docs = []
|
22 |
|
23 |
+
# Ensure data directory exists
|
24 |
+
os.makedirs("data_local", exist_ok=True)
|
25 |
|
26 |
def create_faiss_index_file():
|
27 |
+
global index # Ensure we use the global FAISS index
|
28 |
+
all_documents.clear() # Reset document list
|
29 |
+
|
30 |
+
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
|
31 |
+
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
|
32 |
+
'tatqa', 'techqa']:
|
33 |
+
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
|
34 |
+
|
35 |
+
for split in ragbench_dataset.keys():
|
36 |
+
for row in ragbench_dataset[split]:
|
37 |
+
doc = row["documents"]
|
38 |
+
if isinstance(doc, list):
|
39 |
+
doc = " ".join(doc) # Convert list to string if needed
|
40 |
+
all_documents.append(doc)
|
41 |
|
42 |
+
# Convert documents to embeddings
|
43 |
+
embeddings = embedding_model.embed_documents(all_documents)
|
44 |
+
embeddings_np = np.array(embeddings, dtype=np.float32)
|
45 |
|
46 |
+
# Initialize and store in FAISS
|
47 |
+
index = faiss.IndexFlatL2(embeddings_np.shape[1])
|
48 |
+
index.add(embeddings_np)
|
|
|
49 |
|
50 |
# Save FAISS index
|
51 |
+
faiss.write_index(index, "data_local/rag7_index.faiss")
|
52 |
+
|
53 |
+
# Save documents metadata
|
54 |
+
with open("data_local/rag7_docs.json", "w") as f:
|
55 |
json.dump(all_documents, f)
|
56 |
|
57 |
+
print("FAISS index and metadata saved successfully!")
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def load_ragbench():
|
60 |
+
global ragbench
|
61 |
+
ragbench.clear() # Reset dictionary
|
62 |
+
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
|
63 |
+
'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
|
64 |
+
'tatqa', 'techqa']:
|
65 |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
|
66 |
|
67 |
+
def load_faiss():
|
68 |
global index
|
69 |
+
faiss_index_path = "data_local/rag7_index.faiss"
|
70 |
+
if os.path.exists(faiss_index_path):
|
71 |
+
index = faiss.read_index(faiss_index_path)
|
72 |
+
print("FAISS index loaded successfully.")
|
73 |
+
else:
|
74 |
+
print("FAISS index file not found. Run create_faiss_index_file() first.")
|
75 |
|
76 |
+
def load_metadata():
|
77 |
global actual_docs
|
78 |
+
metadata_path = "data_local/rag7_docs.json"
|
79 |
+
if os.path.exists(metadata_path):
|
80 |
+
with open(metadata_path, "r") as f:
|
81 |
+
actual_docs = json.load(f)
|
82 |
+
print("Metadata loaded successfully.")
|
83 |
+
else:
|
84 |
+
print("Metadata file not found. Run create_faiss_index_file() first.")
|
85 |
+
|
86 |
+
def load_data_from_faiss():
|
87 |
+
load_faiss()
|
88 |
+
load_metadata()
|
89 |
+
#return index, actual_docs
|