cb1716pics commited on
Commit
0e36212
·
verified ·
1 Parent(s): 694551d

Upload 4 files

Browse files
Files changed (3) hide show
  1. app.py +67 -83
  2. data_processing.py +31 -22
  3. evaluation.py +23 -15
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  from generator import generate_response_from_document
3
  from retrieval import retrieve_documents_hybrid,find_query_dataset
4
  from evaluation import calculate_metrics
5
- from data_processing import load_recent_questions, save_recent_question
6
  import time
7
  import matplotlib.pyplot as plt
8
 
@@ -22,26 +22,6 @@ st.markdown(
22
  unsafe_allow_html=True
23
  )
24
 
25
- # global retrieved_documents
26
- # retrieved_documents = []
27
-
28
- # global response
29
- # response = ""
30
-
31
- # global time_taken_for_response
32
- # time_taken_for_response = 'N/A'
33
-
34
- # @st.cache_data
35
- # def load_data():
36
- # load_data_from_faiss()
37
-
38
- # data_status = load_data()
39
-
40
- # Question Section
41
- st.subheader("Hi, What do you want to know today?")
42
- question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
43
- question = question.strip()
44
-
45
  # # Submit Button
46
  # if st.button("Submit"):
47
  # start_time = time.time()
@@ -70,48 +50,49 @@ question = question.strip()
70
  # with col2:
71
  # st.text_area("Metrics:", value=metrics, height=100, disabled=True)
72
 
73
- if "retrieved_documents" not in st.session_state:
74
- st.session_state.retrieved_documents = []
75
- if "response" not in st.session_state:
76
- st.session_state.response = ""
77
- if "time_taken_for_response" not in st.session_state:
78
- st.session_state.time_taken_for_response = "N/A"
79
- if "metrics" not in st.session_state:
80
- st.session_state.metrics = {}
81
- if "query_dataset" not in st.session_state:
82
- st.session_state.query_dataset = ''
83
  if "recent_questions" not in st.session_state:
84
- st.session_state.recent_questions = {}
 
 
 
 
 
 
 
 
 
85
 
86
- st.session_state.recent_questions = load_recent_questions()
87
- print(st.session_state.recent_questions )
88
 
89
- if st.session_state.recent_questions and "questions" in st.session_state.recent_questions and st.session_state.recent_questions ["questions"]:
90
- recent_qns = list(reversed(st.session_state.recent_questions ["questions"]))
91
 
92
  print(recent_qns)
93
 
94
  # Display Recent Questions
95
- st.sidebar.title("Recent Questions")
96
- for q in recent_qns: # Show latest first
97
- st.sidebar.write(f"🔹 {q['question']}")
 
98
 
99
  st.sidebar.markdown("---")
 
100
  st.sidebar.title("Analytics")
101
 
102
  # Extract response times and labels
103
- response_time = [q['metrics']["response_time"] for q in recent_qns]
104
  labels = [f"Q{i+1}" for i in range(len(response_time))]
105
 
106
  # Plot graph
107
- fig, ax = plt.subplots()
108
- ax.plot(labels, response_time, marker="o", linestyle="-", color="skyblue")
109
- ax.set_xlabel("Recent Questions")
110
- ax.set_ylabel("Time Taken for Response (seconds)")
111
- ax.set_title("Response Time Analysis")
112
-
113
- # Display the plot in the sidebar
114
- st.sidebar.pyplot(fig)
115
 
116
  st.sidebar.markdown("---")
117
 
@@ -122,52 +103,55 @@ if st.session_state.recent_questions and "questions" in st.session_state.recent
122
  else:
123
  st.sidebar.title("No recent questions")
124
 
125
- if st.button("Submit"):
126
- start_time = time.time()
127
- st.session_state.metrics = {}
128
- st.session_state.query_dataset = find_query_dataset(question)
129
- st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 10)
130
- st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
131
- end_time = time.time()
132
- st.session_state.time_taken_for_response = end_time - start_time
133
-
134
- # Store in session state
135
- # st.session_state.recent_questions.append({
136
- # "question": question,
137
- # "response_time": st.session_state.time_taken_for_response
138
- # })
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # Display stored response
141
  st.subheader("Response")
142
  st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
143
 
144
  col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
145
 
146
- # # Calculate Metrics Button
147
- # with col1:
148
- # if st.button("Calculate Metrics"):
149
- # metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
150
- # else:
151
- # metrics = {}
152
-
153
- # with col2:
154
- # #st.text_area("Metrics:", value=metrics, height=100, disabled=True)
155
- # st.json(metrics)
156
-
157
-
158
  # Calculate Metrics Button
159
  with col1:
160
  if st.button("Show Metrics"):
161
  st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
162
- metrics_ = st.session_state.metrics
163
- else:
164
- metrics_ = {}
165
-
166
- with col2:
167
- #st.text_area("Metrics:", value=metrics, height=100, disabled=True)
168
- if len(metrics_) > 0:
169
- st.json(metrics_)
170
 
171
- save_recent_question(question, st.session_state.metrics)
 
172
 
 
 
173
 
 
 
2
  from generator import generate_response_from_document
3
  from retrieval import retrieve_documents_hybrid,find_query_dataset
4
  from evaluation import calculate_metrics
5
+ from data_processing import load_recent_questions, save_recent_questions
6
  import time
7
  import matplotlib.pyplot as plt
8
 
 
22
  unsafe_allow_html=True
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # # Submit Button
26
  # if st.button("Submit"):
27
  # start_time = time.time()
 
50
  # with col2:
51
  # st.text_area("Metrics:", value=metrics, height=100, disabled=True)
52
 
53
+ # Initialize session state
 
 
 
 
 
 
 
 
 
54
  if "recent_questions" not in st.session_state:
55
+ st.session_state.recent_questions = load_recent_questions()
56
+
57
+ if "last_question" not in st.session_state:
58
+ st.session_state.last_question = None
59
+
60
+ if "response_time" not in st.session_state:
61
+ st.session_state.response_time = None
62
+
63
+ if "retrieved_documents" not in st.session_state:
64
+ st.session_state.retrieved_documents = None
65
 
66
+ if "response" not in st.session_state:
67
+ st.session_state.response = None
68
 
69
+ if st.session_state.recent_questions:
70
+ recent_qns = list(reversed(st.session_state.recent_questions))
71
 
72
  print(recent_qns)
73
 
74
  # Display Recent Questions
75
+ st.sidebar.title("Overall RMSE")
76
+ rmse_values = [q["metrics"]["rmse"] for q in recent_qns if "metrics" in q and "rmse" in q["metrics"]]
77
+ average_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else 0
78
+ st.sidebar.write(f"📊 **Average RMSE:** {average_rmse:.4f}")
79
 
80
  st.sidebar.markdown("---")
81
+
82
  st.sidebar.title("Analytics")
83
 
84
  # Extract response times and labels
85
+ response_time = [q.get('metrics').get('response_time') for q in recent_qns]
86
  labels = [f"Q{i+1}" for i in range(len(response_time))]
87
 
88
  # Plot graph
89
+ if any(response_time):
90
+ fig, ax = plt.subplots()
91
+ ax.plot(labels, response_time, marker="o", linestyle="-", color="skyblue")
92
+ ax.set_xlabel("Recent Questions")
93
+ ax.set_ylabel("Time Taken for Response (seconds)")
94
+ ax.set_title("Response Time Analysis")
95
+ st.sidebar.pyplot(fig)
 
96
 
97
  st.sidebar.markdown("---")
98
 
 
103
  else:
104
  st.sidebar.title("No recent questions")
105
 
106
+ # Question Section
107
+ st.subheader("Hi, What do you want to know today?")
108
+ question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
109
+ question = question.strip()
 
 
 
 
 
 
 
 
 
 
110
 
111
+ if st.button("Submit"):
112
+ if question:
113
+ st.session_state.last_question = question
114
+
115
+ start_time = time.time()
116
+ st.session_state.metrics = {}
117
+ st.session_state.response = ""
118
+ st.session_state.query_dataset = find_query_dataset(question)
119
+ st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 10)
120
+ st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
121
+ end_time = time.time()
122
+ st.session_state.time_taken_for_response = end_time - start_time
123
+
124
+ # Check if question already exists
125
+ existing_questions = [q["question"] for q in st.session_state.recent_questions]
126
+
127
+ if question not in existing_questions:
128
+ new_entry = {
129
+ "question": question,
130
+ "metrics": st.session_state.metrics
131
+ }
132
+ st.session_state.recent_questions.append(new_entry)
133
+ save_recent_questions(st.session_state.recent_questions)
134
+ else:
135
+ st.error("Please enter a question before submitting.")
136
+
137
  # Display stored response
138
  st.subheader("Response")
139
  st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
140
 
141
  col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
142
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # Calculate Metrics Button
144
  with col1:
145
  if st.button("Show Metrics"):
146
  st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
147
+ for q in st.session_state.recent_questions:
148
+ if q["question"] == st.session_state.last_question:
149
+ q["metrics"] = {"metrics": st.session_state.metrics}
 
 
 
 
 
150
 
151
+ # Save updated data to file
152
+ save_recent_questions(st.session_state.recent_questions)
153
 
154
+ with col2:
155
+ st.text_area("Metrics:", value=st.session_state.metrics, height=100, disabled=True)
156
 
157
+ st.experimental_rerun()
data_processing.py CHANGED
@@ -23,10 +23,10 @@ query_dataset_data = {}
23
 
24
  RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
25
 
26
- # Ensure the file exists and initialize if empty
27
- if not os.path.exists(RECENT_QUESTIONS_FILE):
28
- with open(RECENT_QUESTIONS_FILE, "w") as file:
29
- json.dump({"questions": []}, file, indent=4)
30
 
31
  all_documents = []
32
  ragbench = {}
@@ -130,27 +130,36 @@ def rerank_documents(query, retrieved_docs):
130
  ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
131
  return ranked_docs[:5] # Return top 5 most relevant
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def load_recent_questions():
134
  if os.path.exists(RECENT_QUESTIONS_FILE):
135
  with open(RECENT_QUESTIONS_FILE, "r") as file:
136
  return json.load(file)
137
- return {"questions": []}
138
-
139
- def save_recent_question(question, metrics):
140
- data = load_recent_questions()
141
-
142
- if "question" in data["questions"] and question not in data["questions"]["question"]:
143
- # Append new question & metrics
144
- data["questions"].append({
145
- "question": question,
146
- "metrics": metrics
147
- })
148
 
149
- # Keep only the last 5 questions
150
- data["questions"] = data["questions"][-5:]
151
-
152
- # Write back to file
153
  with open(RECENT_QUESTIONS_FILE, "w") as file:
154
- json.dump(data, file, indent=4)
155
-
156
-
 
23
 
24
  RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
25
 
26
+ # # Ensure the file exists and initialize if empty
27
+ # if not os.path.exists(RECENT_QUESTIONS_FILE):
28
+ # with open(RECENT_QUESTIONS_FILE, "w") as file:
29
+ # json.dump({"questions": []}, file, indent=4)
30
 
31
  all_documents = []
32
  ragbench = {}
 
130
  ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
131
  return ranked_docs[:5] # Return top 5 most relevant
132
 
133
+ # def load_recent_questions():
134
+ # if os.path.exists(RECENT_QUESTIONS_FILE):
135
+ # with open(RECENT_QUESTIONS_FILE, "r") as file:
136
+ # return json.load(file)
137
+ # return {"questions": []}
138
+
139
+ # def save_recent_question(question, metrics_1):
140
+ # data = load_recent_questions()
141
+
142
+ # # Append new question & metrics
143
+ # data["questions"].append({
144
+ # "question": question,
145
+ # "metrics": metrics_1
146
+ # })
147
+
148
+ # # # Keep only the last 5 questions
149
+ # # data["questions"] = data["questions"][-5:]
150
+
151
+ # # Write back to file
152
+ # with open(RECENT_QUESTIONS_FILE, "w") as file:
153
+ # json.dump(data, file, indent=4)
154
+
155
+ # Load previous questions from file
156
  def load_recent_questions():
157
  if os.path.exists(RECENT_QUESTIONS_FILE):
158
  with open(RECENT_QUESTIONS_FILE, "r") as file:
159
  return json.load(file)
160
+ return []
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # Save questions to file
163
+ def save_recent_questions(data):
 
 
164
  with open(RECENT_QUESTIONS_FILE, "w") as file:
165
+ json.dump(data, file, indent=4)
 
 
evaluation.py CHANGED
@@ -85,7 +85,7 @@ def adherence(response, relevant_documents):
85
  response_tokens = set(response.split())
86
  relevant_tokens = set(combined_docs.split())
87
  supported_tokens = response_tokens.intersection(relevant_tokens)
88
- return len(supported_tokens) / len(response_tokens)
89
 
90
  # Step 6: Compute RMSE for metrics
91
  def compute_rmse(predicted_values, ground_truth_values):
@@ -93,7 +93,7 @@ def compute_rmse(predicted_values, ground_truth_values):
93
 
94
  def calculate_metrics(question, q_dataset, response, docs, time_taken):
95
  data = load_query_dataset(q_dataset)
96
- ground_truth_answer = retrieve_ground_truths(question, data) # Store the ground truth answer
97
 
98
  # Ensure ground_truth_answer is not empty before proceeding
99
  if ground_truth_answer is None:
@@ -104,12 +104,20 @@ def calculate_metrics(question, q_dataset, response, docs, time_taken):
104
  "context_relevance": context_relevance(question, docs),
105
  "context_utilization": context_utilization(response, docs),
106
  "completeness": completeness(response, ground_truth_answer),
107
- "adherence": adherence(response, docs),
 
 
 
 
 
 
 
108
  "response_time": time_taken,
109
  "ground_truth": ground_truth_answer,
110
- "RAG_model_response": response
111
  }
112
- return predicted_metrics
 
113
 
114
  def retrieve_ground_truths(question, dataset):
115
  for split_name, instances in dataset.items():
@@ -118,18 +126,18 @@ def retrieve_ground_truths(question, dataset):
118
  #if instance['question'] == question:
119
  if is_similar(instance['question'], question):
120
  instance_id = instance['id']
121
- instance_response = instance['response']
122
- # ground_truth_metrics = {
123
- # "context_relevance": instance['relevance_score'],
124
- # "context_utilization": instance['utilization_score'],
125
- # "completeness": instance['completeness_score'],
126
- # "adherence": instance['adherence_score']
127
- # }
128
  print(f"Match found in {split_name} split!")
129
- print(f"ID: {instance_id}, Response: {instance_response}")
130
- return instance_response # Return ground truth response immediately
131
 
132
- return None
133
 
134
  def is_similar(question1, question2, threshold=0.85):
135
  vectorizer = TfidfVectorizer()
 
85
  response_tokens = set(response.split())
86
  relevant_tokens = set(combined_docs.split())
87
  supported_tokens = response_tokens.intersection(relevant_tokens)
88
+ return len(supported_tokens) / len(response_tokens) >= 0.5
89
 
90
  # Step 6: Compute RMSE for metrics
91
  def compute_rmse(predicted_values, ground_truth_values):
 
93
 
94
  def calculate_metrics(question, q_dataset, response, docs, time_taken):
95
  data = load_query_dataset(q_dataset)
96
+ ground_truth_answer, ground_truth_metrics = retrieve_ground_truths(question, data) # Store the ground truth answer
97
 
98
  # Ensure ground_truth_answer is not empty before proceeding
99
  if ground_truth_answer is None:
 
104
  "context_relevance": context_relevance(question, docs),
105
  "context_utilization": context_utilization(response, docs),
106
  "completeness": completeness(response, ground_truth_answer),
107
+ "adherence": adherence(response, docs),
108
+ }
109
+
110
+ rmse = compute_rmse(predicted_metrics, ground_truth_metrics),
111
+
112
+ metrics = {
113
+ "RMSE": rmse,
114
+ "metrics":predicted_metrics,
115
  "response_time": time_taken,
116
  "ground_truth": ground_truth_answer,
117
+ "RAG_model_response": response,
118
  }
119
+
120
+ return metrics
121
 
122
  def retrieve_ground_truths(question, dataset):
123
  for split_name, instances in dataset.items():
 
126
  #if instance['question'] == question:
127
  if is_similar(instance['question'], question):
128
  instance_id = instance['id']
129
+ ground_truth = instance['response']
130
+ ground_truth_metrics_ = {
131
+ "context_relevance": instance['relevance_score'],
132
+ "context_utilization": instance['utilization_score'],
133
+ "completeness": instance['completeness_score'],
134
+ "adherence": instance['adherence_score']
135
+ }
136
  print(f"Match found in {split_name} split!")
137
+ print(f"ID: {instance_id}, Response: {ground_truth}")
138
+ return ground_truth , ground_truth_metrics_ # Return ground truth response immediately
139
 
140
+ return None, None
141
 
142
  def is_similar(question1, question2, threshold=0.85):
143
  vectorizer = TfidfVectorizer()