Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app.py +6 -3
- evaluation.py +2 -3
- generator.py +17 -6
- retrieval.py +15 -20
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from generator import generate_response_from_document
|
3 |
-
from retrieval import retrieve_documents_hybrid
|
4 |
from evaluation import calculate_metrics
|
5 |
from data_processing import load_recent_questions, save_recent_question
|
6 |
import time
|
@@ -63,6 +63,8 @@ if "time_taken_for_response" not in st.session_state:
|
|
63 |
st.session_state.time_taken_for_response = "N/A"
|
64 |
if "metrics" not in st.session_state:
|
65 |
st.session_state.metrics = {}
|
|
|
|
|
66 |
|
67 |
recent_data = load_recent_questions()
|
68 |
|
@@ -106,7 +108,8 @@ for q in reversed(recent_data["questions"]): # Show latest first
|
|
106 |
|
107 |
if st.button("Submit"):
|
108 |
start_time = time.time()
|
109 |
-
st.session_state.
|
|
|
110 |
st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
|
111 |
end_time = time.time()
|
112 |
st.session_state.time_taken_for_response = end_time - start_time
|
@@ -133,7 +136,7 @@ col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics d
|
|
133 |
# Calculate Metrics Button
|
134 |
with col1:
|
135 |
if st.button("Show Metrics"):
|
136 |
-
st.session_state.metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
|
137 |
else:
|
138 |
metrics_ = {}
|
139 |
|
|
|
1 |
import streamlit as st
|
2 |
from generator import generate_response_from_document
|
3 |
+
from retrieval import retrieve_documents_hybrid,find_query_dataset
|
4 |
from evaluation import calculate_metrics
|
5 |
from data_processing import load_recent_questions, save_recent_question
|
6 |
import time
|
|
|
63 |
st.session_state.time_taken_for_response = "N/A"
|
64 |
if "metrics" not in st.session_state:
|
65 |
st.session_state.metrics = {}
|
66 |
+
if "metrics" not in st.session_state:
|
67 |
+
st.session_state.metrics = {}
|
68 |
|
69 |
recent_data = load_recent_questions()
|
70 |
|
|
|
108 |
|
109 |
if st.button("Submit"):
|
110 |
start_time = time.time()
|
111 |
+
st.session_state.query_dataset = find_query_dataset(question)
|
112 |
+
st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 10)
|
113 |
st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
|
114 |
end_time = time.time()
|
115 |
st.session_state.time_taken_for_response = end_time - start_time
|
|
|
136 |
# Calculate Metrics Button
|
137 |
with col1:
|
138 |
if st.button("Show Metrics"):
|
139 |
+
st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
|
140 |
else:
|
141 |
metrics_ = {}
|
142 |
|
evaluation.py
CHANGED
@@ -91,9 +91,8 @@ def adherence(response, relevant_documents):
|
|
91 |
def compute_rmse(predicted_values, ground_truth_values):
|
92 |
return np.sqrt(mean_squared_error(ground_truth_values, predicted_values))
|
93 |
|
94 |
-
def calculate_metrics(question, response, docs, time_taken):
|
95 |
-
|
96 |
-
data = load_query_dataset(query_dataset)
|
97 |
ground_truth_answer = retrieve_ground_truths(question, data) # Store the ground truth answer
|
98 |
|
99 |
# Ensure ground_truth_answer is not empty before proceeding
|
|
|
91 |
def compute_rmse(predicted_values, ground_truth_values):
|
92 |
return np.sqrt(mean_squared_error(ground_truth_values, predicted_values))
|
93 |
|
94 |
+
def calculate_metrics(question, q_dataset, response, docs, time_taken):
|
95 |
+
data = load_query_dataset(q_dataset)
|
|
|
96 |
ground_truth_answer = retrieve_ground_truths(question, data) # Store the ground truth answer
|
97 |
|
98 |
# Ensure ground_truth_answer is not empty before proceeding
|
generator.py
CHANGED
@@ -11,12 +11,23 @@ def generate_response_from_document(query, retrieved_docs):
|
|
11 |
|
12 |
# context = " ".join([doc.page_content for doc in retrieved_docs]) # Now iterates over Document objects
|
13 |
context = " ".join([doc for doc in retrieved_docs])
|
14 |
-
prompt =
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
try:
|
22 |
response = openai.chat.completions.create( # Use the new chat completions API
|
|
|
11 |
|
12 |
# context = " ".join([doc.page_content for doc in retrieved_docs]) # Now iterates over Document objects
|
13 |
context = " ".join([doc for doc in retrieved_docs])
|
14 |
+
prompt = """
|
15 |
+
"You are an accurate and reliable AI assistant that can answer questions with the help of external documents.
|
16 |
+
Please note that external documents may contain noisy or factually incorrect information.
|
17 |
+
If the information in the document contains the correct answer, you will give an accurate answer.
|
18 |
+
If the information in the document does not contain the answer, you will generate ’I can not answer the question because of the insufficient information in documents.‘.
|
19 |
+
If there are inconsistencies with the facts in some of the documents, please generate the response 'There are factual errors in the provided documents.' and provide the correct answer."
|
20 |
+
|
21 |
+
Context or Document: {context}
|
22 |
+
Query: {query}
|
23 |
+
"""
|
24 |
+
# prompt = (
|
25 |
+
# "You are a highly intelligent assistant tasked with answering a question based strictly on the provided context. "
|
26 |
+
# f"Given Question: {query} \n\n"
|
27 |
+
# f"Context: {context} \n"
|
28 |
+
# f"Answer the question directly and concisely using only the information available in the context."
|
29 |
+
# "Do not include any other information which is not there in the context."
|
30 |
+
# )
|
31 |
|
32 |
try:
|
33 |
response = openai.chat.completions.create( # Use the new chat completions API
|
retrieval.py
CHANGED
@@ -3,23 +3,18 @@ import numpy as np
|
|
3 |
from langchain.schema import Document
|
4 |
import faiss
|
5 |
from rank_bm25 import BM25Okapi
|
6 |
-
from data_processing import embedding_model
|
7 |
from sentence_transformers import CrossEncoder
|
8 |
|
9 |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
10 |
|
11 |
retrieved_docs = None
|
12 |
-
global query_dataset
|
13 |
-
query_dataset = ''
|
14 |
|
15 |
-
def retrieve_documents_hybrid(query, top_k=5):
|
16 |
-
|
17 |
-
query_dataset = find_query_dataset(query)
|
18 |
-
|
19 |
-
with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
|
20 |
chunked_documents = json.load(f) # Contains all documents for this dataset
|
21 |
|
22 |
-
faiss_index_path = f"data_local/{
|
23 |
index = faiss.read_index(faiss_index_path)
|
24 |
|
25 |
# Tokenize documents for BM25
|
@@ -47,23 +42,23 @@ def retrieve_documents_hybrid(query, top_k=5):
|
|
47 |
return reranked_docs
|
48 |
|
49 |
# Retrieval Function
|
50 |
-
def retrieve_documents(query, top_k=5):
|
51 |
-
|
52 |
-
|
53 |
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
|
61 |
|
62 |
-
|
63 |
|
64 |
-
|
65 |
|
66 |
-
|
67 |
|
68 |
def remove_duplicate_documents(documents):
|
69 |
unique_documents = []
|
|
|
3 |
from langchain.schema import Document
|
4 |
import faiss
|
5 |
from rank_bm25 import BM25Okapi
|
6 |
+
from data_processing import embedding_model
|
7 |
from sentence_transformers import CrossEncoder
|
8 |
|
9 |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
10 |
|
11 |
retrieved_docs = None
|
|
|
|
|
12 |
|
13 |
+
def retrieve_documents_hybrid(query, q_dataset, top_k=5):
|
14 |
+
with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
|
|
|
|
|
|
|
15 |
chunked_documents = json.load(f) # Contains all documents for this dataset
|
16 |
|
17 |
+
faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
|
18 |
index = faiss.read_index(faiss_index_path)
|
19 |
|
20 |
# Tokenize documents for BM25
|
|
|
42 |
return reranked_docs
|
43 |
|
44 |
# Retrieval Function
|
45 |
+
# def retrieve_documents(query, top_k=5):
|
46 |
+
# query_dataset = find_query_dataset(query)
|
47 |
+
# #index, chunk_docs = load_data_from_faiss(query)
|
48 |
|
49 |
+
# with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
|
50 |
+
# documents = json.load(f) # Contains all documents for this dataset
|
51 |
|
52 |
+
# faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
|
53 |
+
# index = faiss.read_index(faiss_index_path)
|
54 |
|
55 |
+
# query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
|
56 |
|
57 |
+
# _, nearest_indices = index.search(query_embedding, top_k)
|
58 |
|
59 |
+
# retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
|
60 |
|
61 |
+
# return retrieved_docs
|
62 |
|
63 |
def remove_duplicate_documents(documents):
|
64 |
unique_documents = []
|