cb1716pics commited on
Commit
99afa50
·
verified ·
1 Parent(s): 192559e

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +3 -0
  2. data_processing.py +66 -41
  3. requirements.txt +2 -1
  4. retrieval.py +12 -1
app.py CHANGED
@@ -8,6 +8,9 @@ import time
8
  # Page Title
9
  st.title("RAG7 - Real World RAG System")
10
 
 
 
 
11
  # @st.cache_data
12
  # def load_data():
13
  # load_data_from_faiss()
 
8
  # Page Title
9
  st.title("RAG7 - Real World RAG System")
10
 
11
+ global retrieved_documents
12
+ retrieved_documents = []
13
+
14
  # @st.cache_data
15
  # def load_data():
16
  # load_data_from_faiss()
data_processing.py CHANGED
@@ -1,60 +1,78 @@
1
- import numpy as np
2
  import faiss
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from sentence_transformers import SentenceTransformer
5
- from datasets import load_dataset
6
  import torch
7
  import json
8
  import os
 
 
 
 
 
 
 
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Load embedding model
13
  embedding_model = HuggingFaceEmbeddings(
14
- model_name="paraphrase-MiniLM-L3-v2",
15
  model_kwargs={"device": device}
16
  )
17
 
 
 
18
  all_documents = []
19
  ragbench = {}
20
  index = None
21
- actual_docs = []
 
22
 
23
  # Ensure data directory exists
24
  os.makedirs("data_local", exist_ok=True)
25
 
26
- def create_faiss_index_file():
27
- global index # Ensure we use the global FAISS index
28
- all_documents.clear() # Reset document list
 
 
29
 
30
- for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa',
31
- 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
32
- 'tatqa', 'techqa']:
33
- ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- for split in ragbench_dataset.keys():
36
- for row in ragbench_dataset[split]:
37
- doc = row["documents"]
38
- if isinstance(doc, list):
39
- doc = " ".join(doc) # Convert list to string if needed
40
- all_documents.append(doc)
41
 
42
- # Convert documents to embeddings
43
- embeddings = embedding_model.embed_documents(all_documents)
44
  embeddings_np = np.array(embeddings, dtype=np.float32)
45
 
46
- # Initialize and store in FAISS
47
- index = faiss.IndexFlatL2(embeddings_np.shape[1])
48
- index.add(embeddings_np)
49
 
50
  # Save FAISS index
51
- faiss.write_index(index, "data_local/rag7_index.faiss")
52
-
53
- # Save documents metadata
54
- with open("data_local/rag7_docs.json", "w") as f:
55
- json.dump(all_documents, f)
56
 
57
- print("FAISS index and metadata saved successfully!")
58
 
59
  def load_ragbench():
60
  global ragbench
@@ -64,26 +82,33 @@ def load_ragbench():
64
  'tatqa', 'techqa']:
65
  ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
66
 
67
- def load_faiss():
68
  global index
69
- faiss_index_path = "data_local/rag7_index.faiss"
70
  if os.path.exists(faiss_index_path):
71
  index = faiss.read_index(faiss_index_path)
72
  print("FAISS index loaded successfully.")
73
  else:
74
- print("FAISS index file not found. Run create_faiss_index_file() first.")
75
 
76
- def load_metadata():
77
- global actual_docs
78
- metadata_path = "data_local/rag7_docs.json"
79
  if os.path.exists(metadata_path):
80
  with open(metadata_path, "r") as f:
81
- actual_docs = json.load(f)
82
  print("Metadata loaded successfully.")
83
  else:
84
  print("Metadata file not found. Run create_faiss_index_file() first.")
85
 
86
- def load_data_from_faiss():
87
- load_faiss()
88
- load_metadata()
89
- #return index, actual_docs
 
 
 
 
 
 
 
 
 
1
  import faiss
 
 
 
2
  import torch
3
  import json
4
  import os
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from datasets import load_dataset
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from sentence_transformers import CrossEncoder
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  # Load embedding model
16
  embedding_model = HuggingFaceEmbeddings(
17
+ model_name="all-MiniLM-L12-v2",
18
  model_kwargs={"device": device}
19
  )
20
 
21
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
22
+
23
  all_documents = []
24
  ragbench = {}
25
  index = None
26
+ chunk_docs = []
27
+ documents = []
28
 
29
  # Ensure data directory exists
30
  os.makedirs("data_local", exist_ok=True)
31
 
32
+ # Initialize a text splitter
33
+ text_splitter = RecursiveCharacterTextSplitter(
34
+ chunk_size=1024,
35
+ chunk_overlap=100
36
+ )
37
 
38
+ def chunk_documents(docs):
39
+ chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
40
+ return chunks
41
+
42
+ def create_faiss_index(dataset):
43
+ # Load dataset
44
+ ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
45
+
46
+ for split in ragbench_dataset.keys():
47
+ for row in ragbench_dataset[split]:
48
+ # Ensure document is a string before appending
49
+ doc = row["documents"]
50
+ if isinstance(doc, list):
51
+ # If doc is a list, join its elements into a single string
52
+ doc = " ".join(doc)
53
+ documents.append(doc) # Extract document text
54
+ # Chunking
55
+
56
+ chunked_documents = chunk_documents(documents)
57
+
58
+ # Save documents in JSON (metadata storage)
59
+ with open(f"{dataset}_chunked_docs.json", "w") as f:
60
+ json.dump(chunked_documents, f)
61
 
62
+ print(len(chunked_documents))
63
+ # Convert to embeddings
64
+ embeddings = embedding_model.embed_documents(chunked_documents)
 
 
 
65
 
66
+ # Convert embeddings to a NumPy array
 
67
  embeddings_np = np.array(embeddings, dtype=np.float32)
68
 
 
 
 
69
 
70
  # Save FAISS index
71
+ index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size
72
+ index.add(embeddings_np)
73
+ faiss.write_index(index, f"{dataset}_chunked_index.faiss")
 
 
74
 
75
+ print(f"{dataset} stored as individual FAISS index!")
76
 
77
  def load_ragbench():
78
  global ragbench
 
82
  'tatqa', 'techqa']:
83
  ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
84
 
85
+ def load_faiss(query_dataset):
86
  global index
87
+ faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
88
  if os.path.exists(faiss_index_path):
89
  index = faiss.read_index(faiss_index_path)
90
  print("FAISS index loaded successfully.")
91
  else:
92
+ print("FAISS index file not found. Run create_faiss_index_file() first.")
93
 
94
+ def load_chunks(query_dataset):
95
+ global chunk_docs
96
+ metadata_path = f"data_local/{query_dataset}_chunked_docs.json"
97
  if os.path.exists(metadata_path):
98
  with open(metadata_path, "r") as f:
99
+ chunk_docs = json.load(f)
100
  print("Metadata loaded successfully.")
101
  else:
102
  print("Metadata file not found. Run create_faiss_index_file() first.")
103
 
104
+ def load_data_from_faiss(query_dataset):
105
+ load_faiss(query_dataset)
106
+ load_chunks(query_dataset)
107
+ #return index_, chunks_
108
+
109
+ def rerank_documents(query, retrieved_docs):
110
+ doc_texts = [doc for doc in retrieved_docs]
111
+ scores = reranker.predict([[query, doc] for doc in doc_texts])
112
+ ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
113
+ return ranked_docs[:5] # Return top 5 most relevant
114
+
requirements.txt CHANGED
@@ -14,4 +14,5 @@ rank_bm25
14
  nltk
15
  requests
16
  rouge-score
17
- numpy
 
 
14
  nltk
15
  requests
16
  rouge-score
17
+ numpy
18
+ rank_bm25
retrieval.py CHANGED
@@ -33,4 +33,15 @@ def remove_duplicate_documents(documents):
33
  if doc_content not in seen_documents:
34
  unique_documents.append(doc)
35
  seen_documents.add(doc_content)
36
- return unique_documents
 
 
 
 
 
 
 
 
 
 
 
 
33
  if doc_content not in seen_documents:
34
  unique_documents.append(doc)
35
  seen_documents.add(doc_content)
36
+ return unique_documents
37
+
38
+ def find_query_dataset(query):
39
+ index = faiss.read_index("question_index.faiss")
40
+
41
+ with open("dataset_mapping.json", "r") as f:
42
+ dataset_names = json.load(f)
43
+
44
+ question_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
45
+ _, nearest_index = index.search(question_embedding, 1)
46
+ best_dataset = dataset_names[nearest_index[0][0]]
47
+ return best_dataset