cb1716pics commited on
Commit
fdc80c8
·
verified ·
1 Parent(s): eea3fa6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +6 -3
  2. evaluation.py +2 -3
  3. generator.py +17 -6
  4. retrieval.py +15 -20
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  from generator import generate_response_from_document
3
- from retrieval import retrieve_documents_hybrid
4
  from evaluation import calculate_metrics
5
  from data_processing import load_recent_questions, save_recent_question
6
  import time
@@ -63,6 +63,8 @@ if "time_taken_for_response" not in st.session_state:
63
  st.session_state.time_taken_for_response = "N/A"
64
  if "metrics" not in st.session_state:
65
  st.session_state.metrics = {}
 
 
66
 
67
  recent_data = load_recent_questions()
68
 
@@ -106,7 +108,8 @@ for q in reversed(recent_data["questions"]): # Show latest first
106
 
107
  if st.button("Submit"):
108
  start_time = time.time()
109
- st.session_state.retrieved_documents = retrieve_documents_hybrid(question, 10)
 
110
  st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
111
  end_time = time.time()
112
  st.session_state.time_taken_for_response = end_time - start_time
@@ -133,7 +136,7 @@ col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics d
133
  # Calculate Metrics Button
134
  with col1:
135
  if st.button("Show Metrics"):
136
- st.session_state.metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
137
  else:
138
  metrics_ = {}
139
 
 
1
  import streamlit as st
2
  from generator import generate_response_from_document
3
+ from retrieval import retrieve_documents_hybrid,find_query_dataset
4
  from evaluation import calculate_metrics
5
  from data_processing import load_recent_questions, save_recent_question
6
  import time
 
63
  st.session_state.time_taken_for_response = "N/A"
64
  if "metrics" not in st.session_state:
65
  st.session_state.metrics = {}
66
+ if "metrics" not in st.session_state:
67
+ st.session_state.metrics = {}
68
 
69
  recent_data = load_recent_questions()
70
 
 
108
 
109
  if st.button("Submit"):
110
  start_time = time.time()
111
+ st.session_state.query_dataset = find_query_dataset(question)
112
+ st.session_state.retrieved_documents = retrieve_documents_hybrid(question, st.session_state.query_dataset, 10)
113
  st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
114
  end_time = time.time()
115
  st.session_state.time_taken_for_response = end_time - start_time
 
136
  # Calculate Metrics Button
137
  with col1:
138
  if st.button("Show Metrics"):
139
+ st.session_state.metrics = calculate_metrics(question, st.session_state.query_dataset, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
140
  else:
141
  metrics_ = {}
142
 
evaluation.py CHANGED
@@ -91,9 +91,8 @@ def adherence(response, relevant_documents):
91
  def compute_rmse(predicted_values, ground_truth_values):
92
  return np.sqrt(mean_squared_error(ground_truth_values, predicted_values))
93
 
94
- def calculate_metrics(question, response, docs, time_taken):
95
- from retrieval import query_dataset
96
- data = load_query_dataset(query_dataset)
97
  ground_truth_answer = retrieve_ground_truths(question, data) # Store the ground truth answer
98
 
99
  # Ensure ground_truth_answer is not empty before proceeding
 
91
  def compute_rmse(predicted_values, ground_truth_values):
92
  return np.sqrt(mean_squared_error(ground_truth_values, predicted_values))
93
 
94
+ def calculate_metrics(question, q_dataset, response, docs, time_taken):
95
+ data = load_query_dataset(q_dataset)
 
96
  ground_truth_answer = retrieve_ground_truths(question, data) # Store the ground truth answer
97
 
98
  # Ensure ground_truth_answer is not empty before proceeding
generator.py CHANGED
@@ -11,12 +11,23 @@ def generate_response_from_document(query, retrieved_docs):
11
 
12
  # context = " ".join([doc.page_content for doc in retrieved_docs]) # Now iterates over Document objects
13
  context = " ".join([doc for doc in retrieved_docs])
14
- prompt = (
15
- "You are a highly intelligent assistant tasked with answering a question based strictly on the provided context. "
16
- f"Given Question: {query} \n\n"
17
- f"Context: {context} \n"
18
- "Answer the question directly and concisely using only the information available in the context."
19
- )
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  try:
22
  response = openai.chat.completions.create( # Use the new chat completions API
 
11
 
12
  # context = " ".join([doc.page_content for doc in retrieved_docs]) # Now iterates over Document objects
13
  context = " ".join([doc for doc in retrieved_docs])
14
+ prompt = """
15
+ "You are an accurate and reliable AI assistant that can answer questions with the help of external documents.
16
+ Please note that external documents may contain noisy or factually incorrect information.
17
+ If the information in the document contains the correct answer, you will give an accurate answer.
18
+ If the information in the document does not contain the answer, you will generate ’I can not answer the question because of the insufficient information in documents.‘.
19
+ If there are inconsistencies with the facts in some of the documents, please generate the response 'There are factual errors in the provided documents.' and provide the correct answer."
20
+
21
+ Context or Document: {context}
22
+ Query: {query}
23
+ """
24
+ # prompt = (
25
+ # "You are a highly intelligent assistant tasked with answering a question based strictly on the provided context. "
26
+ # f"Given Question: {query} \n\n"
27
+ # f"Context: {context} \n"
28
+ # f"Answer the question directly and concisely using only the information available in the context."
29
+ # "Do not include any other information which is not there in the context."
30
+ # )
31
 
32
  try:
33
  response = openai.chat.completions.create( # Use the new chat completions API
retrieval.py CHANGED
@@ -3,23 +3,18 @@ import numpy as np
3
  from langchain.schema import Document
4
  import faiss
5
  from rank_bm25 import BM25Okapi
6
- from data_processing import embedding_model #, index, actual_docs
7
  from sentence_transformers import CrossEncoder
8
 
9
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
10
 
11
  retrieved_docs = None
12
- global query_dataset
13
- query_dataset = ''
14
 
15
- def retrieve_documents_hybrid(query, top_k=5):
16
- #global query_dataset
17
- query_dataset = find_query_dataset(query)
18
-
19
- with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
20
  chunked_documents = json.load(f) # Contains all documents for this dataset
21
 
22
- faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
23
  index = faiss.read_index(faiss_index_path)
24
 
25
  # Tokenize documents for BM25
@@ -47,23 +42,23 @@ def retrieve_documents_hybrid(query, top_k=5):
47
  return reranked_docs
48
 
49
  # Retrieval Function
50
- def retrieve_documents(query, top_k=5):
51
- query_dataset = find_query_dataset(query)
52
- #index, chunk_docs = load_data_from_faiss(query)
53
 
54
- with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
55
- documents = json.load(f) # Contains all documents for this dataset
56
 
57
- faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
58
- index = faiss.read_index(faiss_index_path)
59
 
60
- query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
61
 
62
- _, nearest_indices = index.search(query_embedding, top_k)
63
 
64
- retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
65
 
66
- return retrieved_docs
67
 
68
  def remove_duplicate_documents(documents):
69
  unique_documents = []
 
3
  from langchain.schema import Document
4
  import faiss
5
  from rank_bm25 import BM25Okapi
6
+ from data_processing import embedding_model
7
  from sentence_transformers import CrossEncoder
8
 
9
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
10
 
11
  retrieved_docs = None
 
 
12
 
13
+ def retrieve_documents_hybrid(query, q_dataset, top_k=5):
14
+ with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
 
 
 
15
  chunked_documents = json.load(f) # Contains all documents for this dataset
16
 
17
+ faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
18
  index = faiss.read_index(faiss_index_path)
19
 
20
  # Tokenize documents for BM25
 
42
  return reranked_docs
43
 
44
  # Retrieval Function
45
+ # def retrieve_documents(query, top_k=5):
46
+ # query_dataset = find_query_dataset(query)
47
+ # #index, chunk_docs = load_data_from_faiss(query)
48
 
49
+ # with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
50
+ # documents = json.load(f) # Contains all documents for this dataset
51
 
52
+ # faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
53
+ # index = faiss.read_index(faiss_index_path)
54
 
55
+ # query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
56
 
57
+ # _, nearest_indices = index.search(query_embedding, top_k)
58
 
59
+ # retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
60
 
61
+ # return retrieved_docs
62
 
63
  def remove_duplicate_documents(documents):
64
  unique_documents = []