cb1716pics commited on
Commit
ece1395
·
verified ·
1 Parent(s): 4433c64

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +23 -25
  2. data_processing.py +7 -15
  3. evaluation.py +12 -5
  4. retrieval.py +34 -10
app.py CHANGED
@@ -14,7 +14,7 @@ st.markdown(
14
  <style>
15
  .stTextArea textarea {
16
  background-color: white !important;
17
- font-size: 20px !important;
18
  color: black !important;
19
  }
20
  </style>
@@ -82,19 +82,19 @@ if "query_dataset" not in st.session_state:
82
  st.session_state.query_dataset = ''
83
 
84
  recent_questions = load_recent_questions()
 
85
 
86
- # for visualization
87
-
88
- # response_time = [q["response_time"] for q in recent_data["questions"]]
89
- # labels = [f"Q{i+1}" for i in range(len(response_time))] # Labels for X-axis
90
-
91
- # fig, ax = plt.subplots()
92
- # ax.set_xlabel("Recent Questions")
93
- # ax.set_ylabel("Time Taken for Response")
94
- # ax.legend()
95
- # st.sidebar.pyplot(fig)
96
  if recent_questions and "questions" in recent_questions and recent_questions["questions"]:
97
  recent_qns = list(reversed(recent_questions["questions"]))
 
 
 
 
 
 
 
 
 
98
  st.sidebar.title("Analytics")
99
 
100
  # Extract response times and labels
@@ -119,18 +119,6 @@ if recent_questions and "questions" in recent_questions and recent_questions["qu
119
  st.sidebar.write(f"🔹 {q['question']}")
120
  else:
121
  st.sidebar.write("No recent questions")
122
- # Separator
123
-
124
- # Streamlit Sidebar for Recent Questions
125
-
126
-
127
- # Submit Button
128
- # if st.button("Submit"):
129
- # start_time = time.time()
130
- # st.session_state.retrieved_documents = retrieve_documents_hybrid(question, 10)
131
- # st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
132
- # end_time = time.time()
133
- # st.session_state.time_taken_for_response = end_time - start_time
134
 
135
  if st.button("Submit"):
136
  start_time = time.time()
@@ -140,7 +128,12 @@ if st.button("Submit"):
140
  st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
141
  end_time = time.time()
142
  st.session_state.time_taken_for_response = end_time - start_time
143
- save_recent_question(question, st.session_state.time_taken_for_response)
 
 
 
 
 
144
 
145
  # Display stored response
146
  st.subheader("Response")
@@ -164,10 +157,15 @@ col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics d
164
  with col1:
165
  if st.button("Show Metrics"):
166
  st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
 
167
  else:
168
  metrics_ = {}
169
 
170
  with col2:
171
  #st.text_area("Metrics:", value=metrics, height=100, disabled=True)
172
- st.json(st.session_state.metrics)
 
 
 
 
173
 
 
14
  <style>
15
  .stTextArea textarea {
16
  background-color: white !important;
17
+ font-size: 24px !important;
18
  color: black !important;
19
  }
20
  </style>
 
82
  st.session_state.query_dataset = ''
83
 
84
  recent_questions = load_recent_questions()
85
+ print(recent_questions)
86
 
 
 
 
 
 
 
 
 
 
 
87
  if recent_questions and "questions" in recent_questions and recent_questions["questions"]:
88
  recent_qns = list(reversed(recent_questions["questions"]))
89
+
90
+ print(recent_qns)
91
+
92
+ # Display Recent Questions
93
+ st.sidebar.title("Recent Questions")
94
+ for q in recent_qns: # Show latest first
95
+ st.sidebar.write(f"🔹 {q['question']}")
96
+
97
+ st.sidebar.markdown("---")
98
  st.sidebar.title("Analytics")
99
 
100
  # Extract response times and labels
 
119
  st.sidebar.write(f"🔹 {q['question']}")
120
  else:
121
  st.sidebar.write("No recent questions")
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  if st.button("Submit"):
124
  start_time = time.time()
 
128
  st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
129
  end_time = time.time()
130
  st.session_state.time_taken_for_response = end_time - start_time
131
+
132
+ # Store in session state
133
+ st.session_state.recent_questions.append({
134
+ "question": question,
135
+ "response_time": st.session_state.time_taken_for_response
136
+ })
137
 
138
  # Display stored response
139
  st.subheader("Response")
 
157
  with col1:
158
  if st.button("Show Metrics"):
159
  st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
160
+ metrics_ = st.session_state.metrics
161
  else:
162
  metrics_ = {}
163
 
164
  with col2:
165
  #st.text_area("Metrics:", value=metrics, height=100, disabled=True)
166
+ if len(metrics_) > 0:
167
+ st.json(metrics_)
168
+
169
+ save_recent_question(question, st.session_state.metrics)
170
+
171
 
data_processing.py CHANGED
@@ -21,7 +21,6 @@ embedding_model = HuggingFaceEmbeddings(
21
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
22
  query_dataset_data = {}
23
 
24
- # File path for storing recently asked questions and metrics
25
  RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
26
 
27
  # Ensure the file exists and initialize if empty
@@ -36,10 +35,7 @@ chunk_docs = []
36
  documents = []
37
  query_dataset_data = {}
38
 
39
- # Ensure data directory exists
40
- os.makedirs("data_local", exist_ok=True)
41
-
42
- # Initialize a text splitter
43
  text_splitter = RecursiveCharacterTextSplitter(
44
  chunk_size=1024,
45
  chunk_overlap=100
@@ -55,14 +51,12 @@ def create_faiss_index(dataset):
55
 
56
  for split in ragbench_dataset.keys():
57
  for row in ragbench_dataset[split]:
58
- # Ensure document is a string before appending
59
  doc = row["documents"]
60
  if isinstance(doc, list):
61
- # If doc is a list, join its elements into a single string
62
  doc = " ".join(doc)
63
- documents.append(doc) # Extract document text
64
- # Chunking
65
-
66
  chunked_documents = chunk_documents(documents)
67
 
68
  # Save documents in JSON (metadata storage)
@@ -76,7 +70,6 @@ def create_faiss_index(dataset):
76
  # Convert embeddings to a NumPy array
77
  embeddings_np = np.array(embeddings, dtype=np.float32)
78
 
79
-
80
  # Save FAISS index
81
  index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size
82
  index.add(embeddings_np)
@@ -141,17 +134,16 @@ def load_recent_questions():
141
  if os.path.exists(RECENT_QUESTIONS_FILE):
142
  with open(RECENT_QUESTIONS_FILE, "r") as file:
143
  return json.load(file)
144
- return {"questions": []} # Default structure if file doesn't exist
145
 
146
- def save_recent_question(question, response_time):
147
  data = load_recent_questions()
148
 
149
- #data["questions"] = [q for q in data["questions"] if q["question"] != question]
150
  if "question" in data["questions"] and question not in data["questions"]["question"]:
151
  # Append new question & metrics
152
  data["questions"].append({
153
  "question": question,
154
- "response_time": response_time
155
  })
156
 
157
  # Keep only the last 5 questions
 
21
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
22
  query_dataset_data = {}
23
 
 
24
  RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
25
 
26
  # Ensure the file exists and initialize if empty
 
35
  documents = []
36
  query_dataset_data = {}
37
 
38
+ # Text splitter
 
 
 
39
  text_splitter = RecursiveCharacterTextSplitter(
40
  chunk_size=1024,
41
  chunk_overlap=100
 
51
 
52
  for split in ragbench_dataset.keys():
53
  for row in ragbench_dataset[split]:
 
54
  doc = row["documents"]
55
  if isinstance(doc, list):
 
56
  doc = " ".join(doc)
57
+ documents.append(doc) #
58
+
59
+ # Chunking
60
  chunked_documents = chunk_documents(documents)
61
 
62
  # Save documents in JSON (metadata storage)
 
70
  # Convert embeddings to a NumPy array
71
  embeddings_np = np.array(embeddings, dtype=np.float32)
72
 
 
73
  # Save FAISS index
74
  index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size
75
  index.add(embeddings_np)
 
134
  if os.path.exists(RECENT_QUESTIONS_FILE):
135
  with open(RECENT_QUESTIONS_FILE, "r") as file:
136
  return json.load(file)
137
+ return {"questions": []}
138
 
139
+ def save_recent_question(question, metrics):
140
  data = load_recent_questions()
141
 
 
142
  if "question" in data["questions"] and question not in data["questions"]["question"]:
143
  # Append new question & metrics
144
  data["questions"].append({
145
  "question": question,
146
+ "metrics": metrics
147
  })
148
 
149
  # Keep only the last 5 questions
evaluation.py CHANGED
@@ -101,13 +101,13 @@ def calculate_metrics(question, q_dataset, response, docs, time_taken):
101
 
102
  # Predicted metrics
103
  predicted_metrics = {
104
- "RAG_model_response": response,
105
- "ground_truth": ground_truth_answer,
106
  "context_relevance": context_relevance(question, docs),
107
  "context_utilization": context_utilization(response, docs),
108
  "completeness": completeness(response, ground_truth_answer),
109
  "adherence": adherence(response, docs),
110
- "response_time": time_taken
 
 
111
  }
112
  return predicted_metrics
113
 
@@ -115,7 +115,8 @@ def retrieve_ground_truths(question, dataset):
115
  for split_name, instances in dataset.items():
116
  print(f"Processing {split_name} split")
117
  for instance in instances:
118
- if instance['question'] == question:
 
119
  instance_id = instance['id']
120
  instance_response = instance['response']
121
  # ground_truth_metrics = {
@@ -128,4 +129,10 @@ def retrieve_ground_truths(question, dataset):
128
  print(f"ID: {instance_id}, Response: {instance_response}")
129
  return instance_response # Return ground truth response immediately
130
 
131
- return None # Return None if no match is found
 
 
 
 
 
 
 
101
 
102
  # Predicted metrics
103
  predicted_metrics = {
 
 
104
  "context_relevance": context_relevance(question, docs),
105
  "context_utilization": context_utilization(response, docs),
106
  "completeness": completeness(response, ground_truth_answer),
107
  "adherence": adherence(response, docs),
108
+ "response_time": time_taken,
109
+ "ground_truth": ground_truth_answer,
110
+ "RAG_model_response": response
111
  }
112
  return predicted_metrics
113
 
 
115
  for split_name, instances in dataset.items():
116
  print(f"Processing {split_name} split")
117
  for instance in instances:
118
+ #if instance['question'] == question:
119
+ if is_similar(instance['question'], question):
120
  instance_id = instance['id']
121
  instance_response = instance['response']
122
  # ground_truth_metrics = {
 
129
  print(f"ID: {instance_id}, Response: {instance_response}")
130
  return instance_response # Return ground truth response immediately
131
 
132
+ return None
133
+
134
+ def is_similar(question1, question2, threshold=0.85):
135
+ vectorizer = TfidfVectorizer()
136
+ vectors = vectorizer.fit_transform([question1, question2])
137
+ similarity = cosine_similarity(vectors[0], vectors[1])[0][0]
138
+ return similarity >= threshold
retrieval.py CHANGED
@@ -5,11 +5,17 @@ import faiss
5
  from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model
7
  from sentence_transformers import CrossEncoder
 
 
8
 
9
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
10
 
11
  retrieved_docs = None
12
 
 
 
 
 
13
  def retrieve_documents_hybrid(query, q_dataset, top_k=5):
14
  with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
15
  chunked_documents = json.load(f) # Contains all documents for this dataset
@@ -18,29 +24,48 @@ def retrieve_documents_hybrid(query, q_dataset, top_k=5):
18
  index = faiss.read_index(faiss_index_path)
19
 
20
  # Tokenize documents for BM25
21
- tokenized_docs = [doc.split() for doc in chunked_documents]
22
  bm25 = BM25Okapi(tokenized_docs)
23
 
24
  query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
25
  query_embedding = query_embedding.reshape(1, -1)
26
 
27
  # FAISS Search
28
- _, nearest_indices = index.search(query_embedding, top_k)
29
- faiss_docs = [chunked_documents[i] for i in nearest_indices[0]]
30
 
31
  # BM25 Search
32
- tokenized_query = query.split()
33
  bm25_scores = bm25.get_scores(tokenized_query)
34
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
35
  bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
36
 
37
- # Merge FAISS + BM25 Results
38
- retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
39
-
40
- reranked_docs = rerank_documents(query, retrieved_docs)
 
 
 
 
 
41
 
42
  return reranked_docs
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # Retrieval Function
45
  # def retrieve_documents(query, top_k=5):
46
  # query_dataset = find_query_dataset(query)
@@ -62,9 +87,8 @@ def retrieve_documents_hybrid(query, q_dataset, top_k=5):
62
 
63
  def remove_duplicate_documents(documents):
64
  unique_documents = []
65
- seen_documents = set() # To keep track of seen documents
66
  for doc in documents:
67
- # Using the page_content as a unique identifier for deduplication
68
  doc_content = doc.page_content
69
  if doc_content not in seen_documents:
70
  unique_documents.append(doc)
 
5
  from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model
7
  from sentence_transformers import CrossEncoder
8
+ from nltk.tokenize import word_tokenize
9
+ import string
10
 
11
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
12
 
13
  retrieved_docs = None
14
 
15
+ # Tokenize the documents and remove punctuation
16
+ def preprocess(doc):
17
+ return [word.lower() for word in word_tokenize(doc) if word not in string.punctuation]
18
+
19
  def retrieve_documents_hybrid(query, q_dataset, top_k=5):
20
  with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
21
  chunked_documents = json.load(f) # Contains all documents for this dataset
 
24
  index = faiss.read_index(faiss_index_path)
25
 
26
  # Tokenize documents for BM25
27
+ tokenized_docs = [preprocess(doc) for doc in chunked_documents]
28
  bm25 = BM25Okapi(tokenized_docs)
29
 
30
  query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
31
  query_embedding = query_embedding.reshape(1, -1)
32
 
33
  # FAISS Search
34
+ faiss_distances, faiss_indices = index.search(query_embedding, top_k)
35
+ faiss_docs = [chunked_documents[i] for i in faiss_indices[0]]
36
 
37
  # BM25 Search
38
+ tokenized_query = preprocess(query)
39
  bm25_scores = bm25.get_scores(tokenized_query)
40
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
41
  bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
42
 
43
+ # Combine FAISS + BM25 scores and retrieve docs
44
+ combined_results = set(bm25_top_indices).union(set(faiss_indices[0]))
45
+
46
+ combined_scores = rerank_docs_bm25faiss_scores(combined_results,bm25_scores, faiss_distances,faiss_indices)
47
+ reranked_docs = [chunked_documents[result[0]] for result in combined_scores[:top_k]]
48
+
49
+ # Merge FAISS + BM25 Results and re-rank
50
+ #retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
51
+ #reranked_docs = rerank_documents(query, retrieved_docs)
52
 
53
  return reranked_docs
54
 
55
+ def rerank_docs_bm25faiss_scores(combined_results_,bm25_scores_, faiss_distances_,faiss_indices_):
56
+ final_results = []
57
+ for idx in combined_results_:
58
+ # Combine BM25 score and FAISS score for ranking (this could be more sophisticated)
59
+ bm25_score = bm25_scores_[idx]
60
+ faiss_score = 1 / (1 + faiss_distances_[0][np.where(faiss_indices_[0] == idx)]) # Inverse distance for relevance
61
+ final_results.append((idx, bm25_score, faiss_score))
62
+
63
+ # Sort final results by combined score (you can adjust the ranking strategy here)
64
+ final_results.sort(key=lambda x: (x[1] + x[2]), reverse=True)
65
+
66
+ return final_results
67
+
68
+
69
  # Retrieval Function
70
  # def retrieve_documents(query, top_k=5):
71
  # query_dataset = find_query_dataset(query)
 
87
 
88
  def remove_duplicate_documents(documents):
89
  unique_documents = []
90
+ seen_documents = set()
91
  for doc in documents:
 
92
  doc_content = doc.page_content
93
  if doc_content not in seen_documents:
94
  unique_documents.append(doc)