cb1716pics commited on
Commit
01c5a73
·
verified ·
1 Parent(s): ffaa00c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +12 -10
  2. retrieval.py +14 -15
app.py CHANGED
@@ -78,14 +78,16 @@ if "time_taken_for_response" not in st.session_state:
78
  st.session_state.time_taken_for_response = "N/A"
79
  if "metrics" not in st.session_state:
80
  st.session_state.metrics = {}
81
- if "query_dataset" not in st.session_state:
82
  st.session_state.query_dataset = ''
 
 
83
 
84
- recent_questions = load_recent_questions()
85
- print(recent_questions)
86
 
87
- if recent_questions and "questions" in recent_questions and recent_questions["questions"]:
88
- recent_qns = list(reversed(recent_questions["questions"]))
89
 
90
  print(recent_qns)
91
 
@@ -98,7 +100,7 @@ if recent_questions and "questions" in recent_questions and recent_questions["qu
98
  st.sidebar.title("Analytics")
99
 
100
  # Extract response times and labels
101
- response_time = [q["response_time"] for q in recent_qns]
102
  labels = [f"Q{i+1}" for i in range(len(response_time))]
103
 
104
  # Plot graph
@@ -130,10 +132,10 @@ if st.button("Submit"):
130
  st.session_state.time_taken_for_response = end_time - start_time
131
 
132
  # Store in session state
133
- st.session_state.recent_questions.append({
134
- "question": question,
135
- "response_time": st.session_state.time_taken_for_response
136
- })
137
 
138
  # Display stored response
139
  st.subheader("Response")
 
78
  st.session_state.time_taken_for_response = "N/A"
79
  if "metrics" not in st.session_state:
80
  st.session_state.metrics = {}
81
+ if "query_dataset" not in
82
  st.session_state.query_dataset = ''
83
+ if "recent_questions" not in st.session_state:
84
+ st.session_state.recent_questions = {}
85
 
86
+ st.session_state.recent_questions = load_recent_questions()
87
+ print(st.session_state.recent_questions )
88
 
89
+ if st.session_state.recent_questions and "questions" in st.session_state.recent_questions and st.session_state.recent_questions ["questions"]:
90
+ recent_qns = list(reversed(st.session_state.recent_questions ["questions"]))
91
 
92
  print(recent_qns)
93
 
 
100
  st.sidebar.title("Analytics")
101
 
102
  # Extract response times and labels
103
+ response_time = [q['metrics']["response_time"] for q in recent_qns]
104
  labels = [f"Q{i+1}" for i in range(len(response_time))]
105
 
106
  # Plot graph
 
132
  st.session_state.time_taken_for_response = end_time - start_time
133
 
134
  # Store in session state
135
+ # st.session_state.recent_questions.append({
136
+ # "question": question,
137
+ # "response_time": st.session_state.time_taken_for_response
138
+ # })
139
 
140
  # Display stored response
141
  st.subheader("Response")
retrieval.py CHANGED
@@ -5,12 +5,11 @@ import faiss
5
  from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model
7
  from sentence_transformers import CrossEncoder
8
- import string
9
- import nltk
10
 
11
- import nltk
12
- nltk.download('punkt')
13
- nltk.download('punkt_tab')
 
14
 
15
  from nltk.tokenize import word_tokenize
16
 
@@ -19,8 +18,8 @@ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
19
  retrieved_docs = None
20
 
21
  # Tokenize the documents and remove punctuation
22
- def preprocess(doc):
23
- return [word.lower() for word in word_tokenize(doc) if word not in string.punctuation]
24
 
25
  def retrieve_documents_hybrid(query, q_dataset, top_k=5):
26
  with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
@@ -30,31 +29,31 @@ def retrieve_documents_hybrid(query, q_dataset, top_k=5):
30
  index = faiss.read_index(faiss_index_path)
31
 
32
  # Tokenize documents for BM25
33
- tokenized_docs = [preprocess(doc) for doc in chunked_documents]
34
  bm25 = BM25Okapi(tokenized_docs)
35
 
36
  query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
37
  query_embedding = query_embedding.reshape(1, -1)
38
 
39
  # FAISS Search
40
- faiss_distances, faiss_indices = index.search(query_embedding, top_k)
41
  faiss_docs = [chunked_documents[i] for i in faiss_indices[0]]
42
 
43
  # BM25 Search
44
- tokenized_query = preprocess(query)
45
  bm25_scores = bm25.get_scores(tokenized_query)
46
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
47
  bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
48
 
49
  # Combine FAISS + BM25 scores and retrieve docs
50
- combined_results = set(bm25_top_indices).union(set(faiss_indices[0]))
51
 
52
- combined_scores = rerank_docs_bm25faiss_scores(combined_results,bm25_scores, faiss_distances,faiss_indices)
53
- reranked_docs = [chunked_documents[result[0]] for result in combined_scores[:top_k]]
54
 
55
  # Merge FAISS + BM25 Results and re-rank
56
- #retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
57
- #reranked_docs = rerank_documents(query, retrieved_docs)
58
 
59
  return reranked_docs
60
 
 
5
  from rank_bm25 import BM25Okapi
6
  from data_processing import embedding_model
7
  from sentence_transformers import CrossEncoder
 
 
8
 
9
+ #import string
10
+ # import nltk
11
+ # nltk.download('punkt')
12
+ # nltk.download('punkt_tab')
13
 
14
  from nltk.tokenize import word_tokenize
15
 
 
18
  retrieved_docs = None
19
 
20
  # Tokenize the documents and remove punctuation
21
+ # def preprocess(doc):
22
+ # return [word.lower() for word in word_tokenize(doc) if word not in string.punctuation]
23
 
24
  def retrieve_documents_hybrid(query, q_dataset, top_k=5):
25
  with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
 
29
  index = faiss.read_index(faiss_index_path)
30
 
31
  # Tokenize documents for BM25
32
+ tokenized_docs = [doc.split() for doc in chunked_documents]
33
  bm25 = BM25Okapi(tokenized_docs)
34
 
35
  query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
36
  query_embedding = query_embedding.reshape(1, -1)
37
 
38
  # FAISS Search
39
+ _, faiss_indices = index.search(query_embedding, top_k)
40
  faiss_docs = [chunked_documents[i] for i in faiss_indices[0]]
41
 
42
  # BM25 Search
43
+ tokenized_query = query.split() #preprocess(query)
44
  bm25_scores = bm25.get_scores(tokenized_query)
45
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
46
  bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
47
 
48
  # Combine FAISS + BM25 scores and retrieve docs
49
+ # combined_results = set(bm25_top_indices).union(set(faiss_indices[0]))
50
 
51
+ # combined_scores = rerank_docs_bm25faiss_scores(combined_results,bm25_scores, faiss_distances,faiss_indices)
52
+ # reranked_docs = [chunked_documents[result[0]] for result in combined_scores[:top_k]]
53
 
54
  # Merge FAISS + BM25 Results and re-rank
55
+ retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
56
+ reranked_docs = rerank_documents(query, retrieved_docs)
57
 
58
  return reranked_docs
59