cb1716pics commited on
Commit
d346441
·
verified ·
1 Parent(s): a523549

Upload data_processing.py

Browse files
Files changed (1) hide show
  1. data_processing.py +55 -37
data_processing.py CHANGED
@@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer
5
  from datasets import load_dataset
6
  import torch
7
  import json
 
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
@@ -16,56 +17,73 @@ embedding_model = HuggingFaceEmbeddings(
16
 
17
  all_documents = []
18
  ragbench = {}
 
 
19
 
 
 
20
 
21
  def create_faiss_index_file():
22
- for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
23
- 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
24
- 'tatqa', 'techqa']:
25
- ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
26
- for split in ragbench_dataset.keys():
27
- for row in ragbench_dataset[split]:
28
- doc = row["documents"]
29
- if isinstance(doc, list):
30
- doc = " ".join(doc)
31
-
32
- all_documents.append(doc)
33
-
34
- # Convert to embeddings
35
- embeddings = embedding_model.embed_documents(all_documents)
36
 
37
- # Convert embeddings to a NumPy array
38
- embeddings_np = np.array(embeddings, dtype=np.float32)
 
39
 
40
- global index_w
41
- # Store in FAISS using the NumPy array's shape
42
- index_w = faiss.IndexFlatL2(embeddings_np.shape[1])
43
- index_w.add(embeddings_np)
44
 
45
  # Save FAISS index
46
- faiss.write_index(index, f"data_local/rag7_index.faiss")
47
-
48
- # Save documents in JSON (metadata storage)
49
- with open(f"data_local/rag7_docs.json", "w") as f:
50
  json.dump(all_documents, f)
51
 
52
- print(f"data is stored!")
53
-
54
- def load_data_from_faiss():
55
- load_faiss()
56
- load_metatdata()
57
 
58
  def load_ragbench():
59
- ragbench = {}
60
- for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
 
 
 
61
  ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
62
 
63
- def load_faiss():
64
  global index
65
- faiss_index_path = f"data_local/rag7_index.faiss"
66
- index = faiss.read_index(faiss_index_path)
 
 
 
 
67
 
68
- def load_metatdata():
69
  global actual_docs
70
- with open(f"data_local/rag7_docs.json", "r") as f:
71
- actual_docs = json.load(f) # Contains all documents for this dataset
 
 
 
 
 
 
 
 
 
 
 
5
  from datasets import load_dataset
6
  import torch
7
  import json
8
+ import os
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
 
17
 
18
  all_documents = []
19
  ragbench = {}
20
+ index = None
21
+ actual_docs = []
22
 
23
+ # Ensure data directory exists
24
+ os.makedirs("data_local", exist_ok=True)
25
 
26
  def create_faiss_index_file():
27
+ global index # Ensure we use the global FAISS index
28
+ all_documents.clear() # Reset document list
29
+
30
+ for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
31
+ 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
32
+ 'tatqa', 'techqa']:
33
+ ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
34
+
35
+ for split in ragbench_dataset.keys():
36
+ for row in ragbench_dataset[split]:
37
+ doc = row["documents"]
38
+ if isinstance(doc, list):
39
+ doc = " ".join(doc) # Convert list to string if needed
40
+ all_documents.append(doc)
41
 
42
+ # Convert documents to embeddings
43
+ embeddings = embedding_model.embed_documents(all_documents)
44
+ embeddings_np = np.array(embeddings, dtype=np.float32)
45
 
46
+ # Initialize and store in FAISS
47
+ index = faiss.IndexFlatL2(embeddings_np.shape[1])
48
+ index.add(embeddings_np)
 
49
 
50
  # Save FAISS index
51
+ faiss.write_index(index, "data_local/rag7_index.faiss")
52
+
53
+ # Save documents metadata
54
+ with open("data_local/rag7_docs.json", "w") as f:
55
  json.dump(all_documents, f)
56
 
57
+ print("FAISS index and metadata saved successfully!")
 
 
 
 
58
 
59
  def load_ragbench():
60
+ global ragbench
61
+ ragbench.clear() # Reset dictionary
62
+ for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
63
+ 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
64
+ 'tatqa', 'techqa']:
65
  ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
66
 
67
+ def load_faiss():
68
  global index
69
+ faiss_index_path = "data_local/rag7_index.faiss"
70
+ if os.path.exists(faiss_index_path):
71
+ index = faiss.read_index(faiss_index_path)
72
+ print("FAISS index loaded successfully.")
73
+ else:
74
+ print("FAISS index file not found. Run create_faiss_index_file() first.")
75
 
76
+ def load_metadata():
77
  global actual_docs
78
+ metadata_path = "data_local/rag7_docs.json"
79
+ if os.path.exists(metadata_path):
80
+ with open(metadata_path, "r") as f:
81
+ actual_docs = json.load(f)
82
+ print("Metadata loaded successfully.")
83
+ else:
84
+ print("Metadata file not found. Run create_faiss_index_file() first.")
85
+
86
+ def load_data_from_faiss():
87
+ load_faiss()
88
+ load_metadata()
89
+ #return index, actual_docs