Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app.py +67 -83
- data_processing.py +31 -22
- evaluation.py +23 -15
app.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
from generator import generate_response_from_document
|
3 |
from retrieval import retrieve_documents_hybrid,find_query_dataset
|
4 |
from evaluation import calculate_metrics
|
5 |
-
from data_processing import load_recent_questions,
|
6 |
import time
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
@@ -22,26 +22,6 @@ st.markdown(
|
|
22 |
unsafe_allow_html=True
|
23 |
)
|
24 |
|
25 |
-
# global retrieved_documents
|
26 |
-
# retrieved_documents = []
|
27 |
-
|
28 |
-
# global response
|
29 |
-
# response = ""
|
30 |
-
|
31 |
-
# global time_taken_for_response
|
32 |
-
# time_taken_for_response = 'N/A'
|
33 |
-
|
34 |
-
# @st.cache_data
|
35 |
-
# def load_data():
|
36 |
-
# load_data_from_faiss()
|
37 |
-
|
38 |
-
# data_status = load_data()
|
39 |
-
|
40 |
-
# Question Section
|
41 |
-
st.subheader("Hi, What do you want to know today?")
|
42 |
-
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
|
43 |
-
question = question.strip()
|
44 |
-
|
45 |
# # Submit Button
|
46 |
# if st.button("Submit"):
|
47 |
# start_time = time.time()
|
@@ -70,48 +50,49 @@ question = question.strip()
|
|
70 |
# with col2:
|
71 |
# st.text_area("Metrics:", value=metrics, height=100, disabled=True)
|
72 |
|
73 |
-
|
74 |
-
st.session_state.retrieved_documents = []
|
75 |
-
if "response" not in st.session_state:
|
76 |
-
st.session_state.response = ""
|
77 |
-
if "time_taken_for_response" not in st.session_state:
|
78 |
-
st.session_state.time_taken_for_response = "N/A"
|
79 |
-
if "metrics" not in st.session_state:
|
80 |
-
st.session_state.metrics = {}
|
81 |
-
if "query_dataset" not in st.session_state:
|
82 |
-
st.session_state.query_dataset = ''
|
83 |
if "recent_questions" not in st.session_state:
|
84 |
-
st.session_state.recent_questions =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
st.session_state
|
87 |
-
|
88 |
|
89 |
-
if st.session_state.recent_questions
|
90 |
-
recent_qns = list(reversed(st.session_state.recent_questions
|
91 |
|
92 |
print(recent_qns)
|
93 |
|
94 |
# Display Recent Questions
|
95 |
-
st.sidebar.title("
|
96 |
-
for q in recent_qns
|
97 |
-
|
|
|
98 |
|
99 |
st.sidebar.markdown("---")
|
|
|
100 |
st.sidebar.title("Analytics")
|
101 |
|
102 |
# Extract response times and labels
|
103 |
-
response_time = [q
|
104 |
labels = [f"Q{i+1}" for i in range(len(response_time))]
|
105 |
|
106 |
# Plot graph
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
st.sidebar.pyplot(fig)
|
115 |
|
116 |
st.sidebar.markdown("---")
|
117 |
|
@@ -122,52 +103,55 @@ if st.session_state.recent_questions and "questions" in st.session_state.recent
|
|
122 |
else:
|
123 |
st.sidebar.title("No recent questions")
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 10)
|
130 |
-
st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
|
131 |
-
end_time = time.time()
|
132 |
-
st.session_state.time_taken_for_response = end_time - start_time
|
133 |
-
|
134 |
-
# Store in session state
|
135 |
-
# st.session_state.recent_questions.append({
|
136 |
-
# "question": question,
|
137 |
-
# "response_time": st.session_state.time_taken_for_response
|
138 |
-
# })
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
# Display stored response
|
141 |
st.subheader("Response")
|
142 |
st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
|
143 |
|
144 |
col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
|
145 |
|
146 |
-
# # Calculate Metrics Button
|
147 |
-
# with col1:
|
148 |
-
# if st.button("Calculate Metrics"):
|
149 |
-
# metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
|
150 |
-
# else:
|
151 |
-
# metrics = {}
|
152 |
-
|
153 |
-
# with col2:
|
154 |
-
# #st.text_area("Metrics:", value=metrics, height=100, disabled=True)
|
155 |
-
# st.json(metrics)
|
156 |
-
|
157 |
-
|
158 |
# Calculate Metrics Button
|
159 |
with col1:
|
160 |
if st.button("Show Metrics"):
|
161 |
st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
with col2:
|
167 |
-
#st.text_area("Metrics:", value=metrics, height=100, disabled=True)
|
168 |
-
if len(metrics_) > 0:
|
169 |
-
st.json(metrics_)
|
170 |
|
171 |
-
|
|
|
172 |
|
|
|
|
|
173 |
|
|
|
|
2 |
from generator import generate_response_from_document
|
3 |
from retrieval import retrieve_documents_hybrid,find_query_dataset
|
4 |
from evaluation import calculate_metrics
|
5 |
+
from data_processing import load_recent_questions, save_recent_questions
|
6 |
import time
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
|
|
22 |
unsafe_allow_html=True
|
23 |
)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# # Submit Button
|
26 |
# if st.button("Submit"):
|
27 |
# start_time = time.time()
|
|
|
50 |
# with col2:
|
51 |
# st.text_area("Metrics:", value=metrics, height=100, disabled=True)
|
52 |
|
53 |
+
# Initialize session state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
if "recent_questions" not in st.session_state:
|
55 |
+
st.session_state.recent_questions = load_recent_questions()
|
56 |
+
|
57 |
+
if "last_question" not in st.session_state:
|
58 |
+
st.session_state.last_question = None
|
59 |
+
|
60 |
+
if "response_time" not in st.session_state:
|
61 |
+
st.session_state.response_time = None
|
62 |
+
|
63 |
+
if "retrieved_documents" not in st.session_state:
|
64 |
+
st.session_state.retrieved_documents = None
|
65 |
|
66 |
+
if "response" not in st.session_state:
|
67 |
+
st.session_state.response = None
|
68 |
|
69 |
+
if st.session_state.recent_questions:
|
70 |
+
recent_qns = list(reversed(st.session_state.recent_questions))
|
71 |
|
72 |
print(recent_qns)
|
73 |
|
74 |
# Display Recent Questions
|
75 |
+
st.sidebar.title("Overall RMSE")
|
76 |
+
rmse_values = [q["metrics"]["rmse"] for q in recent_qns if "metrics" in q and "rmse" in q["metrics"]]
|
77 |
+
average_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else 0
|
78 |
+
st.sidebar.write(f"📊 **Average RMSE:** {average_rmse:.4f}")
|
79 |
|
80 |
st.sidebar.markdown("---")
|
81 |
+
|
82 |
st.sidebar.title("Analytics")
|
83 |
|
84 |
# Extract response times and labels
|
85 |
+
response_time = [q.get('metrics').get('response_time') for q in recent_qns]
|
86 |
labels = [f"Q{i+1}" for i in range(len(response_time))]
|
87 |
|
88 |
# Plot graph
|
89 |
+
if any(response_time):
|
90 |
+
fig, ax = plt.subplots()
|
91 |
+
ax.plot(labels, response_time, marker="o", linestyle="-", color="skyblue")
|
92 |
+
ax.set_xlabel("Recent Questions")
|
93 |
+
ax.set_ylabel("Time Taken for Response (seconds)")
|
94 |
+
ax.set_title("Response Time Analysis")
|
95 |
+
st.sidebar.pyplot(fig)
|
|
|
96 |
|
97 |
st.sidebar.markdown("---")
|
98 |
|
|
|
103 |
else:
|
104 |
st.sidebar.title("No recent questions")
|
105 |
|
106 |
+
# Question Section
|
107 |
+
st.subheader("Hi, What do you want to know today?")
|
108 |
+
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
|
109 |
+
question = question.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
if st.button("Submit"):
|
112 |
+
if question:
|
113 |
+
st.session_state.last_question = question
|
114 |
+
|
115 |
+
start_time = time.time()
|
116 |
+
st.session_state.metrics = {}
|
117 |
+
st.session_state.response = ""
|
118 |
+
st.session_state.query_dataset = find_query_dataset(question)
|
119 |
+
st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 10)
|
120 |
+
st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
|
121 |
+
end_time = time.time()
|
122 |
+
st.session_state.time_taken_for_response = end_time - start_time
|
123 |
+
|
124 |
+
# Check if question already exists
|
125 |
+
existing_questions = [q["question"] for q in st.session_state.recent_questions]
|
126 |
+
|
127 |
+
if question not in existing_questions:
|
128 |
+
new_entry = {
|
129 |
+
"question": question,
|
130 |
+
"metrics": st.session_state.metrics
|
131 |
+
}
|
132 |
+
st.session_state.recent_questions.append(new_entry)
|
133 |
+
save_recent_questions(st.session_state.recent_questions)
|
134 |
+
else:
|
135 |
+
st.error("Please enter a question before submitting.")
|
136 |
+
|
137 |
# Display stored response
|
138 |
st.subheader("Response")
|
139 |
st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
|
140 |
|
141 |
col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
# Calculate Metrics Button
|
144 |
with col1:
|
145 |
if st.button("Show Metrics"):
|
146 |
st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
|
147 |
+
for q in st.session_state.recent_questions:
|
148 |
+
if q["question"] == st.session_state.last_question:
|
149 |
+
q["metrics"] = {"metrics": st.session_state.metrics}
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
# Save updated data to file
|
152 |
+
save_recent_questions(st.session_state.recent_questions)
|
153 |
|
154 |
+
with col2:
|
155 |
+
st.text_area("Metrics:", value=st.session_state.metrics, height=100, disabled=True)
|
156 |
|
157 |
+
st.experimental_rerun()
|
data_processing.py
CHANGED
@@ -23,10 +23,10 @@ query_dataset_data = {}
|
|
23 |
|
24 |
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
|
25 |
|
26 |
-
# Ensure the file exists and initialize if empty
|
27 |
-
if not os.path.exists(RECENT_QUESTIONS_FILE):
|
28 |
-
|
29 |
-
|
30 |
|
31 |
all_documents = []
|
32 |
ragbench = {}
|
@@ -130,27 +130,36 @@ def rerank_documents(query, retrieved_docs):
|
|
130 |
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
|
131 |
return ranked_docs[:5] # Return top 5 most relevant
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def load_recent_questions():
|
134 |
if os.path.exists(RECENT_QUESTIONS_FILE):
|
135 |
with open(RECENT_QUESTIONS_FILE, "r") as file:
|
136 |
return json.load(file)
|
137 |
-
return
|
138 |
-
|
139 |
-
def save_recent_question(question, metrics):
|
140 |
-
data = load_recent_questions()
|
141 |
-
|
142 |
-
if "question" in data["questions"] and question not in data["questions"]["question"]:
|
143 |
-
# Append new question & metrics
|
144 |
-
data["questions"].append({
|
145 |
-
"question": question,
|
146 |
-
"metrics": metrics
|
147 |
-
})
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
# Write back to file
|
153 |
with open(RECENT_QUESTIONS_FILE, "w") as file:
|
154 |
-
json.dump(data, file, indent=4)
|
155 |
-
|
156 |
-
|
|
|
23 |
|
24 |
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
|
25 |
|
26 |
+
# # Ensure the file exists and initialize if empty
|
27 |
+
# if not os.path.exists(RECENT_QUESTIONS_FILE):
|
28 |
+
# with open(RECENT_QUESTIONS_FILE, "w") as file:
|
29 |
+
# json.dump({"questions": []}, file, indent=4)
|
30 |
|
31 |
all_documents = []
|
32 |
ragbench = {}
|
|
|
130 |
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
|
131 |
return ranked_docs[:5] # Return top 5 most relevant
|
132 |
|
133 |
+
# def load_recent_questions():
|
134 |
+
# if os.path.exists(RECENT_QUESTIONS_FILE):
|
135 |
+
# with open(RECENT_QUESTIONS_FILE, "r") as file:
|
136 |
+
# return json.load(file)
|
137 |
+
# return {"questions": []}
|
138 |
+
|
139 |
+
# def save_recent_question(question, metrics_1):
|
140 |
+
# data = load_recent_questions()
|
141 |
+
|
142 |
+
# # Append new question & metrics
|
143 |
+
# data["questions"].append({
|
144 |
+
# "question": question,
|
145 |
+
# "metrics": metrics_1
|
146 |
+
# })
|
147 |
+
|
148 |
+
# # # Keep only the last 5 questions
|
149 |
+
# # data["questions"] = data["questions"][-5:]
|
150 |
+
|
151 |
+
# # Write back to file
|
152 |
+
# with open(RECENT_QUESTIONS_FILE, "w") as file:
|
153 |
+
# json.dump(data, file, indent=4)
|
154 |
+
|
155 |
+
# Load previous questions from file
|
156 |
def load_recent_questions():
|
157 |
if os.path.exists(RECENT_QUESTIONS_FILE):
|
158 |
with open(RECENT_QUESTIONS_FILE, "r") as file:
|
159 |
return json.load(file)
|
160 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
+
# Save questions to file
|
163 |
+
def save_recent_questions(data):
|
|
|
|
|
164 |
with open(RECENT_QUESTIONS_FILE, "w") as file:
|
165 |
+
json.dump(data, file, indent=4)
|
|
|
|
evaluation.py
CHANGED
@@ -85,7 +85,7 @@ def adherence(response, relevant_documents):
|
|
85 |
response_tokens = set(response.split())
|
86 |
relevant_tokens = set(combined_docs.split())
|
87 |
supported_tokens = response_tokens.intersection(relevant_tokens)
|
88 |
-
return len(supported_tokens) / len(response_tokens)
|
89 |
|
90 |
# Step 6: Compute RMSE for metrics
|
91 |
def compute_rmse(predicted_values, ground_truth_values):
|
@@ -93,7 +93,7 @@ def compute_rmse(predicted_values, ground_truth_values):
|
|
93 |
|
94 |
def calculate_metrics(question, q_dataset, response, docs, time_taken):
|
95 |
data = load_query_dataset(q_dataset)
|
96 |
-
ground_truth_answer = retrieve_ground_truths(question, data) # Store the ground truth answer
|
97 |
|
98 |
# Ensure ground_truth_answer is not empty before proceeding
|
99 |
if ground_truth_answer is None:
|
@@ -104,12 +104,20 @@ def calculate_metrics(question, q_dataset, response, docs, time_taken):
|
|
104 |
"context_relevance": context_relevance(question, docs),
|
105 |
"context_utilization": context_utilization(response, docs),
|
106 |
"completeness": completeness(response, ground_truth_answer),
|
107 |
-
"adherence": adherence(response, docs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
"response_time": time_taken,
|
109 |
"ground_truth": ground_truth_answer,
|
110 |
-
"RAG_model_response": response
|
111 |
}
|
112 |
-
|
|
|
113 |
|
114 |
def retrieve_ground_truths(question, dataset):
|
115 |
for split_name, instances in dataset.items():
|
@@ -118,18 +126,18 @@ def retrieve_ground_truths(question, dataset):
|
|
118 |
#if instance['question'] == question:
|
119 |
if is_similar(instance['question'], question):
|
120 |
instance_id = instance['id']
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
print(f"Match found in {split_name} split!")
|
129 |
-
print(f"ID: {instance_id}, Response: {
|
130 |
-
return
|
131 |
|
132 |
-
return None
|
133 |
|
134 |
def is_similar(question1, question2, threshold=0.85):
|
135 |
vectorizer = TfidfVectorizer()
|
|
|
85 |
response_tokens = set(response.split())
|
86 |
relevant_tokens = set(combined_docs.split())
|
87 |
supported_tokens = response_tokens.intersection(relevant_tokens)
|
88 |
+
return len(supported_tokens) / len(response_tokens) >= 0.5
|
89 |
|
90 |
# Step 6: Compute RMSE for metrics
|
91 |
def compute_rmse(predicted_values, ground_truth_values):
|
|
|
93 |
|
94 |
def calculate_metrics(question, q_dataset, response, docs, time_taken):
|
95 |
data = load_query_dataset(q_dataset)
|
96 |
+
ground_truth_answer, ground_truth_metrics = retrieve_ground_truths(question, data) # Store the ground truth answer
|
97 |
|
98 |
# Ensure ground_truth_answer is not empty before proceeding
|
99 |
if ground_truth_answer is None:
|
|
|
104 |
"context_relevance": context_relevance(question, docs),
|
105 |
"context_utilization": context_utilization(response, docs),
|
106 |
"completeness": completeness(response, ground_truth_answer),
|
107 |
+
"adherence": adherence(response, docs),
|
108 |
+
}
|
109 |
+
|
110 |
+
rmse = compute_rmse(predicted_metrics, ground_truth_metrics),
|
111 |
+
|
112 |
+
metrics = {
|
113 |
+
"RMSE": rmse,
|
114 |
+
"metrics":predicted_metrics,
|
115 |
"response_time": time_taken,
|
116 |
"ground_truth": ground_truth_answer,
|
117 |
+
"RAG_model_response": response,
|
118 |
}
|
119 |
+
|
120 |
+
return metrics
|
121 |
|
122 |
def retrieve_ground_truths(question, dataset):
|
123 |
for split_name, instances in dataset.items():
|
|
|
126 |
#if instance['question'] == question:
|
127 |
if is_similar(instance['question'], question):
|
128 |
instance_id = instance['id']
|
129 |
+
ground_truth = instance['response']
|
130 |
+
ground_truth_metrics_ = {
|
131 |
+
"context_relevance": instance['relevance_score'],
|
132 |
+
"context_utilization": instance['utilization_score'],
|
133 |
+
"completeness": instance['completeness_score'],
|
134 |
+
"adherence": instance['adherence_score']
|
135 |
+
}
|
136 |
print(f"Match found in {split_name} split!")
|
137 |
+
print(f"ID: {instance_id}, Response: {ground_truth}")
|
138 |
+
return ground_truth , ground_truth_metrics_ # Return ground truth response immediately
|
139 |
|
140 |
+
return None, None
|
141 |
|
142 |
def is_similar(question1, question2, threshold=0.85):
|
143 |
vectorizer = TfidfVectorizer()
|