gourisankar85 commited on
Commit
efb5c9e
·
verified ·
1 Parent(s): 870f650

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +195 -169
  2. config.json +2 -2
  3. main.py +25 -56
app.py CHANGED
@@ -1,181 +1,207 @@
1
  import gradio as gr
2
- import logging
3
- import threading
4
- import time
5
- from generator.compute_metrics import get_attributes_text
6
- from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
- from config import AppConfig, ConfigConstants
8
- from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
9
-
10
- def launch_gradio(config : AppConfig):
11
- """
12
- Launch the Gradio app with pre-initialized objects.
13
- """
14
- logger = logging.getLogger()
15
- logger.setLevel(logging.INFO)
16
-
17
- # Create a list to store logs
18
- logs = []
19
-
20
- # Custom log handler to capture logs and add them to the logs list
21
- class LogHandler(logging.Handler):
22
- def emit(self, record):
23
- log_entry = self.format(record)
24
- logs.append(log_entry)
25
-
26
- # Add custom log handler to the logger
27
- log_handler = LogHandler()
28
- log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
29
- logger.addHandler(log_handler)
30
-
31
- def log_updater():
32
- """Background function to add logs."""
33
- while True:
34
- time.sleep(2) # Update logs every 2 seconds
35
- pass # Log capture is now handled by the logging system
36
-
37
- def get_logs():
38
- """Retrieve logs for display."""
39
- return "\n".join(logs[-50:]) # Only show the last 50 logs for example
40
-
41
- # Start the logging thread
42
- threading.Thread(target=log_updater, daemon=True).start()
43
-
44
- def answer_question(query, state):
45
- try:
46
- # Generate response using the passed objects
47
- response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
48
-
49
- # Update state with the response and source documents
50
- state["query"] = query
51
- state["response"] = response
52
- state["source_docs"] = source_docs
53
-
54
- response_text = f"Response: {response}\n\n"
55
- return response_text, state
56
- except Exception as e:
57
- logging.error(f"Error processing query: {e}")
58
- return f"An error occurred: {e}", state
59
-
60
- def compute_metrics(state):
61
- try:
62
- logging.info(f"Computing metrics")
63
-
64
- # Retrieve response and source documents from state
65
- response = state.get("response", "")
66
- source_docs = state.get("source_docs", {})
67
- query = state.get("query", "")
68
-
69
- # Generate metrics using the passed objects
70
- attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
71
-
72
- attributes_text = get_attributes_text(attributes)
73
-
74
- metrics_text = "Metrics:\n"
75
- for key, value in metrics.items():
76
- if key != 'response':
77
- metrics_text += f"{key}: {value}\n"
78
-
79
- return attributes_text, metrics_text
80
- except Exception as e:
81
- logging.error(f"Error computing metrics: {e}")
82
- return f"An error occurred: {e}", ""
83
-
84
- def reinitialize_gen_llm(gen_llm_name):
85
- """Reinitialize the generation LLM and return updated model info."""
86
- if gen_llm_name.strip(): # Only update if input is not empty
87
- config.gen_llm = initialize_generation_llm(gen_llm_name)
88
-
89
- # Return updated model information
90
- updated_model_info = (
91
- f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
92
- f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
93
- f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
94
- )
95
- return updated_model_info
96
-
97
- def reinitialize_val_llm(val_llm_name):
98
- """Reinitialize the generation LLM and return updated model info."""
99
- if val_llm_name.strip(): # Only update if input is not empty
100
- config.val_llm = initialize_validation_llm(val_llm_name)
101
-
102
- # Return updated model information
103
- updated_model_info = (
104
- f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
105
- f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
106
- f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
107
- )
108
- return updated_model_info
109
 
110
- # Define Gradio Blocks layout
111
- with gr.Blocks() as interface:
112
- interface.title = "Real Time RAG Pipeline Q&A"
113
- gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
114
-
115
- # Textbox for new generation LLM name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with gr.Row():
117
- new_gen_llm_input = gr.Textbox(label="New Generation LLM Name", placeholder="Enter LLM name to update")
118
- update_gen_llm_button = gr.Button("Update Generation LLM")
119
- new_val_llm_input = gr.Textbox(label="New Validation LLM Name", placeholder="Enter LLM name to update")
120
- update_val_llm_button = gr.Button("Update Validation LLM")
121
-
122
- # Section to display LLM names
 
 
 
 
123
  with gr.Row():
124
- model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
125
- model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
126
- model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
127
- model_info_display = gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
128
-
129
- # State to store response and source documents
130
- state = gr.State(value={"query": "","response": "", "source_docs": {}})
131
- gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
132
  with gr.Row():
133
- query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
 
 
134
  with gr.Row():
135
- submit_button = gr.Button("Submit", variant="primary") # Submit button
136
- clear_query_button = gr.Button("Clear") # Clear button
 
 
 
 
 
137
  with gr.Row():
138
- answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- with gr.Row():
141
- compute_metrics_button = gr.Button("Compute metrics", variant="primary")
142
- attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
143
- metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
144
-
145
- #with gr.Row():
146
-
147
- # Define button actions
148
- submit_button.click(
149
- fn=answer_question,
150
- inputs=[query_input, state],
151
- outputs=[answer_output, state]
152
- )
153
- clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
154
- compute_metrics_button.click(
155
- fn=compute_metrics,
156
- inputs=[state],
157
- outputs=[attr_output, metrics_output]
158
- )
159
 
160
- update_gen_llm_button.click(
161
- fn=reinitialize_gen_llm,
162
- inputs=[new_gen_llm_input],
163
- outputs=[model_info_display] # Update the displayed model info
164
- )
165
-
166
- update_val_llm_button.click(
167
- fn=reinitialize_val_llm,
168
- inputs=[new_val_llm_input],
169
- outputs=[model_info_display] # Update the displayed model info
170
- )
171
 
172
- # Section to display logs
173
- with gr.Row():
174
- start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
175
- with gr.Row():
176
- log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) # Log section
 
 
 
177
 
178
- # Set button click to trigger log updates
179
- start_log_button.click(fn=get_logs, outputs=log_section)
 
 
 
 
180
 
181
- interface.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import json
4
+ import pandas as pd
5
+ from scripts.evaluate_information_integration import evaluate_information_integration
6
+ from scripts.evaluate_negative_rejection import evaluate_negative_rejection
7
+ from scripts.helper import update_config
8
+ from scripts.evaluate_noise_robustness import evaluate_noise_robustness
9
+ from scripts.evaluate_factual_robustness import evaluate_factual_robustness
10
+
11
+ # Path to score files
12
+ Noise_Robustness_DIR = "results/Noise Robustness/"
13
+ Negative_Rejection_DIR = "results/Negative Rejection/"
14
+ Counterfactual_Robustness_DIR = "results/Counterfactual Robustness/"
15
+ Infomration_Integration_DIR = "results/Information Integration/"
16
+
17
+ # Function to read and aggregate score data
18
+ def load_scores(file_dir):
19
+ models = set()
20
+ noise_rates = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ if not os.path.exists(file_dir):
23
+ return pd.DataFrame(columns=["Noise Ratio"])
24
+
25
+ score_data = {}
26
+
27
+ # Read all JSON score files
28
+ for filename in os.listdir(file_dir):
29
+ if filename.startswith("scores_") and filename.endswith(".json"):
30
+ filepath = os.path.join(file_dir, filename)
31
+ with open(filepath, "r") as f:
32
+ score = json.load(f)
33
+ model = score["model"]
34
+ noise_rate = str(score['noise_rate'])
35
+
36
+ models.add(model)
37
+ noise_rates.add(noise_rate)
38
+
39
+ score_data[(model, noise_rate)] = score["accuracy"]
40
+
41
+ # Convert to DataFrame
42
+ df = pd.DataFrame([
43
+ {
44
+ "Noise Ratio": model,
45
+ **{
46
+ rate: f"{score_data.get((model, rate), 'N/A') * 100:.2f}"
47
+ if score_data.get((model, rate), "N/A") != "N/A"
48
+ else "N/A"
49
+ for rate in sorted(noise_rates, key=float)
50
+ }
51
+ }
52
+ for model in sorted(models)
53
+ ])
54
+
55
+ return df
56
+
57
+ # Function to load Negative Rejection scores (Only for Noise Rate = 1.0)
58
+ def load_negative_rejection_scores():
59
+ if not os.path.exists(Negative_Rejection_DIR):
60
+ return pd.DataFrame()
61
+
62
+ score_data = {}
63
+ models = set()
64
+
65
+ for filename in os.listdir(Negative_Rejection_DIR):
66
+ if filename.startswith("scores_") and filename.endswith(".json") and "_noise_1.0_" in filename:
67
+ filepath = os.path.join(Negative_Rejection_DIR, filename)
68
+ with open(filepath, "r") as f:
69
+ score = json.load(f)
70
+ model = filename.split("_")[1] # Extract model name
71
+ models.add(model)
72
+ score_data[model] = score.get("reject_rate", "N/A")
73
+
74
+ df = pd.DataFrame([
75
+ {"Model": model, "Rejection Rate": f"{score_data.get(model, 'N/A') * 100:.2f}%"
76
+ if score_data.get(model, "N/A") != "N/A"
77
+ else "N/A"}
78
+ for model in sorted(models)
79
+ ])
80
+
81
+ return df if not df.empty else pd.DataFrame(columns=["Model", "Rejection Rate"])
82
+
83
+ def load_counterfactual_robustness_scores():
84
+ models = set()
85
+
86
+ if not os.path.exists(Counterfactual_Robustness_DIR):
87
+ return pd.DataFrame(columns=["Noise Ratio"])
88
+
89
+ score_data = {}
90
+
91
+ # Read all JSON score files
92
+ for filename in os.listdir(Counterfactual_Robustness_DIR):
93
+ if filename.startswith("scores_") and filename.endswith(".json"):
94
+ filepath = os.path.join(Counterfactual_Robustness_DIR, filename)
95
+ with open(filepath, "r") as f:
96
+ score = json.load(f)
97
+ model = filename.split("_")[1]
98
+ models.add(model)
99
+ score_data[model] = {
100
+ "Accuracy (%)": int(score["all_rate"] * 100), # No decimal
101
+ "Error Detection Rate": int(score["reject_rate"] * 10),
102
+ "Correction Rate (%)": round(score["correct_rate"] * 100, 2) # 2 decimal places
103
+ }
104
+
105
+ # Convert to DataFrame
106
+ df = pd.DataFrame([
107
+ {
108
+ "Model": model,
109
+ "Accuracy (%)": score_data.get(model, {}).get("Accuracy (%)", "N/A"),
110
+ "Error Detection Rate": score_data.get(model, {}).get("Error Detection Rate", "N/A"),
111
+ "Correction Rate (%)": f"{score_data.get(model, {}).get('Correction Rate (%)', 'N/A'):.2f}"
112
+ }
113
+ for model in sorted(models)
114
+ ])
115
+
116
+ return df
117
+
118
+ # Gradio UI
119
+ def launch_gradio_app(config):
120
+ with gr.Blocks() as app:
121
+ app.title = "RAG System Evaluation"
122
+ gr.Markdown("# RAG System Evaluation on RGB Dataset")
123
+
124
+ # Top Section - Inputs and Controls
125
  with gr.Row():
126
+ model_name_input = gr.Dropdown(
127
+ label="Model Name",
128
+ choices= config["models"],
129
+ value="llama3-8b-8192",
130
+ interactive=True
131
+ )
132
+ noise_rate_input = gr.Slider(label="Noise Rate", minimum=0, maximum=1.0, step=0.2, value=0.2, interactive=True)
133
+ num_queries_input = gr.Number(label="Number of Queries", value=50, interactive=True)
134
+
135
+ # Bottom Section - Action Buttons
136
  with gr.Row():
137
+ recalculate_noise_btn = gr.Button("Evaluate Noise Robustness")
138
+ recalculate_negative_btn = gr.Button("Evaluate Negative Rejection")
139
+ recalculate_counterfactual_btn = gr.Button("Evaluate Counterfactual Robustness")
140
+ recalculate_integration_btn = gr.Button("Evaluate Integration Information")
141
+
 
 
 
142
  with gr.Row():
143
+ refresh_btn = gr.Button("Refresh", variant="primary", scale = 0)
144
+
145
+ # Middle Section - Data Tables
146
  with gr.Row():
147
+ with gr.Column():
148
+ gr.Markdown("### 📊 Noise Robustness\n**Description:** The experimental result of noise robustness measured by accuracy (%) under different noise ratios. Result show that the increasing noise rate poses a challenge for RAG in LLMs.")
149
+ noise_table = gr.Dataframe(value=load_scores(Noise_Robustness_DIR), interactive=False)
150
+ with gr.Column():
151
+ gr.Markdown("### 🚫 Negative Rejection\n**Description:** This measures the model's ability to reject invalid or nonsensical queries instead of generating incorrect responses. A higher rejection rate means the model is better at filtering unreliable inputs.")
152
+ rejection_table = gr.Dataframe(value=load_negative_rejection_scores(), interactive=False)
153
+
154
  with gr.Row():
155
+ with gr.Column():
156
+ gr.Markdown("""
157
+ ### 🔄 Counterfactual Robustness
158
+ **Description:**
159
+ Counterfactual Robustness evaluates a model's ability to handle **errors in external knowledge** while ensuring reliable responses.
160
+
161
+ **Key Metrics in this Report:**
162
+ - **Accuracy (%)** → Measures the accuracy (%) of LLMs with counterfactual documents.
163
+ - **Error Detection Rate (%)** → Measures how often the model **rejects** incorrect or misleading queries instead of responding.
164
+ - **Correct Rate (%)** → Measures how often the model provides accurate responses despite **potential misinformation**.
165
+ """)
166
+ counter_factual_table = gr.Dataframe(value=load_counterfactual_robustness_scores(), interactive=False)
167
+ with gr.Column():
168
+ gr.Markdown("### 🧠 Information Integration\n**Description:** The experimental result of information integration measured by accuracy (%) under different noise ratios. The result show that information integration poses a challenge for RAG in LLMs")
169
+ integration_table = gr.Dataframe(value=load_scores(Infomration_Integration_DIR), interactive=False)
170
 
171
+
172
+ # Refresh Scores Function
173
+ def refresh_scores():
174
+ return load_scores(Noise_Robustness_DIR), load_negative_rejection_scores(), load_counterfactual_robustness_scores(), load_scores(Infomration_Integration_DIR)
175
+
176
+ refresh_btn.click(refresh_scores, outputs=[noise_table, rejection_table, counter_factual_table, integration_table])
177
+
178
+ # Button Functions
179
+ def recalculate_noise_robustness(model_name, noise_rate, num_queries):
180
+ update_config(config, model_name, noise_rate, num_queries)
181
+ evaluate_noise_robustness(config)
182
+ return load_scores(Noise_Robustness_DIR)
 
 
 
 
 
 
 
183
 
184
+ recalculate_noise_btn.click(recalculate_noise_robustness, inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[noise_table])
185
+
186
+ def recalculate_counterfactual_robustness(model_name, noise_rate, num_queries):
187
+ update_config(config, model_name, noise_rate, num_queries)
188
+ evaluate_factual_robustness(config)
189
+ return load_counterfactual_robustness_scores()
 
 
 
 
 
190
 
191
+ recalculate_counterfactual_btn.click(recalculate_counterfactual_robustness, inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[counter_factual_table])
192
+
193
+ def recalculate_negative_rejection(model_name, noise_rate, num_queries):
194
+ update_config(config, model_name, noise_rate, num_queries)
195
+ evaluate_negative_rejection(config)
196
+ return load_negative_rejection_scores()
197
+
198
+ recalculate_negative_btn.click(recalculate_negative_rejection, inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[rejection_table])
199
 
200
+ def recalculate_integration_info(model_name, noise_rate, num_queries):
201
+ update_config(config, model_name, noise_rate, num_queries)
202
+ evaluate_information_integration(config)
203
+ return load_scores(Infomration_Integration_DIR)
204
+
205
+ recalculate_integration_btn.click(recalculate_integration_info , inputs=[model_name_input, noise_rate_input, num_queries_input], outputs=[integration_table])
206
 
207
+ app.launch()
config.json CHANGED
@@ -3,11 +3,11 @@
3
  "factual_file_name":"en_fact.json",
4
  "integration_file_name":"en_int.json",
5
  "result_path": "results/",
6
- "models": ["llama3-8b-8192","qwen-2.5-32b", "mixtral-8x7b-32768", "gemma2-9b-it", "deepseek-r1-distill-llama-70b" ],
7
  "model_name":"gemma2-9b-it",
8
  "noise_rate": 0.4,
9
  "passage_num": 5,
10
- "num_queries": 10,
11
  "retry_attempts": 3,
12
  "timeout_limit": 60
13
  }
 
3
  "factual_file_name":"en_fact.json",
4
  "integration_file_name":"en_int.json",
5
  "result_path": "results/",
6
+ "models": ["llama3-8b-8192", "qwen-2.5-32b", "mixtral-8x7b-32768", "gemma2-9b-it", "deepseek-r1-distill-llama-70b" ],
7
  "model_name":"gemma2-9b-it",
8
  "noise_rate": 0.4,
9
  "passage_num": 5,
10
+ "num_queries": 50,
11
  "retry_attempts": 3,
12
  "timeout_limit": 60
13
  }
main.py CHANGED
@@ -1,64 +1,33 @@
 
1
  import logging
2
- from config import AppConfig, ConfigConstants
3
- from data.load_dataset import load_data
4
- from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
5
- from retriever.chunk_documents import chunk_documents
6
- from retriever.embed_documents import embed_documents
7
- from generator.initialize_llm import initialize_generation_llm
8
- from generator.initialize_llm import initialize_validation_llm
9
- from app import launch_gradio
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
13
 
14
  def main():
15
- logging.info("Starting the RAG pipeline")
16
-
17
- # Dictionary to store chunked documents
18
- all_chunked_documents = []
19
- datasets = {}
20
 
21
- # Load multiple datasets
22
- for data_set_name in ConfigConstants.DATA_SET_NAMES:
23
- logging.info(f"Loading dataset: {data_set_name}")
24
- datasets[data_set_name] = load_data(data_set_name)
25
 
26
- # Set chunk size based on dataset name
27
- chunk_size = ConfigConstants.DEFAULT_CHUNK_SIZE
28
- if data_set_name == 'cuad':
29
- chunk_size = 4000 # Custom chunk size for 'cuad'
30
-
31
- # Chunk documents
32
- chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=ConfigConstants.CHUNK_OVERLAP)
33
- all_chunked_documents.extend(chunked_documents) # Combine all chunks
34
 
35
- # Access individual datasets
36
- #for name, dataset in datasets.items():
37
- #logging.info(f"Loaded {name} with {dataset.num_rows} rows")
38
-
39
- # Logging final count
40
- logging.info(f"Total chunked documents: {len(all_chunked_documents)}")
41
 
42
- # Embed the documents
43
- vector_store = embed_documents(all_chunked_documents)
44
- logging.info("Documents embedded")
45
-
46
- # Initialize the Generation LLM
47
- gen_llm = initialize_generation_llm(ConfigConstants.GENERATION_MODEL_NAME)
48
-
49
- # Initialize the Validation LLM
50
- val_llm = initialize_validation_llm(ConfigConstants.VALIDATION_MODEL_NAME)
51
-
52
- #Compute RMSE and AUC-ROC for entire dataset
53
- #Enable below code for calculation
54
- #data_set_name = 'covidqa'
55
- #compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
56
-
57
- # Launch the Gradio app
58
- config = AppConfig(vector_store= vector_store, gen_llm = gen_llm, val_llm = val_llm)
59
- launch_gradio(config)
60
-
61
- logging.info("Finished!!!")
62
-
63
  if __name__ == "__main__":
64
- main()
 
1
+ import json
2
  import logging
3
+ from app import launch_gradio_app
4
+ from scripts.download_files import download_file, get_file_list
5
+
6
+ def load_config(config_file="config.json"):
7
+ """Load configuration from the config file."""
8
+ try:
9
+ with open(config_file, "r", encoding="utf-8") as f:
10
+ config = json.load(f)
11
+ return config
12
+ except Exception as e:
13
+ logging.info(f"Error loading config: {e}")
14
+ return {}
15
 
16
  def main():
17
+ # Load configuration
18
+ config = load_config()
 
 
 
19
 
20
+ logging.info(f"Model: {config['model_name']}")
21
+ logging.info(f"Noise Rate: {config['noise_rate']}")
22
+ logging.info(f"Passage Number: {config['passage_num']}")
23
+ logging.info(f"Number of Queries: {config['num_queries']}")
24
 
25
+ # Download files from the GitHub repository
26
+ files = get_file_list()
27
+ for file in files:
28
+ download_file(file)
 
 
 
 
29
 
30
+ launch_gradio_app(config)
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if __name__ == "__main__":
33
+ main()