gourisankar85 commited on
Commit
870f650
·
verified ·
1 Parent(s): a1710c6

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +169 -197
  2. config.py +16 -0
  3. main.py +54 -33
app.py CHANGED
@@ -1,209 +1,181 @@
1
  import gradio as gr
2
- import os
3
- import json
4
- import pandas as pd
5
- from scripts.evaluate_information_integration import evaluate_information_integration
6
- from scripts.evaluate_negative_rejection import evaluate_negative_rejection
7
- from scripts.helper import update_config
8
- from scripts.evaluate_noise_robustness import evaluate_noise_robustness
9
- from scripts.evaluate_factual_robustness import evaluate_factual_robustness
10
-
11
- # Path to score files
12
- Noise_Robustness_DIR = "results/Noise Robustness/"
13
- Negative_Rejection_DIR = "results/Negative Rejection/"
14
- Counterfactual_Robustness_DIR = "results/Counterfactual Robustness/"
15
- Infomration_Integration_DIR = "results/Information Integration/"
16
-
17
- # Function to read and aggregate score data
18
- def load_scores(file_dir):
19
- models = set()
20
- noise_rates = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- if not os.path.exists(file_dir):
23
- return pd.DataFrame(columns=["Noise Ratio"])
24
-
25
- score_data = {}
26
-
27
- # Read all JSON score files
28
- for filename in os.listdir(file_dir):
29
- if filename.startswith("scores_") and filename.endswith(".json"):
30
- filepath = os.path.join(file_dir, filename)
31
- with open(filepath, "r") as f:
32
- score = json.load(f)
33
- model = score["model"]
34
- noise_rate = str(score["noise_rate"])
35
-
36
- models.add(model)
37
- noise_rates.add(noise_rate)
38
-
39
- score_data[(model, noise_rate)] = score["accuracy"]
40
-
41
- # Convert to DataFrame
42
- df = pd.DataFrame([
43
- {
44
- "Noise Ratio": model,
45
- **{
46
- rate: f"{score_data.get((model, rate), 'N/A') * 100:.2f}"
47
- if score_data.get((model, rate), "N/A") != "N/A"
48
- else "N/A"
49
- for rate in sorted(noise_rates, key=float)
50
- }
51
- }
52
- for model in sorted(models)
53
- ])
54
-
55
- return df
56
-
57
- # Function to load Negative Rejection scores (Only for Noise Rate = 1.0)
58
- def load_negative_rejection_scores():
59
- if not os.path.exists(Negative_Rejection_DIR):
60
- return pd.DataFrame()
61
-
62
- score_data = {}
63
- models = set()
64
-
65
- for filename in os.listdir(Negative_Rejection_DIR):
66
- if filename.startswith("scores_") and filename.endswith(".json") and "_noise_1.0_" in filename:
67
- filepath = os.path.join(Negative_Rejection_DIR, filename)
68
- with open(filepath, "r") as f:
69
- score = json.load(f)
70
- model = filename.split("_")[1] # Extract model name
71
- models.add(model)
72
- score_data[model] = score.get("reject_rate", "N/A")
73
-
74
- df = pd.DataFrame([
75
- {"Model": model, "Rejection Rate": f"{score_data.get(model, 'N/A') * 100:.2f}%"
76
- if score_data.get(model, "N/A") != "N/A"
77
- else "N/A"}
78
- for model in sorted(models)
79
- ])
80
-
81
- return df if not df.empty else pd.DataFrame(columns=["Model", "Rejection Rate"])
82
-
83
- def load_counterfactual_robustness_scores():
84
- models = set()
85
-
86
- if not os.path.exists(Counterfactual_Robustness_DIR):
87
- return pd.DataFrame(columns=["Noise Ratio"])
88
-
89
- score_data = {}
90
-
91
- # Read all JSON score files
92
- for filename in os.listdir(Counterfactual_Robustness_DIR):
93
- if filename.startswith("scores_") and filename.endswith(".json"):
94
- filepath = os.path.join(Counterfactual_Robustness_DIR, filename)
95
- with open(filepath, "r") as f:
96
- score = json.load(f)
97
- model = filename.split("_")[1]
98
- #noise_rate = str(score["noise_rate"])
99
-
100
- models.add(model)
101
- score_data[model] = {
102
- "Accuracy (%)": int(score["all_rate"] * 100), # No decimal
103
- "Error Detection Rate": int(score["reject_rate"] * 10),
104
- "Correction Rate (%)": round(score["correct_rate"] * 100, 2) # 2 decimal places
105
- }
106
-
107
- # Convert to DataFrame
108
- df = pd.DataFrame([
109
- {
110
- "Model": model,
111
- "Accuracy (%)": score_data.get(model, {}).get("Accuracy (%)", "N/A"),
112
- "Error Detection Rate": score_data.get(model, {}).get("Error Detection Rate", "N/A"),
113
- "Correction Rate (%)": f"{score_data.get(model, {}).get('Correction Rate (%)', 'N/A'):.2f}"
114
- }
115
- for model in sorted(models)
116
- ])
117
-
118
- return df
119
-
120
- # Gradio UI
121
- def launch_gradio_app(config):
122
- with gr.Blocks() as app:
123
- app.title = "RAG System Evaluation"
124
- gr.Markdown("# RAG System Evaluation on RGB Dataset")
125
-
126
- # Top Section - Inputs and Controls
127
- with gr.Row():
128
- model_name_input = gr.Dropdown(
129
- label="Model Name",
130
- choices= config["models"],
131
- value="llama3-8b-8192",
132
- interactive=True
133
- )
134
- noise_rate_input = gr.Slider(label="Noise Rate", minimum=0, maximum=1.0, step=0.2, value=0.2, interactive=True)
135
- num_queries_input = gr.Number(label="Number of Queries", value=50, interactive=True)
136
-
137
- # Bottom Section - Action Buttons
138
  with gr.Row():
139
- recalculate_noise_btn = gr.Button("Evaluate Noise Robustness")
140
- recalculate_negative_btn = gr.Button("Evaluate Negative Rejection")
141
- recalculate_counterfactual_btn = gr.Button("Evaluate Counterfactual Robustness")
142
- recalculate_integration_btn = gr.Button("Evaluate Integration Information")
143
 
 
144
  with gr.Row():
145
- refresh_btn = gr.Button("Refresh", variant="primary", scale = 0)
146
-
147
- # Middle Section - Data Tables
 
 
 
 
 
148
  with gr.Row():
149
- with gr.Column():
150
- gr.Markdown("### 📊 Noise Robustness\n**Description:** The experimental result of noise robustness measured by accuracy (%) under different noise ratios. Result show that the increasing noise rate poses a challenge for RAG in LLMs.")
151
- noise_table = gr.Dataframe(value=load_scores(Noise_Robustness_DIR), interactive=False)
152
- with gr.Column():
153
- gr.Markdown("### 🚫 Negative Rejection\n**Description:** This measures the model's ability to reject invalid or nonsensical queries instead of generating incorrect responses. A higher rejection rate means the model is better at filtering unreliable inputs.")
154
- rejection_table = gr.Dataframe(value=load_negative_rejection_scores(), interactive=False)
155
-
156
  with gr.Row():
157
- with gr.Column():
158
- gr.Markdown("""
159
- ### 🔄 Counterfactual Robustness
160
- **Description:**
161
- Counterfactual Robustness evaluates a model's ability to handle **errors in external knowledge** while ensuring reliable responses.
162
-
163
- **Key Metrics in this Report:**
164
- - **Accuracy (%)** → Measures the accuracy (%) of LLMs with counterfactual documents.
165
- - **Error Detection Rate (%)** → Measures how often the model **rejects** incorrect or misleading queries instead of responding.
166
- - **Correct Rate (%)** → Measures how often the model provides accurate responses despite **potential misinformation**.
167
- """)
168
- counter_factual_table = gr.Dataframe(value=load_counterfactual_robustness_scores(), interactive=False)
169
- with gr.Column():
170
- gr.Markdown("### 🧠 Information Integration\n**Description:** The experimental result of information integration measured by accuracy (%) under different noise ratios. The result show that information integration poses a challenge for RAG in LLMs")
171
- integration_table = gr.Dataframe(value=load_scores(Infomration_Integration_DIR), interactive=False)
172
-
173
-
174
- # Refresh Scores Function
175
- def refresh_scores():
176
- return load_scores(Noise_Robustness_DIR), load_negative_rejection_scores(), load_counterfactual_robustness_scores(), load_scores(Infomration_Integration_DIR)
177
-
178
- refresh_btn.click(refresh_scores, outputs=[noise_table, rejection_table, counter_factual_table, integration_table])
179
-
180
- # Button Functions
181
- def recalculate_noise_robustness(model_name, noise_rate, num_queries):
182
- update_config(config, model_name, noise_rate, num_queries)
183
- evaluate_noise_robustness(config)
184
- return load_scores(Noise_Robustness_DIR)
185
 
186
- recalculate_noise_btn.click(recalculate_noise_robustness, inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[noise_table])
187
-
188
- def recalculate_counterfactual_robustness(model_name, noise_rate, num_queries):
189
- update_config(config, model_name, noise_rate, num_queries)
190
- evaluate_factual_robustness(config)
191
- return load_counterfactual_robustness_scores()
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- recalculate_counterfactual_btn.click(recalculate_counterfactual_robustness, inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[counter_factual_table])
194
-
195
- def recalculate_negative_rejection(model_name, noise_rate, num_queries):
196
- update_config(config, model_name, noise_rate, num_queries)
197
- evaluate_negative_rejection(config)
198
- return load_negative_rejection_scores()
 
 
 
 
 
199
 
200
- recalculate_negative_btn.click(recalculate_negative_rejection, inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[rejection_table])
 
 
 
 
201
 
202
- def recalculate_integration_info(model_name, noise_rate, num_queries):
203
- update_config(config, model_name, noise_rate, num_queries)
204
- evaluate_information_integration(config)
205
- return load_scores(Infomration_Integration_DIR)
206
-
207
- recalculate_integration_btn.click(recalculate_integration_info , inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[integration_table])
208
 
209
- app.launch()
 
1
  import gradio as gr
2
+ import logging
3
+ import threading
4
+ import time
5
+ from generator.compute_metrics import get_attributes_text
6
+ from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
+ from config import AppConfig, ConfigConstants
8
+ from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
9
+
10
+ def launch_gradio(config : AppConfig):
11
+ """
12
+ Launch the Gradio app with pre-initialized objects.
13
+ """
14
+ logger = logging.getLogger()
15
+ logger.setLevel(logging.INFO)
16
+
17
+ # Create a list to store logs
18
+ logs = []
19
+
20
+ # Custom log handler to capture logs and add them to the logs list
21
+ class LogHandler(logging.Handler):
22
+ def emit(self, record):
23
+ log_entry = self.format(record)
24
+ logs.append(log_entry)
25
+
26
+ # Add custom log handler to the logger
27
+ log_handler = LogHandler()
28
+ log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
29
+ logger.addHandler(log_handler)
30
+
31
+ def log_updater():
32
+ """Background function to add logs."""
33
+ while True:
34
+ time.sleep(2) # Update logs every 2 seconds
35
+ pass # Log capture is now handled by the logging system
36
+
37
+ def get_logs():
38
+ """Retrieve logs for display."""
39
+ return "\n".join(logs[-50:]) # Only show the last 50 logs for example
40
+
41
+ # Start the logging thread
42
+ threading.Thread(target=log_updater, daemon=True).start()
43
+
44
+ def answer_question(query, state):
45
+ try:
46
+ # Generate response using the passed objects
47
+ response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
48
+
49
+ # Update state with the response and source documents
50
+ state["query"] = query
51
+ state["response"] = response
52
+ state["source_docs"] = source_docs
53
+
54
+ response_text = f"Response: {response}\n\n"
55
+ return response_text, state
56
+ except Exception as e:
57
+ logging.error(f"Error processing query: {e}")
58
+ return f"An error occurred: {e}", state
59
+
60
+ def compute_metrics(state):
61
+ try:
62
+ logging.info(f"Computing metrics")
63
+
64
+ # Retrieve response and source documents from state
65
+ response = state.get("response", "")
66
+ source_docs = state.get("source_docs", {})
67
+ query = state.get("query", "")
68
+
69
+ # Generate metrics using the passed objects
70
+ attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
71
+
72
+ attributes_text = get_attributes_text(attributes)
73
+
74
+ metrics_text = "Metrics:\n"
75
+ for key, value in metrics.items():
76
+ if key != 'response':
77
+ metrics_text += f"{key}: {value}\n"
78
+
79
+ return attributes_text, metrics_text
80
+ except Exception as e:
81
+ logging.error(f"Error computing metrics: {e}")
82
+ return f"An error occurred: {e}", ""
83
+
84
+ def reinitialize_gen_llm(gen_llm_name):
85
+ """Reinitialize the generation LLM and return updated model info."""
86
+ if gen_llm_name.strip(): # Only update if input is not empty
87
+ config.gen_llm = initialize_generation_llm(gen_llm_name)
88
+
89
+ # Return updated model information
90
+ updated_model_info = (
91
+ f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
92
+ f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
93
+ f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
94
+ )
95
+ return updated_model_info
96
+
97
+ def reinitialize_val_llm(val_llm_name):
98
+ """Reinitialize the generation LLM and return updated model info."""
99
+ if val_llm_name.strip(): # Only update if input is not empty
100
+ config.val_llm = initialize_validation_llm(val_llm_name)
101
+
102
+ # Return updated model information
103
+ updated_model_info = (
104
+ f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
105
+ f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
106
+ f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
107
+ )
108
+ return updated_model_info
109
 
110
+ # Define Gradio Blocks layout
111
+ with gr.Blocks() as interface:
112
+ interface.title = "Real Time RAG Pipeline Q&A"
113
+ gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
114
+
115
+ # Textbox for new generation LLM name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with gr.Row():
117
+ new_gen_llm_input = gr.Textbox(label="New Generation LLM Name", placeholder="Enter LLM name to update")
118
+ update_gen_llm_button = gr.Button("Update Generation LLM")
119
+ new_val_llm_input = gr.Textbox(label="New Validation LLM Name", placeholder="Enter LLM name to update")
120
+ update_val_llm_button = gr.Button("Update Validation LLM")
121
 
122
+ # Section to display LLM names
123
  with gr.Row():
124
+ model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
125
+ model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
126
+ model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
127
+ model_info_display = gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
128
+
129
+ # State to store response and source documents
130
+ state = gr.State(value={"query": "","response": "", "source_docs": {}})
131
+ gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
132
  with gr.Row():
133
+ query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
 
 
 
 
 
 
134
  with gr.Row():
135
+ submit_button = gr.Button("Submit", variant="primary") # Submit button
136
+ clear_query_button = gr.Button("Clear") # Clear button
137
+ with gr.Row():
138
+ answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ with gr.Row():
141
+ compute_metrics_button = gr.Button("Compute metrics", variant="primary")
142
+ attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
143
+ metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
144
+
145
+ #with gr.Row():
146
+
147
+ # Define button actions
148
+ submit_button.click(
149
+ fn=answer_question,
150
+ inputs=[query_input, state],
151
+ outputs=[answer_output, state]
152
+ )
153
+ clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
154
+ compute_metrics_button.click(
155
+ fn=compute_metrics,
156
+ inputs=[state],
157
+ outputs=[attr_output, metrics_output]
158
+ )
159
 
160
+ update_gen_llm_button.click(
161
+ fn=reinitialize_gen_llm,
162
+ inputs=[new_gen_llm_input],
163
+ outputs=[model_info_display] # Update the displayed model info
164
+ )
165
+
166
+ update_val_llm_button.click(
167
+ fn=reinitialize_val_llm,
168
+ inputs=[new_val_llm_input],
169
+ outputs=[model_info_display] # Update the displayed model info
170
+ )
171
 
172
+ # Section to display logs
173
+ with gr.Row():
174
+ start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
175
+ with gr.Row():
176
+ log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) # Log section
177
 
178
+ # Set button click to trigger log updates
179
+ start_log_button.click(fn=get_logs, outputs=log_section)
 
 
 
 
180
 
181
+ interface.launch()
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class ConfigConstants:
3
+ # Constants related to datasets and models
4
+ DATA_SET_NAMES = ['covidqa', 'cuad', 'techqa']#, 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa']
5
+ EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
+ RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
+ GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
8
+ VALIDATION_MODEL_NAME = 'llama3-70b-8192'
9
+ DEFAULT_CHUNK_SIZE = 1000
10
+ CHUNK_OVERLAP = 200
11
+
12
+ class AppConfig:
13
+ def __init__(self, vector_store, gen_llm, val_llm):
14
+ self.vector_store = vector_store
15
+ self.gen_llm = gen_llm
16
+ self.val_llm = val_llm
main.py CHANGED
@@ -1,43 +1,64 @@
1
- import json
2
  import logging
3
- from app import launch_gradio_app
4
- from scripts.evaluate_factual_robustness import evaluate_factual_robustness
5
- from scripts.evaluate_negative_rejection import evaluate_negative_rejection
6
- from scripts.evaluate_noise_robustness import evaluate_noise_robustness
7
- from scripts.download_files import download_file, get_file_list
8
-
9
- def load_config(config_file="config.json"):
10
- """Load configuration from the config file."""
11
- try:
12
- with open(config_file, "r", encoding="utf-8") as f:
13
- config = json.load(f)
14
- return config
15
- except Exception as e:
16
- logging.info(f"Error loading config: {e}")
17
- return {}
18
 
19
  def main():
20
- # Load configuration
21
- config = load_config()
 
 
 
22
 
23
- logging.info(f"Model: {config["model_name"]}")
24
- logging.info(f"Noise Rate: {config["noise_rate"]}")
25
- logging.info(f"Passage Number: {config["passage_num"]}")
26
- logging.info(f"Number of Queries: {config["num_queries"]}")
27
 
28
- # Download files from the GitHub repository
29
- files = get_file_list()
30
- for file in files:
31
- download_file(file)
 
 
 
 
32
 
33
- # Load dataset from the local JSON file
 
 
 
 
 
 
 
 
 
34
 
 
 
35
 
36
- # Call evaluate_noise_robustness for each noise rate and model
37
- #evaluate_noise_robustness(config)
38
- #evaluate_negative_rejection(config)
39
- #evaluate_factual_robustness(config)
40
- launch_gradio_app(config)
 
 
41
 
 
 
 
 
 
 
42
  if __name__ == "__main__":
43
- main()
 
 
1
  import logging
2
+ from config import AppConfig, ConfigConstants
3
+ from data.load_dataset import load_data
4
+ from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
5
+ from retriever.chunk_documents import chunk_documents
6
+ from retriever.embed_documents import embed_documents
7
+ from generator.initialize_llm import initialize_generation_llm
8
+ from generator.initialize_llm import initialize_validation_llm
9
+ from app import launch_gradio
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
13
 
14
  def main():
15
+ logging.info("Starting the RAG pipeline")
16
+
17
+ # Dictionary to store chunked documents
18
+ all_chunked_documents = []
19
+ datasets = {}
20
 
21
+ # Load multiple datasets
22
+ for data_set_name in ConfigConstants.DATA_SET_NAMES:
23
+ logging.info(f"Loading dataset: {data_set_name}")
24
+ datasets[data_set_name] = load_data(data_set_name)
25
 
26
+ # Set chunk size based on dataset name
27
+ chunk_size = ConfigConstants.DEFAULT_CHUNK_SIZE
28
+ if data_set_name == 'cuad':
29
+ chunk_size = 4000 # Custom chunk size for 'cuad'
30
+
31
+ # Chunk documents
32
+ chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=ConfigConstants.CHUNK_OVERLAP)
33
+ all_chunked_documents.extend(chunked_documents) # Combine all chunks
34
 
35
+ # Access individual datasets
36
+ #for name, dataset in datasets.items():
37
+ #logging.info(f"Loaded {name} with {dataset.num_rows} rows")
38
+
39
+ # Logging final count
40
+ logging.info(f"Total chunked documents: {len(all_chunked_documents)}")
41
+
42
+ # Embed the documents
43
+ vector_store = embed_documents(all_chunked_documents)
44
+ logging.info("Documents embedded")
45
 
46
+ # Initialize the Generation LLM
47
+ gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME)
48
 
49
+ # Initialize the Validation LLM
50
+ val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME)
51
+
52
+ #Compute RMSE and AUC-ROC for entire dataset
53
+ #Enable below code for calculation
54
+ #data_set_name = 'covidqa'
55
+ #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
56
 
57
+ # Launch the Gradio app
58
+ config = AppConfig(vector_store= vector_store, gen_llm = gen_llm, val_llm = val_llm)
59
+ launch_gradio(config)
60
+
61
+ logging.info("Finished!!!")
62
+
63
  if __name__ == "__main__":
64
+ main()