cb1716pics commited on
Commit
1d3e5ce
·
verified ·
1 Parent(s): c4f2afd

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +8 -4
  2. retrieval.py +36 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  from generator import generate_response_from_document
3
- from retrieval import retrieve_documents
4
  from evaluation import calculate_metrics
5
  #from data_processing import load_data_from_faiss
6
  import time
@@ -11,14 +11,18 @@ st.title("RAG7 - Real World RAG System")
11
  global retrieved_documents
12
  retrieved_documents = []
13
 
 
 
 
 
 
 
14
  # @st.cache_data
15
  # def load_data():
16
  # load_data_from_faiss()
17
 
18
  # data_status = load_data()
19
 
20
- time_taken_for_response = 'N/A'
21
-
22
  # Question Section
23
  st.subheader("Hi, What do you want to know today?")
24
  question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
@@ -26,7 +30,7 @@ question = st.text_area("Enter your question:", placeholder="Type your question
26
  # Submit Button
27
  if st.button("Submit"):
28
  start_time = time.time()
29
- retrieved_documents = retrieve_documents(question, 5)
30
  response = generate_response_from_document(question, retrieved_documents)
31
  end_time = time.time()
32
  time_taken_for_response = end_time-start_time
 
1
  import streamlit as st
2
  from generator import generate_response_from_document
3
+ from retrieval import retrieve_documents_hybrid
4
  from evaluation import calculate_metrics
5
  #from data_processing import load_data_from_faiss
6
  import time
 
11
  global retrieved_documents
12
  retrieved_documents = []
13
 
14
+ global response
15
+ response = ""
16
+
17
+ global time_taken_for_response
18
+ time_taken_for_response = 'N/A'
19
+
20
  # @st.cache_data
21
  # def load_data():
22
  # load_data_from_faiss()
23
 
24
  # data_status = load_data()
25
 
 
 
26
  # Question Section
27
  st.subheader("Hi, What do you want to know today?")
28
  question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
 
30
  # Submit Button
31
  if st.button("Submit"):
32
  start_time = time.time()
33
+ retrieved_documents = retrieve_documents_hybrid(question, 10)
34
  response = generate_response_from_document(question, retrieved_documents)
35
  end_time = time.time()
36
  time_taken_for_response = end_time-start_time
retrieval.py CHANGED
@@ -2,16 +2,51 @@ import json
2
  import numpy as np
3
  from langchain.schema import Document
4
  import faiss
5
-
6
  from data_processing import embedding_model #, index, actual_docs
7
 
8
  retrieved_docs = None
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Retrieval Function
11
  def retrieve_documents(query, top_k=5):
12
  query_dataset = find_query_dataset(query)
13
  #index, chunk_docs = load_data_from_faiss(query)
14
 
 
 
 
15
  faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
16
  index = faiss.read_index(faiss_index_path)
17
 
@@ -19,9 +54,6 @@ def retrieve_documents(query, top_k=5):
19
 
20
  _, nearest_indices = index.search(query_embedding, top_k)
21
 
22
- with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
23
- documents = json.load(f) # Contains all documents for this dataset
24
-
25
  retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
26
 
27
  return retrieved_docs
 
2
  import numpy as np
3
  from langchain.schema import Document
4
  import faiss
5
+ from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model #, index, actual_docs
7
 
8
  retrieved_docs = None
9
 
10
+
11
+ def retrieve_documents_hybrid(query, top_k=5):
12
+ query_dataset = find_query_dataset(query)
13
+
14
+ with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
15
+ chunked_documents = json.load(f) # Contains all documents for this dataset
16
+
17
+ faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
18
+ index = faiss.read_index(faiss_index_path)
19
+
20
+ # Tokenize documents for BM25
21
+ tokenized_docs = [doc.split() for doc in chunked_documents]
22
+ bm25 = BM25Okapi(tokenized_docs)
23
+
24
+ query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
25
+ query_embedding = query_embedding.reshape(1, -1)
26
+
27
+ # FAISS Search
28
+ _, nearest_indices = index.search(query_embedding, top_k)
29
+ faiss_docs = [chunked_documents[i] for i in nearest_indices[0]]
30
+
31
+ # BM25 Search
32
+ tokenized_query = query.split()
33
+ bm25_scores = bm25.get_scores(tokenized_query)
34
+ bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
35
+ bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
36
+
37
+ # Merge FAISS + BM25 Results
38
+ retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
39
+
40
+ return retrieved_docs
41
+
42
  # Retrieval Function
43
  def retrieve_documents(query, top_k=5):
44
  query_dataset = find_query_dataset(query)
45
  #index, chunk_docs = load_data_from_faiss(query)
46
 
47
+ with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
48
+ documents = json.load(f) # Contains all documents for this dataset
49
+
50
  faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
51
  index = faiss.read_index(faiss_index_path)
52
 
 
54
 
55
  _, nearest_indices = index.search(query_embedding, top_k)
56
 
 
 
 
57
  retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
58
 
59
  return retrieved_docs