cb1716pics commited on
Commit
a523549
·
verified ·
1 Parent(s): 973db40

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +8 -2
  2. data_processing.py +5 -7
  3. evaluation.py +5 -2
app.py CHANGED
@@ -2,7 +2,8 @@ import streamlit as st
2
  from generator import generate_response_from_document
3
  from retrieval import retrieve_documents
4
  from evaluation import calculate_metrics
5
- from data_processing import load_data_from_faiss, ragbench
 
6
 
7
  # Page Title
8
  st.title("RAG7 - Real World RAG System")
@@ -13,14 +14,19 @@ def load_data():
13
 
14
  data_status = load_data()
15
 
 
 
16
  # Question Section
17
  st.subheader("Hi, What do you want to know today?")
18
  question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
19
 
20
  # Submit Button
21
  if st.button("Submit"):
 
22
  retrieved_documents = retrieve_documents(question, 5)
23
  response = generate_response_from_document(question, retrieved_documents)
 
 
24
  else:
25
  response = ""
26
 
@@ -35,7 +41,7 @@ col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics d
35
 
36
  with col1:
37
  if st.button("Calculate Metrics"):
38
- metrics = calculate_metrics(question, response, retrieved_documents, ragbench)
39
  else:
40
  metrics = ""
41
 
 
2
  from generator import generate_response_from_document
3
  from retrieval import retrieve_documents
4
  from evaluation import calculate_metrics
5
+ from data_processing import load_data_from_faiss
6
+ import time
7
 
8
  # Page Title
9
  st.title("RAG7 - Real World RAG System")
 
14
 
15
  data_status = load_data()
16
 
17
+ time_taken_for_response = 'N/A'
18
+
19
  # Question Section
20
  st.subheader("Hi, What do you want to know today?")
21
  question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
22
 
23
  # Submit Button
24
  if st.button("Submit"):
25
+ start_time = time.time()
26
  retrieved_documents = retrieve_documents(question, 5)
27
  response = generate_response_from_document(question, retrieved_documents)
28
+ end_time = time.time()
29
+ time_taken_for_response = end_time-start_time
30
  else:
31
  response = ""
32
 
 
41
 
42
  with col1:
43
  if st.button("Calculate Metrics"):
44
+ metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
45
  else:
46
  metrics = ""
47
 
data_processing.py CHANGED
@@ -15,8 +15,6 @@ embedding_model = HuggingFaceEmbeddings(
15
  )
16
 
17
  all_documents = []
18
- index = None
19
- actual_docs = None
20
  ragbench = {}
21
 
22
 
@@ -39,9 +37,10 @@ def create_faiss_index_file():
39
  # Convert embeddings to a NumPy array
40
  embeddings_np = np.array(embeddings, dtype=np.float32)
41
 
 
42
  # Store in FAISS using the NumPy array's shape
43
- index = faiss.IndexFlatL2(embeddings_np.shape[1])
44
- index.add(embeddings_np)
45
 
46
  # Save FAISS index
47
  faiss.write_index(index, f"data_local/rag7_index.faiss")
@@ -53,7 +52,6 @@ def create_faiss_index_file():
53
  print(f"data is stored!")
54
 
55
  def load_data_from_faiss():
56
- load_ragbench()
57
  load_faiss()
58
  load_metatdata()
59
 
@@ -63,11 +61,11 @@ def load_ragbench():
63
  ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
64
 
65
  def load_faiss():
66
- # Load the correct FAISS index
67
  faiss_index_path = f"data_local/rag7_index.faiss"
68
  index = faiss.read_index(faiss_index_path)
69
 
70
  def load_metatdata():
71
- # Load document metadata
72
  with open(f"data_local/rag7_docs.json", "r") as f:
73
  actual_docs = json.load(f) # Contains all documents for this dataset
 
15
  )
16
 
17
  all_documents = []
 
 
18
  ragbench = {}
19
 
20
 
 
37
  # Convert embeddings to a NumPy array
38
  embeddings_np = np.array(embeddings, dtype=np.float32)
39
 
40
+ global index_w
41
  # Store in FAISS using the NumPy array's shape
42
+ index_w = faiss.IndexFlatL2(embeddings_np.shape[1])
43
+ index_w.add(embeddings_np)
44
 
45
  # Save FAISS index
46
  faiss.write_index(index, f"data_local/rag7_index.faiss")
 
52
  print(f"data is stored!")
53
 
54
  def load_data_from_faiss():
 
55
  load_faiss()
56
  load_metatdata()
57
 
 
61
  ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
62
 
63
  def load_faiss():
64
+ global index
65
  faiss_index_path = f"data_local/rag7_index.faiss"
66
  index = faiss.read_index(faiss_index_path)
67
 
68
  def load_metatdata():
69
+ global actual_docs
70
  with open(f"data_local/rag7_docs.json", "r") as f:
71
  actual_docs = json.load(f) # Contains all documents for this dataset
evaluation.py CHANGED
@@ -4,12 +4,15 @@ from sklearn.metrics import mean_squared_error, roc_auc_score
4
  from sklearn.feature_extraction.text import TfidfVectorizer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
 
 
 
7
  ground_truth_answer = ''
8
  ground_truth_metrics = {}
9
 
10
 
11
- def calculate_metrics(question, response, docs,data, time_taken):
12
- retrieve_ground_truths(question,data)
 
13
  # Predicted metrics
14
  predicted_metrics = {
15
  "context_relevance": context_relevance(question, docs),
 
4
  from sklearn.feature_extraction.text import TfidfVectorizer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
 
7
+ from data_processing import load_ragbench
8
+
9
  ground_truth_answer = ''
10
  ground_truth_metrics = {}
11
 
12
 
13
+ def calculate_metrics(question, response, docs, time_taken):
14
+ data = load_ragbench()
15
+ retrieve_ground_truths(question, data)
16
  # Predicted metrics
17
  predicted_metrics = {
18
  "context_relevance": context_relevance(question, docs),