qatch-demo / app.py
simone-papicchio's picture
chore: add new parameters for the prediction in app.py
d3ef38f
import os
import sys
import time
import re
import csv
import gradio as gr
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from qatch.connectors.sqlite_connector import SqliteConnector
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
from prediction import ModelPrediction
import utils_get_db_tables_info
import utilities as us
# @spaces.GPU
# def model_prediction():
# pass
# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
# if os.environ.get("SPACES_ZERO_GPU") is not None:
# import spaces
# else:
# class spaces:
# @staticmethod
# def GPU(func):
# def wrapper(*args, **kwargs):
# return func(*args, **kwargs)
# return wrapper
pnp_path = "evaluation_p_np_metrics.csv"
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
reset_flag=False
with open('style.css', 'r') as file:
css = file.read()
# DataFrame di default
df_default = pd.DataFrame({
'Name': ['Alice', 'Bob', 'Charlie'],
'Age': [25, 30, 35],
'City': ['New York', 'Los Angeles', 'Chicago']
})
models_path ="models.csv"
# Variabile globale per tenere traccia dei dati correnti
df_current = df_default.copy()
description = """## πŸ“Š Comparison of Proprietary and Non-Proprietary Databases
### ➀ **Proprietary** (πŸ’° Economic, πŸ₯ Medical, πŸ’³ Financial, πŸ“‚ Miscellaneous)
### ➀ **Non-Proprietary** (πŸ•·οΈ Spider 1.0)"""
prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
input_data = {
'input_method': "",
'data_path': "",
'db_name': "",
'data': {
'data_frames': {}, # dictionary of dataframes
'db': None, # SQLITE3 database object
'selected_tables' :[]
},
'models': [],
'prompt': prompt_default
}
def load_data(file, path, use_default):
"""Carica i dati da un file, un percorso o usa il DataFrame di default."""
global df_current
if file is not None:
try:
input_data["input_method"] = 'uploaded_file'
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
if file.endswith('.sqlite'):
#return 'Error: The uploaded file is not a valid SQLite database.'
input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
else:
#change path
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
input_data["data"] = us.load_data(file, input_data["db_name"])
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
table2primary_key = {}
for table_name, df in input_data["data"]['data_frames'].items():
# Assign primary keys for each table
table2primary_key[table_name] = 'id'
input_data["data"]["db"] = SqliteConnector(
relative_db_path=input_data["data_path"],
db_name=input_data["db_name"],
tables= input_data["data"]['data_frames'],
table2primary_key=table2primary_key
)
return input_data["data"]['data_frames']
except Exception as e:
return f'Errore nel caricamento del file: {e}'
if use_default:
if(use_default == 'Custom'):
input_data["input_method"] = 'custom'
#input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
input_data["data_path"] = os.path.join(".","mytable_0.sqlite")
#if file already exist
while os.path.exists(input_data["data_path"]):
input_data["data_path"] = us.increment_filename(input_data["data_path"])
input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
input_data["data"]['data_frames'] = {'MyTable': df_current}
if(input_data["data"]['data_frames']):
table2primary_key = {}
for table_name, df in input_data["data"]['data_frames'].items():
# Assign primary keys for each table
table2primary_key[table_name] = 'id'
input_data["data"]["db"] = SqliteConnector(
relative_db_path=input_data["data_path"],
db_name=input_data["db_name"],
tables= input_data["data"]['data_frames'],
table2primary_key=table2primary_key
)
df_current = df_default.copy() # Ripristina i dati di default
return input_data["data"]['data_frames']
if(use_default == 'Proprietary vs Non-proprietary'):
input_data["input_method"] = 'default'
#input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite")
#input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite")
#input_data["db_name"] = "default"
#input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"])
input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl('tables_dict.pkl')
return input_data["data"]['data_frames']
selected_inputs = sum([file is not None, bool(path), use_default])
if selected_inputs > 1:
return 'Error: Select only one input method at a time.'
return input_data["data"]['data_frames']
def preview_default(use_default, file):
if file:
return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## βœ… File successfully uploaded!", visible=True)
else :
if use_default == 'Custom':
return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## πŸ“ Toy Table", visible=True)
else:
return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True)
#return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato
def update_df(new_df):
"""Aggiorna il DataFrame corrente."""
global df_current # Usa la variabile globale per aggiornarla
df_current = new_df
return df_current
def open_accordion(target):
# Apre uno e chiude l'altro
if target == "reset":
df_current = df_default.copy()
input_data['input_method'] = ""
input_data['data_path'] = ""
input_data['db_name'] = ""
input_data['data']['data_frames'] = {}
input_data['data']['selected_tables'] = []
input_data['data']['db'] = None
input_data['models'] = []
return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None)
elif target == "model_selection":
return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
# Interfaccia Gradio
#with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface:
with gr.Row():
with gr.Column(scale=1):
gr.Image(
value=os.path.join(".", "qatch_logo.png"),
show_label=False,
container=False,
interactive=False,
show_fullscreen_button=False,
show_download_button=False,
show_share_button=False,
height=150, # in pixel
width=300
)
with gr.Column(scale=1):
pass
data_state = gr.State(None) # Memorizza i dati caricati
upload_acc = gr.Accordion("Upload data section", open=True, visible=True)
select_table_acc = gr.Accordion("Select tables section", open=False, visible=False)
select_model_acc = gr.Accordion("Select models section", open=False, visible=False)
qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False)
metrics_acc = gr.Accordion("Metrics section", open=False, visible=False)
#################################
# DATABASE INSERTION #
#################################
with upload_acc:
gr.Markdown("## πŸ“₯Choose data input method")
with gr.Row():
default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
#default_checkbox = gr.Checkbox(label="Use default DataFrame"
table_default = gr.Markdown(description, visible=True)
preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
gr.Markdown("## πŸ“‚ Or upload your data")
file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
submit_button = gr.Button("Load Data") # Disabled by default
output = gr.JSON(visible=False) # Dictionary output
# Function to enable the button if there is data to load
def enable_submit(file, use_default):
return gr.update(interactive=bool(file or use_default))
# Function to uncheck the checkbox if a file is uploaded
def deselect_default(file):
if file:
return gr.update(value='Proprietary vs Non-proprietary')
return gr.update()
def enable_disable_first(enable):
return (
gr.update(interactive=enable),
gr.update(interactive=enable),
gr.update(interactive=enable),
gr.update(interactive=enable)
)
# Enable the button when inputs are provided
#file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
#default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
# Show preview of the default DataFrame when checkbox is selected
default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
# Uncheck the checkbox when a file is uploaded
file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
def handle_output(file, use_default):
"""Handles the output when the 'Load Data' button is pressed."""
result = load_data(file, None, use_default)
if isinstance(result, dict): # If result is a dictionary of DataFrames
if len(result) == 1: # If there's only one table
input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys())
return (
gr.update(visible=False), # Hide JSON output
result, # Save the data state
gr.update(visible=False), # Hide table selection
result, # Maintain the data state
gr.update(interactive=False), # Disable the submit button
gr.update(visible=True, open=True), # Proceed to select_model_acc
gr.update(visible=True, open=False)
)
else:
return (
gr.update(visible=False),
result,
gr.update(open=True, visible=True),
result,
gr.update(interactive=False),
gr.update(visible=False), # Keep current behavior
gr.update(visible=True, open=False)
)
else:
return (
gr.update(visible=False),
None,
gr.update(open=False, visible=True),
None,
gr.update(interactive=True),
gr.update(visible=False),
gr.update(visible=True, open=False)
)
submit_button.click(
fn=handle_output,
inputs=[file_input, default_checkbox],
outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
)
submit_button.click(
fn=enable_disable_first,
inputs=[gr.State(False)],
outputs=[
preview_output,
submit_button,
file_input,
default_checkbox
]
)
######################################
# TABLE SELECTION PART #
######################################
with select_table_acc:
previous_selection = gr.State([])
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False)
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
# Model selection button (initially disabled)
open_model_selection = gr.Button("Choose your models", interactive=False)
def update_table_list(data):
"""Dynamically updates the list of available tables and excluded ones."""
if isinstance(data, dict) and data:
table_names = []
excluded_tables = []
data_frames = input_data['data'].get('data_frames', {})
available_tables = []
for name, df in data.items():
df_real = data_frames.get(name, None)
if df_real is not None and df_real.shape[1] > 15:
excluded_tables.append(name)
else:
available_tables.append(name)
if input_data['input_method'] == "default" or len(available_tables) < 6:
table_names.append("All")
table_names.extend(available_tables)
# Prepara il testo da mostrare
if excluded_tables:
excluded_text = "<b>⚠️ The following tables have more than 15 columns and cannot be selected:</b><br>" + "<br>".join(f"- {t}" for t in excluded_tables)
excluded_visible = True
else:
excluded_text = ""
excluded_visible = False
return [
gr.update(choices=table_names, value=[]), # CheckboxGroup update
gr.update(value=excluded_text, visible=excluded_visible) # HTML display update
]
return [
gr.update(choices=[], value=[]),
gr.update(value="", visible=False)
]
def show_selected_tables(data, selected_tables):
updates = []
data_frames = input_data['data'].get('data_frames', {})
available_tables = []
for name, df in data.items():
df_real = data_frames.get(name)
if df_real is not None and df_real.shape[1] <= 15:
available_tables.append(name)
input_method = input_data['input_method']
allow_all = input_method == "default" or len(available_tables) < 6
selected_set = set(selected_tables)
tables_set = set(available_tables)
if allow_all:
if "All" in selected_set:
selected_tables = ["All"] + available_tables
elif selected_set == tables_set:
selected_tables = []
else:
selected_tables = [t for t in selected_tables if t in available_tables]
else:
selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
tables = {name: data[name] for name in selected_tables if name in data}
for i, (name, df) in enumerate(tables.items()):
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
for _ in range(len(tables), 50):
updates.append(gr.update(visible=False))
updates.append(gr.update(interactive=bool(tables)))
if allow_all:
updates.insert(0, gr.update(
choices=["All"] + available_tables,
value=selected_tables
))
else:
if len(selected_tables) >= 5:
updates.insert(0, gr.update(
choices=selected_tables,
value=selected_tables
))
else:
updates.insert(0, gr.update(
choices=available_tables,
value=selected_tables
))
return updates
def show_selected_table_names(data, selected_tables):
"""Displays the names of the selected tables when the button is pressed."""
if selected_tables:
available_tables = list(data.keys()) # Actually available names
if "All" in selected_tables:
selected_tables = available_tables
input_data['data']['selected_tables'] = selected_tables
return gr.update(value=", ".join(selected_tables), visible=False)
return gr.update(value="", visible=False)
# Automatically updates the checkbox list when `data_state` changes
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info])
# Updates the visible tables and the button state based on user selections
#table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
table_selector.change(
fn=show_selected_tables,
inputs=[data_state, table_selector],
outputs=[table_selector] + table_outputs + [open_model_selection]
)
# Shows the list of selected tables when "Choose your models" is clicked
open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
reset_data = gr.Button("Back to upload data section")
reset_data.click(
fn=enable_disable_first,
inputs=[gr.State(True)],
outputs=[
preview_output,
submit_button,
file_input,
default_checkbox
]
)
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
####################################
# MODEL SELECTION PART #
####################################
with select_model_acc:
gr.Markdown("# Model Selection")
# Assume that `us.read_models_csv` also returns the image path
model_list_dict = us.read_models_csv(models_path)
model_list = [model["code"] for model in model_list_dict]
model_images = [model["image_path"] for model in model_list_dict]
model_names = [model["name"] for model in model_list_dict]
# Create a mapping between model_list and model_images_names
model_mapping = dict(zip(model_list, model_names))
model_mapping_reverse = dict(zip(model_names, model_list))
model_checkboxes = []
rows = []
# Dynamically create checkboxes with images (3 per row)
for i in range(0, len(model_list), 3):
with gr.Row():
cols = []
for j in range(3):
if i + j < len(model_list):
model = model_list[i + j]
image_path = model_images[i + j]
with gr.Column():
gr.Image(image_path,
show_label=False,
container=False,
interactive=False,
show_fullscreen_button=False,
show_download_button=False,
show_share_button=False)
checkbox = gr.Checkbox(label=model_mapping[model], value=False)
model_checkboxes.append(checkbox)
cols.append(checkbox)
rows.append(cols)
selected_models_output = gr.JSON(visible=False)
# Function to get selected models
def get_selected_models(*model_selections):
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
input_data['models'] = selected_models
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
# Add the Textbox to the interface
prompt = gr.TextArea(
label="Customise the prompt for selected models here or leave the default one.",
placeholder=prompt_default,
elem_id="custom-textarea"
)
warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
# Submit button (initially disabled)
submit_models_button = gr.Button("Submit Models", interactive=False)
def check_prompt(prompt):
#TODO
missing_elements = []
if(prompt==""):
input_data["prompt"]=prompt_default
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
else:
input_data["prompt"]=prompt
if "{db_schema}" not in prompt:
missing_elements.append("{db_schema}")
if "{question}" not in prompt:
missing_elements.append("{question}")
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
if missing_elements:
return gr.update(
value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
visible=True
), gr.update(interactive=button_state)
return gr.update(visible=False), gr.update(interactive=button_state)
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
# Link checkboxes to selection events
for checkbox in model_checkboxes:
checkbox.change(
fn=get_selected_models,
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, submit_models_button]
)
prompt.change(
fn=get_selected_models,
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, submit_models_button]
)
submit_models_button.click(
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
inputs=model_checkboxes,
outputs=[selected_models_output, select_model_acc, qatch_acc]
)
def enable_disable(enable):
return (
*[gr.update(interactive=enable) for _ in model_checkboxes],
gr.update(interactive=enable),
gr.update(interactive=enable),
gr.update(interactive=enable),
gr.update(interactive=enable),
gr.update(interactive=enable),
gr.update(interactive=enable),
*[gr.update(interactive=enable) for _ in table_outputs],
gr.update(interactive=enable)
)
reset_data = gr.Button("Back to upload data section")
submit_models_button.click(
fn=enable_disable,
inputs=[gr.State(False)],
outputs=[
*model_checkboxes,
submit_models_button,
preview_output,
submit_button,
file_input,
default_checkbox,
table_selector,
*table_outputs,
open_model_selection
]
)
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
reset_data.click(
fn=enable_disable,
inputs=[gr.State(True)],
outputs=[
*model_checkboxes,
submit_models_button,
preview_output,
submit_button,
file_input,
default_checkbox,
table_selector,
*table_outputs,
open_model_selection
]
)
#############################
# QATCH EXECUTION #
#############################
with qatch_acc:
def change_text(text):
return text
loading_symbols= {1:"π“†Ÿ",
2: "π“†ž π“†Ÿ",
3: "𓆛 π“†ž π“†Ÿ",
4: "π“†ž 𓆛 π“†ž π“†Ÿ",
5: "π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ",
6: "π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ",
7: "π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ",
8: "π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ",
9: "π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ",
10:"π“†ž π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ",
}
def generate_loading_text(percent):
num_symbols = (round(percent) % 11) + 1
symbols = loading_symbols.get(num_symbols, "π“†Ÿ")
mirrored_symbols = f'<span class="mirrored">{symbols.strip()}</span>'
css_symbols = f'<span class="fish">{symbols.strip()}</span>'
return f"""
<div class='barcontainer'>
{css_symbols}
<span class='loading' style="font-family: 'Inter', sans-serif;">
Generation {percent}%
</span>
{mirrored_symbols}
</div>
"""
def generate_eval_text(text):
symbols = "𓆑 "
mirrored_symbols = f'<span class="mirrored">{symbols.strip()}</span>'
css_symbols = f'<span class="fish">{symbols.strip()}</span>'
return f"""
<div class='barcontainer'>
{css_symbols}
<span class='loading' style="font-family: 'Inter', sans-serif;">
{text}
</span>
{mirrored_symbols}
</div>
"""
def qatch_flow():
#caching
global reset_flag
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
metrics_conc = pd.DataFrame()
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
if (input_data['input_method']=="default"):
target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv")
#predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
target_df = target_df[target_df["model"].isin(input_data['models'])]
predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
reset_flag = False
for model in input_data['models']:
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
count=1
for _, row in predictions_dict[model].iterrows():
#for index, row in target_df.iterrows():
if (reset_flag == False):
percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
count=count+1
load_text = f"{generate_loading_text(percent_complete)}"
question = row['question']
display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
<div style='display: flex; align-items: center;'>
<div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
<div style='font-size: 3rem'>➑️</div>
</div>
"""
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
#time.sleep(0.02)
prediction = row['predicted_sql']
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
<div style='display: flex; align-items: center;'>
<div style='font-size: 3rem'>➑️</div>
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
</div>
"""
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
metrics_conc = target_df
if 'valid_efficiency_score' not in metrics_conc.columns:
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
eval_text = generate_eval_text("End evaluation")
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
else:
orchestrator_generator = OrchestratorGenerator()
# TODO: add to target_df column target_df["columns_used"], tables selection
# print(input_data['data']['db'])
#print(input_data['data']['selected_tables'])
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
predictor = ModelPrediction()
reset_flag = False
for model in input_data["models"]:
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
count=0
for index, row in target_df.iterrows():
if (reset_flag == False):
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
load_text = f"{generate_loading_text(percent_complete)}"
question = row['question']
display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
<div style='display: flex; align-items: center;'>
<div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
<div style='font-size: 3rem'>➑️</div>
</div>
"""
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
start_time = time.time()
samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
db_id = input_data["db_name"],
base_path = input_data["data_path"],
normalize=False,
sql=row["query"],
get_insert_into=True
)
#prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
prompt_to_send = input_data["prompt"]
#PREDICTION SQL
# TODO add button for QA or SP and pass to .make_prediction parameter TASK
response = predictor.make_prediction(
question=question,
db_schema=schema_text,
model_name=model,
prompt=f"{prompt_to_send}",
task="SP" # TODO change accordingly
)
prediction = response['response_parsed']
price = response['cost']
answer = response['response']
end_time = time.time()
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
<div style='display: flex; align-items: center;'>
<div style='font-size: 3rem'>➑️</div>
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
</div>
"""
# Create a new row as dataframe
new_row = pd.DataFrame([{
'id': index,
'question': question,
'predicted_sql': prediction,
'time': end_time - start_time,
'query': row["query"],
'db_path': input_data["data_path"],
'price':price,
'answer':answer,
'number_question':count,
'prompt': prompt_to_send
}]).dropna(how="all") # Remove only completely empty rows
count=count+1
# TODO: use a for loop
for col in target_df.columns:
if col not in new_row.columns:
new_row[col] = row[col]
# Update model's prediction dataframe incrementally
if not new_row.empty:
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
# END
eval_text = generate_eval_text("Evaluation")
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
evaluator = OrchestratorEvaluator()
for model in input_data["models"]:
metrics_df_model = evaluator.evaluate_df(
df=predictions_dict[model],
target_col_name="query",
prediction_col_name="predicted_sql",
db_path_name="db_path"
)
metrics_df_model['model'] = model
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
if 'valid_efficiency_score' not in metrics_conc.columns:
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
eval_text = generate_eval_text("End evaluation")
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
# Loading Bar
with gr.Row():
# progress = gr.Progress()
variable = gr.Markdown()
# NL -> MODEL -> Generated Query
with gr.Row():
with gr.Column():
with gr.Column():
question_display = gr.Markdown()
with gr.Column():
model_logo = gr.Image(visible=True,
show_label=False,
container=False,
interactive=False,
show_fullscreen_button=False,
show_download_button=False,
show_share_button=False)
with gr.Column():
with gr.Column():
prediction_display = gr.Markdown()
dataframe_per_model = {}
with gr.Tabs() as model_tabs:
tab_dict = {}
# for model, model_name in zip(model_list, model_names):
# with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
# gr.Markdown(f"**Results for {model_name}**")
# tab_dict[model] = tab
# dataframe_per_model[model] = gr.DataFrame()
#model_mapping = dict(zip(model_list, model_names))
#model_mapping_reverse = dict(zip(model_names, model_list))
for model, model_name in zip(model_list, model_names):
with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
gr.Markdown(f"**Results for {model}**")
tab_dict[model] = tab
dataframe_per_model[model] = gr.DataFrame()
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
evaluation_loading = gr.Markdown()
def change_tab():
return [gr.update(visible=(model in input_data["models"])) for model in model_list]
submit_models_button.click(
change_tab,
inputs=[],
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
)
selected_models_display = gr.JSON(label="Final input data", visible=False)
metrics_df = gr.DataFrame(visible=False)
metrics_df_out = gr.DataFrame(visible=False)
submit_models_button.click(
fn=qatch_flow,
inputs=[],
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
)
submit_models_button.click(
fn=lambda: gr.update(value=input_data),
outputs=[selected_models_display]
)
# Works for METRICS
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False)
proceed_to_metrics_button.click(
fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
outputs=[qatch_acc, metrics_acc]
)
def allow_download(metrics_df_out):
#path = os.path.join(".", "data", "data_results", "results.csv")
path = os.path.join(".", "results.csv")
metrics_df_out.to_csv(path, index=False)
return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True)
download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False)
submit_models_button.click(
fn=lambda: gr.update(visible=False),
outputs=[download_metrics]
)
#TODO WHY?
# download_metrics.click(
# fn=lambda: gr.update(open=True, visible=True),
# outputs=[download_metrics]
# )
def refresh():
global reset_flag
reset_flag = True
reset_data = gr.Button("Back to upload data section", interactive=True)
metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
#WHY NOT WORKING?
reset_data.click(
fn=lambda: gr.update(visible=False),
outputs=[download_metrics]
)
reset_data.click(refresh)
reset_data.click(
fn=enable_disable,
inputs=[gr.State(True)],
outputs=[
*model_checkboxes,
submit_models_button,
preview_output,
submit_button,
file_input,
default_checkbox,
table_selector,
*table_outputs,
open_model_selection
]
)
##########################################
# METRICS VISUALIZATION SECTION #
##########################################
with metrics_acc:
#data_path = 'test_results_metrics1.csv'
@gr.render(inputs=metrics_df_out)
def function_metrics(metrics_df_out):
####################################
# UTILS FUNCTIONS SECTION #
####################################
def load_data_csv_es():
#return pd.read_csv(data_path)
#print("---------------->",metrics_df_out)
if input_data["input_method"]=="default":
df = pd.read_csv(pnp_path)
df = df[df['model'].isin(input_data["models"])]
df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B')
df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5')
df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini')
df['model'] = df['model'].replace('llama-70', 'Llama-70B')
df['model'] = df['model'].replace('llama-8', 'Llama-8B')
df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
return df
return metrics_df_out
def calculate_average_metrics(df, selected_metrics):
# Exclude the 'tuple_order' column from the selected metrics
#TODO tuple_order has NULL VALUE
selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order']
#print(df[selected_metrics])
df['avg_metric'] = df[selected_metrics].mean(axis=1)
return df
def generate_model_colors():
"""Generates a unique color map for models in the dataset."""
df = load_data_csv_es()
unique_models = df['model'].unique() # Extract unique models
num_models = len(unique_models)
# Use the Plotly color scale (you can change it if needed)
color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC']
#color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
# If there are more models than colors, cycle through them
colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
return colors
MODEL_COLORS = generate_model_colors()
def generate_db_category_colors():
"""Assigns 3 distinct colors to db_category groups."""
return {
"Spider": "#1f77b4", # blu
"Beaver": "#ff7f0e", # arancione
"Economic": "#2ca02c", # tutti gli altri verdi
"Financial": "#2ca02c",
"Medical": "#2ca02c",
"Miscellaneous": "#2ca02c"
}
DB_CATEGORY_COLORS = generate_db_category_colors()
def normalize_valid_efficiency_score(df):
#TODO valid_efficiency_score
#print(df['valid_efficiency_score'])
df['valid_efficiency_score'] = df['valid_efficiency_score'].replace([np.nan, ''], 0)
df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int)
min_val = df['valid_efficiency_score'].min()
max_val = df['valid_efficiency_score'].max()
if min_val == max_val:
# Tutti i valori sono uguali, assegna 1.0 a tutto per evitare divisione per zero
df['valid_efficiency_score'] = 1.0
else:
df['valid_efficiency_score'] = (
df['valid_efficiency_score'] - min_val
) / (max_val - min_val)
return df
####################################
# GRAPH FUNCTIONS SECTION #
####################################
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
# Mappatura nomi leggibili -> tecnici
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
df = calculate_average_metrics(df, selected_metrics)
if group_by == ["model"]:
# Bar plot per "model"
avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index()
avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
fig = px.bar(
avg_metrics,
x="model",
y="avg_metric",
color="model",
color_discrete_map=MODEL_COLORS,
title='Average metrics per Model 🧠',
labels={"model": "Model", "avg_metric": "Average Metrics"},
template='simple_white',
#template='plotly_dark',
text='text_label'
)
else:
if group_by != ["tbl_name", "model"]:
group_by = ["tbl_name", "model"]
avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
fig = px.bar(
avg_metrics,
x=group_by[0],
y='avg_metric',
color='model',
color_discrete_map=MODEL_COLORS,
barmode='group',
title=f'Average metrics per {group_by[0]} πŸ“Š',
labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'},
template='simple_white',
#template='plotly_dark',
text='text_label'
)
fig.update_traces(textposition='outside', textfont_size=10)
# Applica font Inter a tutto il layout
fig.update_layout(
margin=dict(t=80),
title=dict(
font=dict(
family="Inter, sans-serif",
size=22,
#color="white"
),
x=0.5
),
xaxis=dict(
title=dict(
font=dict(
family="Inter, sans-serif",
size=18,
#color="white"
)
),
tickfont=dict(
family="Inter, sans-serif",
#color="white"
size=16
)
),
yaxis=dict(
title=dict(
font=dict(
family="Inter, sans-serif",
size=18,
#color="white"
)
),
tickfont=dict(
family="Inter, sans-serif",
#color="white"
)
),
legend=dict(
title=dict(
font=dict(
family="Inter, sans-serif",
size=16,
#color="white"
)
),
font=dict(
family="Inter, sans-serif",
#color="white"
)
)
)
return gr.Plot(fig, visible=True)
def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models):
df = load_data_csv_es()
return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models)
# BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
if selected_models == "All":
selected_models = models
else:
selected_models = [selected_models]
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
# Converti nomi leggibili -> tecnici
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
df = calculate_average_metrics(df, selected_metrics)
avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
#MIAO
fig = px.bar(
avg_metrics,
x='db_category',
y='avg_metric',
color='model',
color_discrete_map=MODEL_COLORS,
barmode='group',
title='Average metrics per database types πŸ“Š',
labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'},
template='simple_white',
text='text_label'
)
fig.update_traces(textposition='outside', textfont_size=14)
# Aggiorna layout con font Inter
fig.update_layout(
margin=dict(t=80),
title=dict(
font=dict(
family="Inter, sans-serif",
size=24,
color="black"
),
x=0.5
),
xaxis=dict(
title=dict(
text='Database Category',
font=dict(
family='Inter, sans-serif',
size=22,
color='black'
)
),
tickfont=dict(
family='Inter, sans-serif',
color='black',
size=20
)
),
yaxis=dict(
title=dict(
text='Average Metrics',
font=dict(
family='Inter, sans-serif',
size=22,
color='black'
)
),
tickfont=dict(
family='Inter, sans-serif',
color='black'
)
),
legend=dict(
title=dict(
text='Models',
font=dict(
family='Inter, sans-serif',
size=20,
color='black'
)
),
font=dict(
family='Inter, sans-serif',
color='black',
size=18
)
)
)
return gr.Plot(fig, visible=True)
"""
def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
if selected_models == "All":
selected_models = models
else:
selected_models = [selected_models]
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
if radio_metric == "Qatch":
selected_metrics = qatch_selected_metrics
else:
selected_metrics = external_selected_metric
df = calculate_average_metrics(df, selected_metrics)
# Raggruppamento per modello e categoria
avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index()
avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
# Plot orizzontale con modello sull'asse Y
fig = px.bar(
avg_metrics,
x='avg_metric',
y='model',
color='db_category', # categoria come colore
text='text_label',
barmode='group',
orientation='h',
color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS
title='Average metric per model and db_category πŸ“Š',
labels={'avg_metric': 'AVG Metric', 'model': 'Model'},
template='plotly_dark'
)
fig.update_traces(textposition='outside', textfont_size=10)
fig.update_layout(
margin=dict(t=80),
yaxis=dict(title=''),
xaxis=dict(title='AVG Metrics'),
legend_title='DB Name',
height=600 # puoi aumentare se ci sono tanti modelli
)
return gr.Plot(fig, visible=True)
"""
def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
df = load_data_csv_es()
return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models)
# BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
def lollipop_propietary(selected_models):
df = load_data_csv_es()
# Filtra solo le categorie rilevanti
target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous"]
df = df[df['db_category'].isin(target_cats)]
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
df = calculate_average_metrics(df, qatch_metrics)
# Calcola la media per db_category e modello
avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
# Separa Spider e le altre 4 categorie
spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"]
other_df = avg_metrics[avg_metrics["db_category"] != "Spider"]
# Calcola media delle altre categorie per ciascun modello
other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index()
other_mean_df["db_category"] = "Others"
# Rinominare per chiarezza e uniformitΓ 
spider_df = spider_df.rename(columns={"avg_metric": "Spider"})
other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"})
# Unione dei due dataset
merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model")
# Ordina per modello o per valore se vuoi
merged_df = merged_df.sort_values(by="model")
fig = go.Figure()
# Aggiungi linee orizzontali tra Spider e Others
for _, row in merged_df.iterrows():
fig.add_trace(go.Scatter(
x=[row["Spider"], row["Others"]],
y=[row["model"]] * 2,
mode='lines',
line=dict(color='gray', width=2),
showlegend=False
))
# Punto per Spider
fig.add_trace(go.Scatter(
x=merged_df["Spider"],
y=merged_df["model"],
mode='markers',
name='Non-Proprietary (Spider)',
marker=dict(size=10, color='#C84630')
))
# Punto per Others (media delle altre 4 categorie)
fig.add_trace(go.Scatter(
x=merged_df["Others"],
y=merged_df["model"],
mode='markers',
name='Proprietary Databases',
marker=dict(size=10, color='#0077B6')
))
fig.update_layout(
xaxis_title='Average Metrics',
yaxis_title='Models',
template='simple_white',
#template='plotly_dark',
margin=dict(t=80),
title=dict(
font=dict(
family="Inter, sans-serif",
size=22,
color="black"
),
x=0.5,
text='Dumbbell graph: Non-Proprietary (Spider πŸ•·οΈ) vs Proprietary Databases πŸ“Š'
),
legend_title='Type of Databases:',
height=600,
xaxis=dict(
title=dict(
text='DB Category',
font=dict(
family='Inter, sans-serif',
size=18,
color='black'
)
),
tickfont=dict(
family='Inter, sans-serif',
color='black'
)
),
yaxis=dict(
title=dict(
text='Average Metrics',
font=dict(
family='Inter, sans-serif',
size=18,
color='black'
)
),
tickfont=dict(
family='Inter, sans-serif',
color='black'
)
),
legend=dict(
title=dict(
text='Models',
font=dict(
family='Inter, sans-serif',
size=18,
color='black'
)
),
font=dict(
family='Inter, sans-serif',
color='black',
size=14
)
)
)
return gr.Plot(fig, visible=True)
# RADAR OR BAR CHART BASED ON CATEGORY COUNT
def plot_radar(df, selected_models, selected_metrics, selected_categories):
if "external" in selected_metrics:
selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
else:
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
# Filtro modelli e normalizzazione
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
df = calculate_average_metrics(df, selected_metrics)
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
if avg_metrics.empty:
print("Error: No data available to compute averages.")
return go.Figure()
categories = selected_categories
if len(categories) < 3:
# πŸ”„ BAR PLOT
fig = go.Figure()
for model in selected_models:
model_data = avg_metrics[avg_metrics['model'] == model]
values = [
model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
if cat in model_data['test_category'].values else 0
for cat in categories
]
fig.add_trace(go.Bar(
x=categories,
y=values,
name=model,
marker=dict(color=MODEL_COLORS.get(model, "gray"))
))
fig.update_layout(
barmode='group',
title=dict(
text='πŸ“Š Bar Plot of Metrics per Model (Few Categories)',
font=dict(
family='Inter, sans-serif',
size=22,
#color='white'
),
x=0.5
),
template='simple_white',
#template='plotly_dark',
xaxis=dict(
title=dict(
text='Test Category',
font=dict(
family='Inter, sans-serif',
size=18,
#color='white'
)
),
tickfont=dict(
family='Inter, sans-serif',
size=16
#color='white'
)
),
yaxis=dict(
title=dict(
text='Average Metrics',
font=dict(
family='Inter, sans-serif',
size=18,
#color='white'
)
),
tickfont=dict(
family='Inter, sans-serif',
#color='white'
)
),
legend=dict(
title=dict(
text='Models',
font=dict(
family='Inter, sans-serif',
size=16,
#color='white'
)
),
font=dict(
family='Inter, sans-serif',
#color='white'
)
)
)
else:
# 🧭 RADAR PLOT
fig = go.Figure()
for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
model_data = avg_metrics[avg_metrics['model'] == model]
values = [
model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
if cat in model_data['test_category'].values else 0
for cat in categories
]
fig.add_trace(go.Scatterpolar(
r=values,
theta=categories,
fill='toself',
name=model,
line=dict(color=MODEL_COLORS.get(model, "gray"))
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
tickfont=dict(
family='Inter, sans-serif',
#color='white'
)
),
angularaxis=dict(
tickfont=dict(
family='Inter, sans-serif',
size=16
#color='white'
)
)
),
title=dict(
text='❇️ Radar Plot of Metrics per Model (Average per SQL Category)',
font=dict(
family='Inter, sans-serif',
size=22,
#color='white'
),
x=0.5
),
legend=dict(
title=dict(
text='Models',
font=dict(
family='Inter, sans-serif',
size=18,
#color='white'
)
),
font=dict(
family='Inter, sans-serif',
size=16
#color='white'
)
),
template='simple_white'
#template='plotly_dark'
)
return fig
def update_radar(selected_models, selected_metrics, selected_categories):
df = load_data_csv_es()
return plot_radar(df, selected_models, selected_metrics, selected_categories)
# RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
if "external" in selected_metrics:
selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
else:
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
df = calculate_average_metrics(df, selected_metrics)
if isinstance(selected_category, str):
selected_category = [selected_category]
df = df[df['test_category'].isin(selected_category)]
avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index()
if avg_metrics.empty:
print("Error: No data available to compute averages.")
return go.Figure()
categories = df['sql_tag'].unique().tolist()
if len(categories) < 3:
# πŸ”„ BAR PLOT
fig = go.Figure()
for model in selected_models:
model_data = avg_metrics[avg_metrics['model'] == model]
values = [
model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
if cat in model_data['sql_tag'].values else 0
for cat in categories
]
fig.add_trace(go.Bar(
x=categories,
y=values,
name=model,
marker=dict(color=MODEL_COLORS.get(model, "gray"))
))
fig.update_layout(
barmode='group',
title=dict(
text='πŸ“Š Bar Plot of Metrics per Model (Few Sub-Categories)',
font=dict(
family='Inter, sans-serif',
size=22,
#color='white'
),
x=0.5
),
template='simple_white',
#template='plotly_dark',
xaxis=dict(
title=dict(
text='SQL Tag (Sub Category)',
font=dict(
family='Inter, sans-serif',
size=18,
#color='white'
)
),
tickfont=dict(
family='Inter, sans-serif',
#color='white'
)
),
yaxis=dict(
title=dict(
text='Average Metrics',
font=dict(
family='Inter, sans-serif',
size=18,
#color='white'
)
),
tickfont=dict(
family='Inter, sans-serif',
#color='white'
)
),
legend=dict(
title=dict(
text='Models',
font=dict(
family='Inter, sans-serif',
size=16,
#color='white'
)
),
font=dict(
family='Inter, sans-serif',
size=14
#color='white'
)
)
)
else:
# 🧭 RADAR PLOT
fig = go.Figure()
for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
model_data = avg_metrics[avg_metrics['model'] == model]
values = [
model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
if cat in model_data['sql_tag'].values else 0
for cat in categories
]
fig.add_trace(go.Scatterpolar(
r=values,
theta=categories,
fill='toself',
name=model,
line=dict(color=MODEL_COLORS.get(model, "gray"))
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
tickfont=dict(
family='Inter, sans-serif',
#color='white'
)
),
angularaxis=dict(
tickfont=dict(
family='Inter, sans-serif',
size=16
#color='white'
)
)
),
title=dict(
text='❇️ Radar Plot of Metrics per Model (Average per SQL Sub-Category)',
font=dict(
family='Inter, sans-serif',
size=22,
#color='white'
),
x=0.5
),
legend=dict(
title=dict(
text='Models',
font=dict(
family='Inter, sans-serif',
size=16,
#color='white'
)
),
font=dict(
family='Inter, sans-serif',
size=14,
#color='white'
)
),
template='simple_white'
#template='plotly_dark'
)
return fig
def update_radar_sub(selected_models, selected_metrics, selected_category):
df = load_data_csv_es()
return plot_radar_sub(df, selected_models, selected_metrics, selected_category)
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
if selected_models == "All":
selected_models = models
else:
selected_models = [selected_models]
if selected_categories == "All":
selected_categories = principal_categories
else:
selected_categories = [selected_categories]
df = df[df['model'].isin(selected_models)]
df = df[df['test_category'].isin(selected_categories)]
if "external" in selected_metrics:
selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
else:
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
df = normalize_valid_efficiency_score(df)
df = calculate_average_metrics(df, selected_metrics)
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
worst_cases_top_3 = worst_cases_df.head(3)
worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
worst_str = []
answer_str = []
medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
for i, row in worst_cases_top_3.iterrows():
entry = (
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
)
worst_str.append(entry)
raw_answer = (
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
f"<span style='font-size:16px;'>- <b>Raw Answer:</b><br> `{row['answer']}`</span> \n"
)
answer_str.append(raw_answer)
return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
df = load_data_csv_es()
return worst_cases_text(df, selected_models, selected_metrics, selected_categories)
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
def plot_cumulative_flow(df, selected_models, max_points):
df = df[df['model'].isin(selected_models)]
df = normalize_valid_efficiency_score(df)
fig = go.Figure()
for model in selected_models:
model_df = df[df['model'] == model].copy()
# Limita il numero di punti se richiesto
if max_points is not None:
model_df = model_df.head(max_points + 1)
# Tooltip personalizzato
model_df['hover_info'] = model_df.apply(
lambda row:
f"<b>Id question</b>: {row['number_question']}<br>"
f"<b>Question</b>: {row['question']}<br>"
f"<b>Target</b>: {row['query']}<br>"
f"<b>Prediction</b>: {row['predicted_sql']}<br>"
f"<b>Category</b>: {row['test_category']}",
axis=1
)
# Calcoli cumulativi
model_df['cumulative_time'] = model_df['time'].cumsum()
model_df['cumulative_price'] = model_df['price'].cumsum()
# Colore del modello
color = MODEL_COLORS.get(model, "gray")
fig.add_trace(go.Scatter(
x=model_df['cumulative_time'],
y=model_df['cumulative_price'],
mode='lines+markers',
name=model,
line=dict(width=2, color=color),
customdata=model_df['hover_info'],
hovertemplate=
"<b>Model:</b> " + model + "<br>" +
"<b>Cumulative Time:</b> %{x}s<br>" +
"<b>Cumulative Price:</b> $%{y:.2f}<br>" +
"<br><b>Details:</b><br>%{customdata}<extra></extra>"
))
# Layout con font elegante
fig.update_layout(
title=dict(
text="Cumulative Price Flow Chart πŸ’°",
font=dict(
family="Inter, sans-serif",
size=24,
#color="white"
),
x=0.5
),
xaxis=dict(
title=dict(
text="Cumulative Time (s)",
font=dict(
family="Inter, sans-serif",
size=20,
#color="white"
)
),
tickfont=dict(
family="Inter, sans-serif",
size=18
#color="white"
)
),
yaxis=dict(
title=dict(
text="Cumulative Price ($)",
font=dict(
family="Inter, sans-serif",
size=20,
#color="white"
)
),
tickfont=dict(
family="Inter, sans-serif",
size=18
#color="white"
)
),
legend=dict(
title=dict(
text="Models",
font=dict(
family="Inter, sans-serif",
size=18,
#color="white"
)
),
font=dict(
family="Inter, sans-serif",
size=16,
#color="white"
)
),
template='simple_white',
#template="plotly_dark"
)
return fig
def update_query_rate(selected_models, max_points):
df = load_data_csv_es()
return plot_cumulative_flow(df, selected_models, max_points)
#######################
# PARAMETER SECTION #
#######################
qatch_metrics_dict = {
"Cell Precision": "cell_precision",
"Cell Recall": "cell_recall",
"Tuple Order": "tuple_order",
"Tuple Cardinality": "tuple_cardinality",
"Tuple Constraint": "tuple_constraint"
}
qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare l’ultima selezione valida
def enforce_qatch_metrics_selection(selected):
global last_valid_qatch_metrics_selection
if not selected: # Se nessuna metrica Γ¨ selezionata
return gr.update(value=last_valid_qatch_metrics_selection)
last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida
return gr.update(value=selected)
external_metrics_dict = {
"Execution Accuracy": "execution_accuracy",
"Valid Efficiency Score": "valid_efficiency_score"
}
external_metric = ["execution_accuracy", "valid_efficiency_score"]
last_valid_external_metric_selection = external_metric.copy()
def enforce_external_metric_selection(selected):
global last_valid_external_metric_selection
if not selected: # Se nessuna metrica Γ¨ selezionata
return gr.update(value=last_valid_external_metric_selection)
last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida
return gr.update(value=selected)
all_metrics = {
"Qatch": ["qatch"],
"External": ["external"]
}
group_options = {
"Table": ["tbl_name", "model"],
"Model": ["model"]
}
df_initial = load_data_csv_es()
models = models = df_initial['model'].unique().tolist()
last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida
def enforce_model_selection(selected):
global last_valid_model_selection
if not selected: # Se nessuna metrica Γ¨ selezionata
return gr.update(value=last_valid_model_selection)
last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida
return gr.update(value=selected)
all_categories = df_initial['sql_tag'].unique().tolist()
principal_categories = df_initial['test_category'].unique().tolist()
last_valid_category_selection = principal_categories.copy() # Per salvare l’ultima selezione valida
def enforce_category_selection(selected):
global last_valid_category_selection
if not selected: # Se nessuna metrica Γ¨ selezionata
return gr.update(value=last_valid_category_selection)
last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida
return gr.update(value=selected)
all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories}
all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories}
all_categories_as_dic_ranking["All"] = principal_categories
all_model_as_dic = {cat: [f"{cat}"] for cat in models}
all_model_as_dic["All"] = models
#with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
###########################
# VISUALIZATION SECTION #
###########################
gr.Markdown("""# Model Performance Analysis""")
#FOR BAR
gr.Markdown("""## Section 1: Model - Data""")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
choose_metrics_bar = gr.Radio(
choices=list(all_metrics.keys()),
label="Select the metrics group that you want to use:",
value="Qatch"
)
with gr.Row():
qatch_info = gr.HTML("""
<div style='display: flex; align-items: center; margin-top: -8px; margin-bottom: 12px;'>
<span
title="Qatch metric info:
Cell Precision: Fraction of predicted table cells also in the ground truth result. High means many correct predictions.
Cell Recall: Fraction of ground truth cells retrieved by the prediction. High means relevant cells were captured.
Tuple Constraint: Fraction of ground truth tuples matched exactly in output (schema, values, cardinality).
Tuple Cardinality: Ratio of predicted to ground truth tuples. Checks only tuple count.
Tuple Order: Spearman correlation between predicted and ground truth tuple ranks."
style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
>Qatch metric info ℹ️</span>
</div>
""", visible=True)
external_info = gr.HTML("""
<div style='display: flex; align-items: center; margin-top: -8px; margin-bottom: 12px;'>
<span
title="External metric info:
Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
Valid Efficiency Score: Evaluates the efficiency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast."
style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
>External metric info ℹ️</span>
</div>
""", visible=False)
qatch_metric_multiselect_bar = gr.CheckboxGroup(
choices=list(qatch_metrics_dict.keys()),
label="Select one or mode Qatch metrics:",
value=list(qatch_metrics_dict.keys()),
visible=True
)
external_metric_select_bar = gr.CheckboxGroup(
choices=list(external_metrics_dict.keys()),
label="Select one or more External metrics:",
visible=False
)
if(input_data['input_method'] == 'default'):
model_radio_bar = gr.Radio(
choices=list(all_model_as_dic.keys()),
label="Select the model that you want to use:",
value="All"
)
else:
model_multiselect_bar = gr.CheckboxGroup(
choices=models,
label="Select one or more models:",
value=models,
interactive=len(models) > 1
)
group_radio = gr.Radio(
choices=list(group_options.keys()),
label="Select the grouping view:",
value="Table"
)
def toggle_metric_selector(selected_type):
if selected_type == "Qatch":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
else:
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
output_plot = gr.Plot(visible=False)
if(input_data['input_method'] == 'default'):
with gr.Row():
lollipop_propietary(models)
#FOR RADAR
gr.Markdown("""## Section 2: Model - Category""")
with gr.Row():
all_metrics_radar = gr.Radio(
choices=list(all_metrics.keys()),
label="Select the metrics group that you want to use:",
value="Qatch"
)
model_multiselect_radar = gr.CheckboxGroup(
choices=models,
label="Select one or more models:",
value=models,
interactive=len(models) > 1
)
with gr.Row():
with gr.Column(scale=1):
category_multiselect_radar = gr.CheckboxGroup(
choices=principal_categories,
label="Select one or more categories:",
value=principal_categories
)
with gr.Column(scale=1):
category_radio_radar = gr.Radio(
choices=list(all_categories_as_dic.keys()),
label="Select the metrics that you want to use:",
value=list(all_categories_as_dic.keys())[0]
)
with gr.Row():
with gr.Column(scale=1):
radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories))
with gr.Column(scale=1):
radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0]))
#FOR RANKING
with gr.Row():
all_metrics_ranking = gr.Radio(
choices=list(all_metrics.keys()),
label="Select the metrics group that you want to use:",
value="Qatch"
)
model_choices = list(all_model_as_dic.keys())
if len(model_choices) == 2:
model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione
selected_value = model_choices[0]
else:
selected_value = "All"
model_radio_ranking = gr.Radio(
choices=model_choices,
label="Select the model that you want to use:",
value=selected_value
)
category_radio_ranking = gr.Radio(
choices=list(all_categories_as_dic_ranking.keys()),
label="Select the category that you want to use",
value="All"
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## ❌ 3 Worst Cases\n")
worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All")
with gr.Row():
first = gr.Markdown(worst_first)
with gr.Row():
first_button = gr.Button("Show raw answer for πŸ₯‡")
with gr.Row():
second = gr.Markdown(worst_second)
with gr.Row():
second_button = gr.Button("Show raw answer for πŸ₯ˆ")
with gr.Row():
third = gr.Markdown(worst_third)
with gr.Row():
third_button = gr.Button("Show raw answer for πŸ₯‰")
with gr.Column(scale=1):
gr.Markdown("""## Raw Answer""")
row_answer_first = gr.Markdown(value=raw_first, visible=True)
row_answer_second = gr.Markdown(value=raw_second, visible=False)
row_answer_third = gr.Markdown(value=raw_third, visible=False)
#FOR RATE
gr.Markdown("""## Section 3: Time - Price""")
with gr.Row():
model_multiselect_rate = gr.CheckboxGroup(
choices=models,
label="Select one or more models:",
value=models,
interactive=len(models) > 1
)
with gr.Row():
slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider")
query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
#FOR RESET
reset_data = gr.Button("Back to upload data section")
###############################
# CALLBACK FUNCTION SECTION #
###############################
#FOR BAR
def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models):
return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models)
def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models):
return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models)
#FOR RADAR
def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories):
return update_radar(selected_models, selected_metrics, selected_categories)
def on_radar_radio_change(selected_models, selected_metrics, selected_category):
return update_radar_sub(selected_models, selected_metrics, selected_category)
#FOR RANKING
def on_ranking_change(selected_models, selected_metrics, selected_categories):
return update_worst_cases_text(selected_models, selected_metrics, selected_categories)
def show_first():
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False)
)
def show_second():
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False)
)
def show_third():
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True)
)
######################
# ON CLICK SECTION #
######################
#FOR BAR
if(input_data['input_method'] == 'default'):
proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
else:
proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
#FOR RADAR MULTISELECT
model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar)
category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar)
#FOR RADAR RADIO
model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
#FOR RANKING
model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking)
category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking)
first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third])
third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third])
#FOR RATE
model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate)
slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
#FOR RESET
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
interface.launch(share = True)