Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +8 -4
- retrieval.py +36 -4
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from generator import generate_response_from_document
|
3 |
-
from retrieval import
|
4 |
from evaluation import calculate_metrics
|
5 |
#from data_processing import load_data_from_faiss
|
6 |
import time
|
@@ -11,14 +11,18 @@ st.title("RAG7 - Real World RAG System")
|
|
11 |
global retrieved_documents
|
12 |
retrieved_documents = []
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# @st.cache_data
|
15 |
# def load_data():
|
16 |
# load_data_from_faiss()
|
17 |
|
18 |
# data_status = load_data()
|
19 |
|
20 |
-
time_taken_for_response = 'N/A'
|
21 |
-
|
22 |
# Question Section
|
23 |
st.subheader("Hi, What do you want to know today?")
|
24 |
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
|
@@ -26,7 +30,7 @@ question = st.text_area("Enter your question:", placeholder="Type your question
|
|
26 |
# Submit Button
|
27 |
if st.button("Submit"):
|
28 |
start_time = time.time()
|
29 |
-
retrieved_documents =
|
30 |
response = generate_response_from_document(question, retrieved_documents)
|
31 |
end_time = time.time()
|
32 |
time_taken_for_response = end_time-start_time
|
|
|
1 |
import streamlit as st
|
2 |
from generator import generate_response_from_document
|
3 |
+
from retrieval import retrieve_documents_hybrid
|
4 |
from evaluation import calculate_metrics
|
5 |
#from data_processing import load_data_from_faiss
|
6 |
import time
|
|
|
11 |
global retrieved_documents
|
12 |
retrieved_documents = []
|
13 |
|
14 |
+
global response
|
15 |
+
response = ""
|
16 |
+
|
17 |
+
global time_taken_for_response
|
18 |
+
time_taken_for_response = 'N/A'
|
19 |
+
|
20 |
# @st.cache_data
|
21 |
# def load_data():
|
22 |
# load_data_from_faiss()
|
23 |
|
24 |
# data_status = load_data()
|
25 |
|
|
|
|
|
26 |
# Question Section
|
27 |
st.subheader("Hi, What do you want to know today?")
|
28 |
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
|
|
|
30 |
# Submit Button
|
31 |
if st.button("Submit"):
|
32 |
start_time = time.time()
|
33 |
+
retrieved_documents = retrieve_documents_hybrid(question, 10)
|
34 |
response = generate_response_from_document(question, retrieved_documents)
|
35 |
end_time = time.time()
|
36 |
time_taken_for_response = end_time-start_time
|
retrieval.py
CHANGED
@@ -2,16 +2,51 @@ import json
|
|
2 |
import numpy as np
|
3 |
from langchain.schema import Document
|
4 |
import faiss
|
5 |
-
|
6 |
from data_processing import embedding_model #, index, actual_docs
|
7 |
|
8 |
retrieved_docs = None
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# Retrieval Function
|
11 |
def retrieve_documents(query, top_k=5):
|
12 |
query_dataset = find_query_dataset(query)
|
13 |
#index, chunk_docs = load_data_from_faiss(query)
|
14 |
|
|
|
|
|
|
|
15 |
faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
|
16 |
index = faiss.read_index(faiss_index_path)
|
17 |
|
@@ -19,9 +54,6 @@ def retrieve_documents(query, top_k=5):
|
|
19 |
|
20 |
_, nearest_indices = index.search(query_embedding, top_k)
|
21 |
|
22 |
-
with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
|
23 |
-
documents = json.load(f) # Contains all documents for this dataset
|
24 |
-
|
25 |
retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
|
26 |
|
27 |
return retrieved_docs
|
|
|
2 |
import numpy as np
|
3 |
from langchain.schema import Document
|
4 |
import faiss
|
5 |
+
from rank_bm25 import BM25Okapi
|
6 |
from data_processing import embedding_model #, index, actual_docs
|
7 |
|
8 |
retrieved_docs = None
|
9 |
|
10 |
+
|
11 |
+
def retrieve_documents_hybrid(query, top_k=5):
|
12 |
+
query_dataset = find_query_dataset(query)
|
13 |
+
|
14 |
+
with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
|
15 |
+
chunked_documents = json.load(f) # Contains all documents for this dataset
|
16 |
+
|
17 |
+
faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
|
18 |
+
index = faiss.read_index(faiss_index_path)
|
19 |
+
|
20 |
+
# Tokenize documents for BM25
|
21 |
+
tokenized_docs = [doc.split() for doc in chunked_documents]
|
22 |
+
bm25 = BM25Okapi(tokenized_docs)
|
23 |
+
|
24 |
+
query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
|
25 |
+
query_embedding = query_embedding.reshape(1, -1)
|
26 |
+
|
27 |
+
# FAISS Search
|
28 |
+
_, nearest_indices = index.search(query_embedding, top_k)
|
29 |
+
faiss_docs = [chunked_documents[i] for i in nearest_indices[0]]
|
30 |
+
|
31 |
+
# BM25 Search
|
32 |
+
tokenized_query = query.split()
|
33 |
+
bm25_scores = bm25.get_scores(tokenized_query)
|
34 |
+
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
35 |
+
bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
|
36 |
+
|
37 |
+
# Merge FAISS + BM25 Results
|
38 |
+
retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
|
39 |
+
|
40 |
+
return retrieved_docs
|
41 |
+
|
42 |
# Retrieval Function
|
43 |
def retrieve_documents(query, top_k=5):
|
44 |
query_dataset = find_query_dataset(query)
|
45 |
#index, chunk_docs = load_data_from_faiss(query)
|
46 |
|
47 |
+
with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
|
48 |
+
documents = json.load(f) # Contains all documents for this dataset
|
49 |
+
|
50 |
faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
|
51 |
index = faiss.read_index(faiss_index_path)
|
52 |
|
|
|
54 |
|
55 |
_, nearest_indices = index.search(query_embedding, top_k)
|
56 |
|
|
|
|
|
|
|
57 |
retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
|
58 |
|
59 |
return retrieved_docs
|