import os import sys import time import re import csv import gradio as gr import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go import plotly.colors as pc from qatch.connectors.sqlite_connector import SqliteConnector from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator from prediction import ModelPrediction import utils_get_db_tables_info import utilities as us # @spaces.GPU # def model_prediction(): # pass # # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 # if os.environ.get("SPACES_ZERO_GPU") is not None: # import spaces # else: # class spaces: # @staticmethod # def GPU(func): # def wrapper(*args, **kwargs): # return func(*args, **kwargs) # return wrapper pnp_path = "evaluation_p_np_metrics.csv" js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } """ reset_flag=False with open('style.css', 'r') as file: css = file.read() # DataFrame di default df_default = pd.DataFrame({ 'Name': ['Alice', 'Bob', 'Charlie'], 'Age': [25, 30, 35], 'City': ['New York', 'Los Angeles', 'Chicago'] }) models_path ="models.csv" # Variabile globale per tenere traccia dei dati correnti df_current = df_default.copy() description = """## πŸ“Š Comparison of Proprietary and Non-Proprietary Databases ### ➀ **Proprietary** (πŸ’° Economic, πŸ₯ Medical, πŸ’³ Financial, πŸ“‚ Miscellaneous) ### ➀ **Non-Proprietary** (πŸ•·οΈ Spider 1.0)""" prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n" input_data = { 'input_method': "", 'data_path': "", 'db_name': "", 'data': { 'data_frames': {}, # dictionary of dataframes 'db': None, # SQLITE3 database object 'selected_tables' :[] }, 'models': [], 'prompt': prompt_default } def load_data(file, path, use_default): """Carica i dati da un file, un percorso o usa il DataFrame di default.""" global df_current if file is not None: try: input_data["input_method"] = 'uploaded_file' input_data["db_name"] = os.path.splitext(os.path.basename(file))[0] if file.endswith('.sqlite'): #return 'Error: The uploaded file is not a valid SQLite database.' input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite") else: #change path input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite") input_data["data"] = us.load_data(file, input_data["db_name"]) df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files table2primary_key = {} for table_name, df in input_data["data"]['data_frames'].items(): # Assign primary keys for each table table2primary_key[table_name] = 'id' input_data["data"]["db"] = SqliteConnector( relative_db_path=input_data["data_path"], db_name=input_data["db_name"], tables= input_data["data"]['data_frames'], table2primary_key=table2primary_key ) return input_data["data"]['data_frames'] except Exception as e: return f'Errore nel caricamento del file: {e}' if use_default: if(use_default == 'Custom'): input_data["input_method"] = 'custom' #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite") input_data["data_path"] = os.path.join(".","mytable_0.sqlite") #if file already exist while os.path.exists(input_data["data_path"]): input_data["data_path"] = us.increment_filename(input_data["data_path"]) input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0] input_data["data"]['data_frames'] = {'MyTable': df_current} if(input_data["data"]['data_frames']): table2primary_key = {} for table_name, df in input_data["data"]['data_frames'].items(): # Assign primary keys for each table table2primary_key[table_name] = 'id' input_data["data"]["db"] = SqliteConnector( relative_db_path=input_data["data_path"], db_name=input_data["db_name"], tables= input_data["data"]['data_frames'], table2primary_key=table2primary_key ) df_current = df_default.copy() # Ripristina i dati di default return input_data["data"]['data_frames'] if(use_default == 'Proprietary vs Non-proprietary'): input_data["input_method"] = 'default' #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite") #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite") #input_data["db_name"] = "default" #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"]) input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl('tables_dict.pkl') return input_data["data"]['data_frames'] selected_inputs = sum([file is not None, bool(path), use_default]) if selected_inputs > 1: return 'Error: Select only one input method at a time.' return input_data["data"]['data_frames'] def preview_default(use_default, file): if file: return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## βœ… File successfully uploaded!", visible=True) else : if use_default == 'Custom': return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## πŸ“ Toy Table", visible=True) else: return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True) #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato def update_df(new_df): """Aggiorna il DataFrame corrente.""" global df_current # Usa la variabile globale per aggiornarla df_current = new_df return df_current def open_accordion(target): # Apre uno e chiude l'altro if target == "reset": df_current = df_default.copy() input_data['input_method'] = "" input_data['data_path'] = "" input_data['db_name'] = "" input_data['data']['data_frames'] = {} input_data['data']['selected_tables'] = [] input_data['data']['db'] = None input_data['models'] = [] return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None) elif target == "model_selection": return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False) # Interfaccia Gradio #with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface: with gr.Row(): with gr.Column(scale=1): gr.Image( value=os.path.join(".", "qatch_logo.png"), show_label=False, container=False, interactive=False, show_fullscreen_button=False, show_download_button=False, show_share_button=False, height=150, # in pixel width=300 ) with gr.Column(scale=1): pass data_state = gr.State(None) # Memorizza i dati caricati upload_acc = gr.Accordion("Upload data section", open=True, visible=True) select_table_acc = gr.Accordion("Select tables section", open=False, visible=False) select_model_acc = gr.Accordion("Select models section", open=False, visible=False) qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False) metrics_acc = gr.Accordion("Metrics section", open=False, visible=False) ################################# # DATABASE INSERTION # ################################# with upload_acc: gr.Markdown("## πŸ“₯Choose data input method") with gr.Row(): default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary') #default_checkbox = gr.Checkbox(label="Use default DataFrame" table_default = gr.Markdown(description, visible=True) preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default) gr.Markdown("## πŸ“‚ Or upload your data") file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"]) submit_button = gr.Button("Load Data") # Disabled by default output = gr.JSON(visible=False) # Dictionary output # Function to enable the button if there is data to load def enable_submit(file, use_default): return gr.update(interactive=bool(file or use_default)) # Function to uncheck the checkbox if a file is uploaded def deselect_default(file): if file: return gr.update(value='Proprietary vs Non-proprietary') return gr.update() def enable_disable_first(enable): return ( gr.update(interactive=enable), gr.update(interactive=enable), gr.update(interactive=enable), gr.update(interactive=enable) ) # Enable the button when inputs are provided #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) # Show preview of the default DataFrame when checkbox is selected default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output]) # Uncheck the checkbox when a file is uploaded file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox]) def handle_output(file, use_default): """Handles the output when the 'Load Data' button is pressed.""" result = load_data(file, None, use_default) if isinstance(result, dict): # If result is a dictionary of DataFrames if len(result) == 1: # If there's only one table input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys()) return ( gr.update(visible=False), # Hide JSON output result, # Save the data state gr.update(visible=False), # Hide table selection result, # Maintain the data state gr.update(interactive=False), # Disable the submit button gr.update(visible=True, open=True), # Proceed to select_model_acc gr.update(visible=True, open=False) ) else: return ( gr.update(visible=False), result, gr.update(open=True, visible=True), result, gr.update(interactive=False), gr.update(visible=False), # Keep current behavior gr.update(visible=True, open=False) ) else: return ( gr.update(visible=False), None, gr.update(open=False, visible=True), None, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=True, open=False) ) submit_button.click( fn=handle_output, inputs=[file_input, default_checkbox], outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc] ) submit_button.click( fn=enable_disable_first, inputs=[gr.State(False)], outputs=[ preview_output, submit_button, file_input, default_checkbox ] ) ###################################### # TABLE SELECTION PART # ###################################### with select_table_acc: previous_selection = gr.State([]) table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[]) excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False) table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)] selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False) # Model selection button (initially disabled) open_model_selection = gr.Button("Choose your models", interactive=False) def update_table_list(data): """Dynamically updates the list of available tables and excluded ones.""" if isinstance(data, dict) and data: table_names = [] excluded_tables = [] data_frames = input_data['data'].get('data_frames', {}) available_tables = [] for name, df in data.items(): df_real = data_frames.get(name, None) if df_real is not None and df_real.shape[1] > 15: excluded_tables.append(name) else: available_tables.append(name) if input_data['input_method'] == "default" or len(available_tables) < 6: table_names.append("All") table_names.extend(available_tables) # Prepara il testo da mostrare if excluded_tables: excluded_text = "⚠️ The following tables have more than 15 columns and cannot be selected:
" + "
".join(f"- {t}" for t in excluded_tables) excluded_visible = True else: excluded_text = "" excluded_visible = False return [ gr.update(choices=table_names, value=[]), # CheckboxGroup update gr.update(value=excluded_text, visible=excluded_visible) # HTML display update ] return [ gr.update(choices=[], value=[]), gr.update(value="", visible=False) ] def show_selected_tables(data, selected_tables): updates = [] data_frames = input_data['data'].get('data_frames', {}) available_tables = [] for name, df in data.items(): df_real = data_frames.get(name) if df_real is not None and df_real.shape[1] <= 15: available_tables.append(name) input_method = input_data['input_method'] allow_all = input_method == "default" or len(available_tables) < 6 selected_set = set(selected_tables) tables_set = set(available_tables) if allow_all: if "All" in selected_set: selected_tables = ["All"] + available_tables elif selected_set == tables_set: selected_tables = [] else: selected_tables = [t for t in selected_tables if t in available_tables] else: selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5] tables = {name: data[name] for name in selected_tables if name in data} for i, (name, df) in enumerate(tables.items()): updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False)) for _ in range(len(tables), 50): updates.append(gr.update(visible=False)) updates.append(gr.update(interactive=bool(tables))) if allow_all: updates.insert(0, gr.update( choices=["All"] + available_tables, value=selected_tables )) else: if len(selected_tables) >= 5: updates.insert(0, gr.update( choices=selected_tables, value=selected_tables )) else: updates.insert(0, gr.update( choices=available_tables, value=selected_tables )) return updates def show_selected_table_names(data, selected_tables): """Displays the names of the selected tables when the button is pressed.""" if selected_tables: available_tables = list(data.keys()) # Actually available names if "All" in selected_tables: selected_tables = available_tables input_data['data']['selected_tables'] = selected_tables return gr.update(value=", ".join(selected_tables), visible=False) return gr.update(value="", visible=False) # Automatically updates the checkbox list when `data_state` changes data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info]) # Updates the visible tables and the button state based on user selections #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection]) table_selector.change( fn=show_selected_tables, inputs=[data_state, table_selector], outputs=[table_selector] + table_outputs + [open_model_selection] ) # Shows the list of selected tables when "Choose your models" is clicked open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names]) open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc]) reset_data = gr.Button("Back to upload data section") reset_data.click( fn=enable_disable_first, inputs=[gr.State(True)], outputs=[ preview_output, submit_button, file_input, default_checkbox ] ) reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) #################################### # MODEL SELECTION PART # #################################### with select_model_acc: gr.Markdown("# Model Selection") # Assume that `us.read_models_csv` also returns the image path model_list_dict = us.read_models_csv(models_path) model_list = [model["code"] for model in model_list_dict] model_images = [model["image_path"] for model in model_list_dict] model_names = [model["name"] for model in model_list_dict] # Create a mapping between model_list and model_images_names model_mapping = dict(zip(model_list, model_names)) model_mapping_reverse = dict(zip(model_names, model_list)) model_checkboxes = [] rows = [] # Dynamically create checkboxes with images (3 per row) for i in range(0, len(model_list), 3): with gr.Row(): cols = [] for j in range(3): if i + j < len(model_list): model = model_list[i + j] image_path = model_images[i + j] with gr.Column(): gr.Image(image_path, show_label=False, container=False, interactive=False, show_fullscreen_button=False, show_download_button=False, show_share_button=False) checkbox = gr.Checkbox(label=model_mapping[model], value=False) model_checkboxes.append(checkbox) cols.append(checkbox) rows.append(cols) selected_models_output = gr.JSON(visible=False) # Function to get selected models def get_selected_models(*model_selections): selected_models = [model for model, selected in zip(model_list, model_selections) if selected] input_data['models'] = selected_models button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state) # Add the Textbox to the interface prompt = gr.TextArea( label="Customise the prompt for selected models here or leave the default one.", placeholder=prompt_default, elem_id="custom-textarea" ) warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False) # Submit button (initially disabled) submit_models_button = gr.Button("Submit Models", interactive=False) def check_prompt(prompt): #TODO missing_elements = [] if(prompt==""): input_data["prompt"]=prompt_default button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) else: input_data["prompt"]=prompt if "{db_schema}" not in prompt: missing_elements.append("{db_schema}") if "{question}" not in prompt: missing_elements.append("{question}") button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) if missing_elements: return gr.update( value=f"
" f"❌ Missing {', '.join(missing_elements)} in the prompt ❌
", visible=True ), gr.update(interactive=button_state) return gr.update(visible=False), gr.update(interactive=button_state) prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button]) # Link checkboxes to selection events for checkbox in model_checkboxes: checkbox.change( fn=get_selected_models, inputs=model_checkboxes, outputs=[selected_models_output, select_model_acc, submit_models_button] ) prompt.change( fn=get_selected_models, inputs=model_checkboxes, outputs=[selected_models_output, select_model_acc, submit_models_button] ) submit_models_button.click( fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)), inputs=model_checkboxes, outputs=[selected_models_output, select_model_acc, qatch_acc] ) def enable_disable(enable): return ( *[gr.update(interactive=enable) for _ in model_checkboxes], gr.update(interactive=enable), gr.update(interactive=enable), gr.update(interactive=enable), gr.update(interactive=enable), gr.update(interactive=enable), gr.update(interactive=enable), *[gr.update(interactive=enable) for _ in table_outputs], gr.update(interactive=enable) ) reset_data = gr.Button("Back to upload data section") submit_models_button.click( fn=enable_disable, inputs=[gr.State(False)], outputs=[ *model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection ] ) reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) reset_data.click( fn=enable_disable, inputs=[gr.State(True)], outputs=[ *model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection ] ) ############################# # QATCH EXECUTION # ############################# with qatch_acc: def change_text(text): return text loading_symbols= {1:"π“†Ÿ", 2: "π“†ž π“†Ÿ", 3: "𓆛 π“†ž π“†Ÿ", 4: "π“†ž 𓆛 π“†ž π“†Ÿ", 5: "π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", 6: "π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", 7: "π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", 8: "π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", 9: "π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", 10:"π“†ž π“†Ÿ π“†ž π“†œ π“†ž π“†Ÿ π“†ž 𓆛 π“†ž π“†Ÿ", } def generate_loading_text(percent): num_symbols = (round(percent) % 11) + 1 symbols = loading_symbols.get(num_symbols, "π“†Ÿ") mirrored_symbols = f'{symbols.strip()}' css_symbols = f'{symbols.strip()}' return f"""
{css_symbols} Generation {percent}% {mirrored_symbols}
""" def generate_eval_text(text): symbols = "𓆑 " mirrored_symbols = f'{symbols.strip()}' css_symbols = f'{symbols.strip()}' return f"""
{css_symbols} {text} {mirrored_symbols}
""" def qatch_flow(): #caching global reset_flag predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list} metrics_conc = pd.DataFrame() columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"] if (input_data['input_method']=="default"): target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv") #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list} target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])] target_df = target_df[target_df["model"].isin(input_data['models'])] predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list} reset_flag = False for model in input_data['models']: model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] count=1 for _, row in predictions_dict[model].iterrows(): #for index, row in target_df.iterrows(): if (reset_flag == False): percent_complete = round(count / len(predictions_dict[model]) * 100, 2) count=count+1 load_text = f"{generate_loading_text(percent_complete)}" question = row['question'] display_question = f"""
Natural Language:
{question}
➑️
""" yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] #time.sleep(0.02) prediction = row['predicted_sql'] display_prediction = f"""
Predicted SQL:
➑️
{prediction}
""" yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] metrics_conc = target_df if 'valid_efficiency_score' not in metrics_conc.columns: metrics_conc['valid_efficiency_score'] = metrics_conc['VES'] eval_text = generate_eval_text("End evaluation") yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] else: orchestrator_generator = OrchestratorGenerator() # TODO: add to target_df column target_df["columns_used"], tables selection # print(input_data['data']['db']) #print(input_data['data']['selected_tables']) target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables']) #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None) predictor = ModelPrediction() reset_flag = False for model in input_data["models"]: model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] count=0 for index, row in target_df.iterrows(): if (reset_flag == False): percent_complete = round(((index+1) / len(target_df)) * 100, 2) load_text = f"{generate_loading_text(percent_complete)}" question = row['question'] display_question = f"""
Natural Language:
{question}
➑️
""" yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list] start_time = time.time() samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"]) schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string( db_id = input_data["db_name"], base_path = input_data["data_path"], normalize=False, sql=row["query"], get_insert_into=True ) #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples) prompt_to_send = input_data["prompt"] #PREDICTION SQL # TODO add button for QA or SP and pass to .make_prediction parameter TASK response = predictor.make_prediction( question=question, db_schema=schema_text, model_name=model, prompt=f"{prompt_to_send}", task="SP" # TODO change accordingly ) prediction = response['response_parsed'] price = response['cost'] answer = response['response'] end_time = time.time() display_prediction = f"""
Predicted SQL:
➑️
{prediction}
""" # Create a new row as dataframe new_row = pd.DataFrame([{ 'id': index, 'question': question, 'predicted_sql': prediction, 'time': end_time - start_time, 'query': row["query"], 'db_path': input_data["data_path"], 'price':price, 'answer':answer, 'number_question':count, 'prompt': prompt_to_send }]).dropna(how="all") # Remove only completely empty rows count=count+1 # TODO: use a for loop for col in target_df.columns: if col not in new_row.columns: new_row[col] = row[col] # Update model's prediction dataframe incrementally if not new_row.empty: predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True) # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list] yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] # END eval_text = generate_eval_text("Evaluation") yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] evaluator = OrchestratorEvaluator() for model in input_data["models"]: metrics_df_model = evaluator.evaluate_df( df=predictions_dict[model], target_col_name="query", prediction_col_name="predicted_sql", db_path_name="db_path" ) metrics_df_model['model'] = model metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True) if 'valid_efficiency_score' not in metrics_conc.columns: metrics_conc['valid_efficiency_score'] = metrics_conc['VES'] eval_text = generate_eval_text("End evaluation") yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] # Loading Bar with gr.Row(): # progress = gr.Progress() variable = gr.Markdown() # NL -> MODEL -> Generated Query with gr.Row(): with gr.Column(): with gr.Column(): question_display = gr.Markdown() with gr.Column(): model_logo = gr.Image(visible=True, show_label=False, container=False, interactive=False, show_fullscreen_button=False, show_download_button=False, show_share_button=False) with gr.Column(): with gr.Column(): prediction_display = gr.Markdown() dataframe_per_model = {} with gr.Tabs() as model_tabs: tab_dict = {} # for model, model_name in zip(model_list, model_names): # with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab: # gr.Markdown(f"**Results for {model_name}**") # tab_dict[model] = tab # dataframe_per_model[model] = gr.DataFrame() #model_mapping = dict(zip(model_list, model_names)) #model_mapping_reverse = dict(zip(model_names, model_list)) for model, model_name in zip(model_list, model_names): with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab: gr.Markdown(f"**Results for {model}**") tab_dict[model] = tab dataframe_per_model[model] = gr.DataFrame() # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False) evaluation_loading = gr.Markdown() def change_tab(): return [gr.update(visible=(model in input_data["models"])) for model in model_list] submit_models_button.click( change_tab, inputs=[], outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility ) selected_models_display = gr.JSON(label="Final input data", visible=False) metrics_df = gr.DataFrame(visible=False) metrics_df_out = gr.DataFrame(visible=False) submit_models_button.click( fn=qatch_flow, inputs=[], outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values()) ) submit_models_button.click( fn=lambda: gr.update(value=input_data), outputs=[selected_models_display] ) # Works for METRICS metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out]) proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False) proceed_to_metrics_button.click( fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)), outputs=[qatch_acc, metrics_acc] ) def allow_download(metrics_df_out): #path = os.path.join(".", "data", "data_results", "results.csv") path = os.path.join(".", "results.csv") metrics_df_out.to_csv(path, index=False) return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True) download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False) submit_models_button.click( fn=lambda: gr.update(visible=False), outputs=[download_metrics] ) #TODO WHY? # download_metrics.click( # fn=lambda: gr.update(open=True, visible=True), # outputs=[download_metrics] # ) def refresh(): global reset_flag reset_flag = True reset_data = gr.Button("Back to upload data section", interactive=True) metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data]) reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) #WHY NOT WORKING? reset_data.click( fn=lambda: gr.update(visible=False), outputs=[download_metrics] ) reset_data.click(refresh) reset_data.click( fn=enable_disable, inputs=[gr.State(True)], outputs=[ *model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection ] ) ########################################## # METRICS VISUALIZATION SECTION # ########################################## with metrics_acc: #data_path = 'test_results_metrics1.csv' @gr.render(inputs=metrics_df_out) def function_metrics(metrics_df_out): #################################### # UTILS FUNCTIONS SECTION # #################################### def load_data_csv_es(): #return pd.read_csv(data_path) #print("---------------->",metrics_df_out) if input_data["input_method"]=="default": df = pd.read_csv(pnp_path) df = df[df['model'].isin(input_data["models"])] df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B') df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5') df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini') df['model'] = df['model'].replace('llama-70', 'Llama-70B') df['model'] = df['model'].replace('llama-8', 'Llama-8B') df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY') return df return metrics_df_out def calculate_average_metrics(df, selected_metrics): # Exclude the 'tuple_order' column from the selected metrics #TODO tuple_order has NULL VALUE selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order'] #print(df[selected_metrics]) df['avg_metric'] = df[selected_metrics].mean(axis=1) return df def generate_model_colors(): """Generates a unique color map for models in the dataset.""" df = load_data_csv_es() unique_models = df['model'].unique() # Extract unique models num_models = len(unique_models) # Use the Plotly color scale (you can change it if needed) color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC'] #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...] # If there are more models than colors, cycle through them colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)} return colors MODEL_COLORS = generate_model_colors() def generate_db_category_colors(): """Assigns 3 distinct colors to db_category groups.""" return { "Spider": "#1f77b4", # blu "Beaver": "#ff7f0e", # arancione "Economic": "#2ca02c", # tutti gli altri verdi "Financial": "#2ca02c", "Medical": "#2ca02c", "Miscellaneous": "#2ca02c" } DB_CATEGORY_COLORS = generate_db_category_colors() def normalize_valid_efficiency_score(df): #TODO valid_efficiency_score #print(df['valid_efficiency_score']) df['valid_efficiency_score'] = df['valid_efficiency_score'].replace([np.nan, ''], 0) df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int) min_val = df['valid_efficiency_score'].min() max_val = df['valid_efficiency_score'].max() if min_val == max_val: # Tutti i valori sono uguali, assegna 1.0 a tutto per evitare divisione per zero df['valid_efficiency_score'] = 1.0 else: df['valid_efficiency_score'] = ( df['valid_efficiency_score'] - min_val ) / (max_val - min_val) return df #################################### # GRAPH FUNCTIONS SECTION # #################################### # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models): df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) # Mappatura nomi leggibili -> tecnici qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal df = calculate_average_metrics(df, selected_metrics) if group_by == ["model"]: # Bar plot per "model" avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index() avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') fig = px.bar( avg_metrics, x="model", y="avg_metric", color="model", color_discrete_map=MODEL_COLORS, title='Average metrics per Model 🧠', labels={"model": "Model", "avg_metric": "Average Metrics"}, template='simple_white', #template='plotly_dark', text='text_label' ) else: if group_by != ["tbl_name", "model"]: group_by = ["tbl_name", "model"] avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index() avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') fig = px.bar( avg_metrics, x=group_by[0], y='avg_metric', color='model', color_discrete_map=MODEL_COLORS, barmode='group', title=f'Average metrics per {group_by[0]} πŸ“Š', labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'}, template='simple_white', #template='plotly_dark', text='text_label' ) fig.update_traces(textposition='outside', textfont_size=10) # Applica font Inter a tutto il layout fig.update_layout( margin=dict(t=80), title=dict( font=dict( family="Inter, sans-serif", size=22, #color="white" ), x=0.5 ), xaxis=dict( title=dict( font=dict( family="Inter, sans-serif", size=18, #color="white" ) ), tickfont=dict( family="Inter, sans-serif", #color="white" size=16 ) ), yaxis=dict( title=dict( font=dict( family="Inter, sans-serif", size=18, #color="white" ) ), tickfont=dict( family="Inter, sans-serif", #color="white" ) ), legend=dict( title=dict( font=dict( family="Inter, sans-serif", size=16, #color="white" ) ), font=dict( family="Inter, sans-serif", #color="white" ) ) ) return gr.Plot(fig, visible=True) def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models): df = load_data_csv_es() return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models) # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): if selected_models == "All": selected_models = models else: selected_models = [selected_models] df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) # Converti nomi leggibili -> tecnici qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal df = calculate_average_metrics(df, selected_metrics) avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') #MIAO fig = px.bar( avg_metrics, x='db_category', y='avg_metric', color='model', color_discrete_map=MODEL_COLORS, barmode='group', title='Average metrics per database types πŸ“Š', labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'}, template='simple_white', text='text_label' ) fig.update_traces(textposition='outside', textfont_size=14) # Aggiorna layout con font Inter fig.update_layout( margin=dict(t=80), title=dict( font=dict( family="Inter, sans-serif", size=24, color="black" ), x=0.5 ), xaxis=dict( title=dict( text='Database Category', font=dict( family='Inter, sans-serif', size=22, color='black' ) ), tickfont=dict( family='Inter, sans-serif', color='black', size=20 ) ), yaxis=dict( title=dict( text='Average Metrics', font=dict( family='Inter, sans-serif', size=22, color='black' ) ), tickfont=dict( family='Inter, sans-serif', color='black' ) ), legend=dict( title=dict( text='Models', font=dict( family='Inter, sans-serif', size=20, color='black' ) ), font=dict( family='Inter, sans-serif', color='black', size=18 ) ) ) return gr.Plot(fig, visible=True) """ def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): if selected_models == "All": selected_models = models else: selected_models = [selected_models] df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) if radio_metric == "Qatch": selected_metrics = qatch_selected_metrics else: selected_metrics = external_selected_metric df = calculate_average_metrics(df, selected_metrics) # Raggruppamento per modello e categoria avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index() avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') # Plot orizzontale con modello sull'asse Y fig = px.bar( avg_metrics, x='avg_metric', y='model', color='db_category', # categoria come colore text='text_label', barmode='group', orientation='h', color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS title='Average metric per model and db_category πŸ“Š', labels={'avg_metric': 'AVG Metric', 'model': 'Model'}, template='plotly_dark' ) fig.update_traces(textposition='outside', textfont_size=10) fig.update_layout( margin=dict(t=80), yaxis=dict(title=''), xaxis=dict(title='AVG Metrics'), legend_title='DB Name', height=600 # puoi aumentare se ci sono tanti modelli ) return gr.Plot(fig, visible=True) """ def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): df = load_data_csv_es() return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models) # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION def lollipop_propietary(selected_models): df = load_data_csv_es() # Filtra solo le categorie rilevanti target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous"] df = df[df['db_category'].isin(target_cats)] df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) df = calculate_average_metrics(df, qatch_metrics) # Calcola la media per db_category e modello avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() # Separa Spider e le altre 4 categorie spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"] other_df = avg_metrics[avg_metrics["db_category"] != "Spider"] # Calcola media delle altre categorie per ciascun modello other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index() other_mean_df["db_category"] = "Others" # Rinominare per chiarezza e uniformitΓ  spider_df = spider_df.rename(columns={"avg_metric": "Spider"}) other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"}) # Unione dei due dataset merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model") # Ordina per modello o per valore se vuoi merged_df = merged_df.sort_values(by="model") fig = go.Figure() # Aggiungi linee orizzontali tra Spider e Others for _, row in merged_df.iterrows(): fig.add_trace(go.Scatter( x=[row["Spider"], row["Others"]], y=[row["model"]] * 2, mode='lines', line=dict(color='gray', width=2), showlegend=False )) # Punto per Spider fig.add_trace(go.Scatter( x=merged_df["Spider"], y=merged_df["model"], mode='markers', name='Non-Proprietary (Spider)', marker=dict(size=10, color='#C84630') )) # Punto per Others (media delle altre 4 categorie) fig.add_trace(go.Scatter( x=merged_df["Others"], y=merged_df["model"], mode='markers', name='Proprietary Databases', marker=dict(size=10, color='#0077B6') )) fig.update_layout( xaxis_title='Average Metrics', yaxis_title='Models', template='simple_white', #template='plotly_dark', margin=dict(t=80), title=dict( font=dict( family="Inter, sans-serif", size=22, color="black" ), x=0.5, text='Dumbbell graph: Non-Proprietary (Spider πŸ•·οΈ) vs Proprietary Databases πŸ“Š' ), legend_title='Type of Databases:', height=600, xaxis=dict( title=dict( text='DB Category', font=dict( family='Inter, sans-serif', size=18, color='black' ) ), tickfont=dict( family='Inter, sans-serif', color='black' ) ), yaxis=dict( title=dict( text='Average Metrics', font=dict( family='Inter, sans-serif', size=18, color='black' ) ), tickfont=dict( family='Inter, sans-serif', color='black' ) ), legend=dict( title=dict( text='Models', font=dict( family='Inter, sans-serif', size=18, color='black' ) ), font=dict( family='Inter, sans-serif', color='black', size=14 ) ) ) return gr.Plot(fig, visible=True) # RADAR OR BAR CHART BASED ON CATEGORY COUNT def plot_radar(df, selected_models, selected_metrics, selected_categories): if "external" in selected_metrics: selected_metrics = ["execution_accuracy", "valid_efficiency_score"] else: selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] # Filtro modelli e normalizzazione df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) df = calculate_average_metrics(df, selected_metrics) avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index() if avg_metrics.empty: print("Error: No data available to compute averages.") return go.Figure() categories = selected_categories if len(categories) < 3: # πŸ”„ BAR PLOT fig = go.Figure() for model in selected_models: model_data = avg_metrics[avg_metrics['model'] == model] values = [ model_data[model_data['test_category'] == cat]['avg_metric'].values[0] if cat in model_data['test_category'].values else 0 for cat in categories ] fig.add_trace(go.Bar( x=categories, y=values, name=model, marker=dict(color=MODEL_COLORS.get(model, "gray")) )) fig.update_layout( barmode='group', title=dict( text='πŸ“Š Bar Plot of Metrics per Model (Few Categories)', font=dict( family='Inter, sans-serif', size=22, #color='white' ), x=0.5 ), template='simple_white', #template='plotly_dark', xaxis=dict( title=dict( text='Test Category', font=dict( family='Inter, sans-serif', size=18, #color='white' ) ), tickfont=dict( family='Inter, sans-serif', size=16 #color='white' ) ), yaxis=dict( title=dict( text='Average Metrics', font=dict( family='Inter, sans-serif', size=18, #color='white' ) ), tickfont=dict( family='Inter, sans-serif', #color='white' ) ), legend=dict( title=dict( text='Models', font=dict( family='Inter, sans-serif', size=16, #color='white' ) ), font=dict( family='Inter, sans-serif', #color='white' ) ) ) else: # 🧭 RADAR PLOT fig = go.Figure() for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): model_data = avg_metrics[avg_metrics['model'] == model] values = [ model_data[model_data['test_category'] == cat]['avg_metric'].values[0] if cat in model_data['test_category'].values else 0 for cat in categories ] fig.add_trace(go.Scatterpolar( r=values, theta=categories, fill='toself', name=model, line=dict(color=MODEL_COLORS.get(model, "gray")) )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], tickfont=dict( family='Inter, sans-serif', #color='white' ) ), angularaxis=dict( tickfont=dict( family='Inter, sans-serif', size=16 #color='white' ) ) ), title=dict( text='❇️ Radar Plot of Metrics per Model (Average per SQL Category)', font=dict( family='Inter, sans-serif', size=22, #color='white' ), x=0.5 ), legend=dict( title=dict( text='Models', font=dict( family='Inter, sans-serif', size=18, #color='white' ) ), font=dict( family='Inter, sans-serif', size=16 #color='white' ) ), template='simple_white' #template='plotly_dark' ) return fig def update_radar(selected_models, selected_metrics, selected_categories): df = load_data_csv_es() return plot_radar(df, selected_models, selected_metrics, selected_categories) # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT def plot_radar_sub(df, selected_models, selected_metrics, selected_category): if "external" in selected_metrics: selected_metrics = ["execution_accuracy", "valid_efficiency_score"] else: selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) df = calculate_average_metrics(df, selected_metrics) if isinstance(selected_category, str): selected_category = [selected_category] df = df[df['test_category'].isin(selected_category)] avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index() if avg_metrics.empty: print("Error: No data available to compute averages.") return go.Figure() categories = df['sql_tag'].unique().tolist() if len(categories) < 3: # πŸ”„ BAR PLOT fig = go.Figure() for model in selected_models: model_data = avg_metrics[avg_metrics['model'] == model] values = [ model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] if cat in model_data['sql_tag'].values else 0 for cat in categories ] fig.add_trace(go.Bar( x=categories, y=values, name=model, marker=dict(color=MODEL_COLORS.get(model, "gray")) )) fig.update_layout( barmode='group', title=dict( text='πŸ“Š Bar Plot of Metrics per Model (Few Sub-Categories)', font=dict( family='Inter, sans-serif', size=22, #color='white' ), x=0.5 ), template='simple_white', #template='plotly_dark', xaxis=dict( title=dict( text='SQL Tag (Sub Category)', font=dict( family='Inter, sans-serif', size=18, #color='white' ) ), tickfont=dict( family='Inter, sans-serif', #color='white' ) ), yaxis=dict( title=dict( text='Average Metrics', font=dict( family='Inter, sans-serif', size=18, #color='white' ) ), tickfont=dict( family='Inter, sans-serif', #color='white' ) ), legend=dict( title=dict( text='Models', font=dict( family='Inter, sans-serif', size=16, #color='white' ) ), font=dict( family='Inter, sans-serif', size=14 #color='white' ) ) ) else: # 🧭 RADAR PLOT fig = go.Figure() for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): model_data = avg_metrics[avg_metrics['model'] == model] values = [ model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] if cat in model_data['sql_tag'].values else 0 for cat in categories ] fig.add_trace(go.Scatterpolar( r=values, theta=categories, fill='toself', name=model, line=dict(color=MODEL_COLORS.get(model, "gray")) )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], tickfont=dict( family='Inter, sans-serif', #color='white' ) ), angularaxis=dict( tickfont=dict( family='Inter, sans-serif', size=16 #color='white' ) ) ), title=dict( text='❇️ Radar Plot of Metrics per Model (Average per SQL Sub-Category)', font=dict( family='Inter, sans-serif', size=22, #color='white' ), x=0.5 ), legend=dict( title=dict( text='Models', font=dict( family='Inter, sans-serif', size=16, #color='white' ) ), font=dict( family='Inter, sans-serif', size=14, #color='white' ) ), template='simple_white' #template='plotly_dark' ) return fig def update_radar_sub(selected_models, selected_metrics, selected_category): df = load_data_csv_es() return plot_radar_sub(df, selected_models, selected_metrics, selected_category) # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION def worst_cases_text(df, selected_models, selected_metrics, selected_categories): if selected_models == "All": selected_models = models else: selected_models = [selected_models] if selected_categories == "All": selected_categories = principal_categories else: selected_categories = [selected_categories] df = df[df['model'].isin(selected_models)] df = df[df['test_category'].isin(selected_categories)] if "external" in selected_metrics: selected_metrics = ["execution_accuracy", "valid_efficiency_score"] else: selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] df = normalize_valid_efficiency_score(df) df = calculate_average_metrics(df, selected_metrics) worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True) worst_cases_top_3 = worst_cases_df.head(3) worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2) worst_str = [] answer_str = [] medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"] for i, row in worst_cases_top_3.iterrows(): entry = ( f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" f"- Question: {row['question']} \n" f"- Original Query: `{row['query']}` \n" f"- Predicted SQL: `{row['predicted_sql']}` \n\n" ) worst_str.append(entry) raw_answer = ( f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']}) \n" f"- Raw Answer:
`{row['answer']}`
\n" ) answer_str.append(raw_answer) return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2] def update_worst_cases_text(selected_models, selected_metrics, selected_categories): df = load_data_csv_es() return worst_cases_text(df, selected_models, selected_metrics, selected_categories) # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION def plot_cumulative_flow(df, selected_models, max_points): df = df[df['model'].isin(selected_models)] df = normalize_valid_efficiency_score(df) fig = go.Figure() for model in selected_models: model_df = df[df['model'] == model].copy() # Limita il numero di punti se richiesto if max_points is not None: model_df = model_df.head(max_points + 1) # Tooltip personalizzato model_df['hover_info'] = model_df.apply( lambda row: f"Id question: {row['number_question']}
" f"Question: {row['question']}
" f"Target: {row['query']}
" f"Prediction: {row['predicted_sql']}
" f"Category: {row['test_category']}", axis=1 ) # Calcoli cumulativi model_df['cumulative_time'] = model_df['time'].cumsum() model_df['cumulative_price'] = model_df['price'].cumsum() # Colore del modello color = MODEL_COLORS.get(model, "gray") fig.add_trace(go.Scatter( x=model_df['cumulative_time'], y=model_df['cumulative_price'], mode='lines+markers', name=model, line=dict(width=2, color=color), customdata=model_df['hover_info'], hovertemplate= "Model: " + model + "
" + "Cumulative Time: %{x}s
" + "Cumulative Price: $%{y:.2f}
" + "
Details:
%{customdata}" )) # Layout con font elegante fig.update_layout( title=dict( text="Cumulative Price Flow Chart πŸ’°", font=dict( family="Inter, sans-serif", size=24, #color="white" ), x=0.5 ), xaxis=dict( title=dict( text="Cumulative Time (s)", font=dict( family="Inter, sans-serif", size=20, #color="white" ) ), tickfont=dict( family="Inter, sans-serif", size=18 #color="white" ) ), yaxis=dict( title=dict( text="Cumulative Price ($)", font=dict( family="Inter, sans-serif", size=20, #color="white" ) ), tickfont=dict( family="Inter, sans-serif", size=18 #color="white" ) ), legend=dict( title=dict( text="Models", font=dict( family="Inter, sans-serif", size=18, #color="white" ) ), font=dict( family="Inter, sans-serif", size=16, #color="white" ) ), template='simple_white', #template="plotly_dark" ) return fig def update_query_rate(selected_models, max_points): df = load_data_csv_es() return plot_cumulative_flow(df, selected_models, max_points) ####################### # PARAMETER SECTION # ####################### qatch_metrics_dict = { "Cell Precision": "cell_precision", "Cell Recall": "cell_recall", "Tuple Order": "tuple_order", "Tuple Cardinality": "tuple_cardinality", "Tuple Constraint": "tuple_constraint" } qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare l’ultima selezione valida def enforce_qatch_metrics_selection(selected): global last_valid_qatch_metrics_selection if not selected: # Se nessuna metrica Γ¨ selezionata return gr.update(value=last_valid_qatch_metrics_selection) last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida return gr.update(value=selected) external_metrics_dict = { "Execution Accuracy": "execution_accuracy", "Valid Efficiency Score": "valid_efficiency_score" } external_metric = ["execution_accuracy", "valid_efficiency_score"] last_valid_external_metric_selection = external_metric.copy() def enforce_external_metric_selection(selected): global last_valid_external_metric_selection if not selected: # Se nessuna metrica Γ¨ selezionata return gr.update(value=last_valid_external_metric_selection) last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida return gr.update(value=selected) all_metrics = { "Qatch": ["qatch"], "External": ["external"] } group_options = { "Table": ["tbl_name", "model"], "Model": ["model"] } df_initial = load_data_csv_es() models = models = df_initial['model'].unique().tolist() last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida def enforce_model_selection(selected): global last_valid_model_selection if not selected: # Se nessuna metrica Γ¨ selezionata return gr.update(value=last_valid_model_selection) last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida return gr.update(value=selected) all_categories = df_initial['sql_tag'].unique().tolist() principal_categories = df_initial['test_category'].unique().tolist() last_valid_category_selection = principal_categories.copy() # Per salvare l’ultima selezione valida def enforce_category_selection(selected): global last_valid_category_selection if not selected: # Se nessuna metrica Γ¨ selezionata return gr.update(value=last_valid_category_selection) last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida return gr.update(value=selected) all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories} all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories} all_categories_as_dic_ranking["All"] = principal_categories all_model_as_dic = {cat: [f"{cat}"] for cat in models} all_model_as_dic["All"] = models #with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo: ########################### # VISUALIZATION SECTION # ########################### gr.Markdown("""# Model Performance Analysis""") #FOR BAR gr.Markdown("""## Section 1: Model - Data""") with gr.Row(): with gr.Column(scale=1): with gr.Row(): choose_metrics_bar = gr.Radio( choices=list(all_metrics.keys()), label="Select the metrics group that you want to use:", value="Qatch" ) with gr.Row(): qatch_info = gr.HTML("""
Qatch metric info ℹ️
""", visible=True) external_info = gr.HTML("""
External metric info ℹ️
""", visible=False) qatch_metric_multiselect_bar = gr.CheckboxGroup( choices=list(qatch_metrics_dict.keys()), label="Select one or mode Qatch metrics:", value=list(qatch_metrics_dict.keys()), visible=True ) external_metric_select_bar = gr.CheckboxGroup( choices=list(external_metrics_dict.keys()), label="Select one or more External metrics:", visible=False ) if(input_data['input_method'] == 'default'): model_radio_bar = gr.Radio( choices=list(all_model_as_dic.keys()), label="Select the model that you want to use:", value="All" ) else: model_multiselect_bar = gr.CheckboxGroup( choices=models, label="Select one or more models:", value=models, interactive=len(models) > 1 ) group_radio = gr.Radio( choices=list(group_options.keys()), label="Select the grouping view:", value="Table" ) def toggle_metric_selector(selected_type): if selected_type == "Qatch": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[]) else: return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys())) output_plot = gr.Plot(visible=False) if(input_data['input_method'] == 'default'): with gr.Row(): lollipop_propietary(models) #FOR RADAR gr.Markdown("""## Section 2: Model - Category""") with gr.Row(): all_metrics_radar = gr.Radio( choices=list(all_metrics.keys()), label="Select the metrics group that you want to use:", value="Qatch" ) model_multiselect_radar = gr.CheckboxGroup( choices=models, label="Select one or more models:", value=models, interactive=len(models) > 1 ) with gr.Row(): with gr.Column(scale=1): category_multiselect_radar = gr.CheckboxGroup( choices=principal_categories, label="Select one or more categories:", value=principal_categories ) with gr.Column(scale=1): category_radio_radar = gr.Radio( choices=list(all_categories_as_dic.keys()), label="Select the metrics that you want to use:", value=list(all_categories_as_dic.keys())[0] ) with gr.Row(): with gr.Column(scale=1): radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories)) with gr.Column(scale=1): radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0])) #FOR RANKING with gr.Row(): all_metrics_ranking = gr.Radio( choices=list(all_metrics.keys()), label="Select the metrics group that you want to use:", value="Qatch" ) model_choices = list(all_model_as_dic.keys()) if len(model_choices) == 2: model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione selected_value = model_choices[0] else: selected_value = "All" model_radio_ranking = gr.Radio( choices=model_choices, label="Select the model that you want to use:", value=selected_value ) category_radio_ranking = gr.Radio( choices=list(all_categories_as_dic_ranking.keys()), label="Select the category that you want to use", value="All" ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## ❌ 3 Worst Cases\n") worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All") with gr.Row(): first = gr.Markdown(worst_first) with gr.Row(): first_button = gr.Button("Show raw answer for πŸ₯‡") with gr.Row(): second = gr.Markdown(worst_second) with gr.Row(): second_button = gr.Button("Show raw answer for πŸ₯ˆ") with gr.Row(): third = gr.Markdown(worst_third) with gr.Row(): third_button = gr.Button("Show raw answer for πŸ₯‰") with gr.Column(scale=1): gr.Markdown("""## Raw Answer""") row_answer_first = gr.Markdown(value=raw_first, visible=True) row_answer_second = gr.Markdown(value=raw_second, visible=False) row_answer_third = gr.Markdown(value=raw_third, visible=False) #FOR RATE gr.Markdown("""## Section 3: Time - Price""") with gr.Row(): model_multiselect_rate = gr.CheckboxGroup( choices=models, label="Select one or more models:", value=models, interactive=len(models) > 1 ) with gr.Row(): slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider") query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique()))) #FOR RESET reset_data = gr.Button("Back to upload data section") ############################### # CALLBACK FUNCTION SECTION # ############################### #FOR BAR def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models): return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models) def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models): return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models) #FOR RADAR def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories): return update_radar(selected_models, selected_metrics, selected_categories) def on_radar_radio_change(selected_models, selected_metrics, selected_category): return update_radar_sub(selected_models, selected_metrics, selected_category) #FOR RANKING def on_ranking_change(selected_models, selected_metrics, selected_categories): return update_worst_cases_text(selected_models, selected_metrics, selected_categories) def show_first(): return ( gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) ) def show_second(): return ( gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) ) def show_third(): return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) ) ###################### # ON CLICK SECTION # ###################### #FOR BAR if(input_data['input_method'] == 'default'): proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) else: proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar) choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) #FOR RADAR MULTISELECT model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar) category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar) #FOR RADAR RADIO model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) #FOR RANKING model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking) category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking) first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third]) third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third]) #FOR RATE model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate) slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) #FOR RESET reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics]) reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection]) interface.launch(share = True)