Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| import re | |
| import csv | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import plotly.colors as pc | |
| from qatch.connectors.sqlite_connector import SqliteConnector | |
| from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator | |
| from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator | |
| import qatch.evaluate_dataset.orchestrator_evaluator as eva | |
| from prediction import ModelPrediction | |
| import utils_get_db_tables_info | |
| import utilities as us | |
| # @spaces.GPU | |
| # def model_prediction(): | |
| # pass | |
| # # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 | |
| # if os.environ.get("SPACES_ZERO_GPU") is not None: | |
| # import spaces | |
| # else: | |
| # class spaces: | |
| # @staticmethod | |
| # def GPU(func): | |
| # def wrapper(*args, **kwargs): | |
| # return func(*args, **kwargs) | |
| # return wrapper | |
| #pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv") | |
| pnp_path = "concatenated_output.csv" | |
| PATH_PKL_TABLES = 'tables_dict_beaver.pkl' | |
| PNP_TQA_PATH = 'concatenated_output_tqa.csv' | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| reset_flag = False | |
| flag_TQA = False | |
| with open('style.css', 'r') as file: | |
| css = file.read() | |
| # DataFrame di default | |
| df_default = pd.DataFrame({ | |
| 'Name': ['Alice', 'Bob', 'Charlie'], | |
| 'Age': [25, 30, 35], | |
| 'City': ['New York', 'Los Angeles', 'Chicago'] | |
| }) | |
| models_path ="models.csv" | |
| # Variabile globale per tenere traccia dei dati correnti | |
| df_current = df_default.copy() | |
| description = """## π Comparison of Proprietary and Non-Proprietary Databases | |
| ### β€ **Proprietary** : | |
| ###     β Economic π°, Medical π₯, Financial π³, Miscellaneous π | |
| ###     β BEAVER (FAC BUILDING ADDRESS π’ , TIME QUARTER β±οΈ) | |
| ### β€ **Non-Proprietary** | |
| ###     β Spider 1.0 π·οΈ""" | |
| prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n" | |
| prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as <answer> </answer>.\n Question \n {question}\n Database Schema\n {db_schema}\n" | |
| input_data = { | |
| 'input_method': "", | |
| 'data_path': "", | |
| 'db_name': "", | |
| 'data': { | |
| 'data_frames': {}, # dictionary of dataframes | |
| 'db': None, # SQLITE3 database object | |
| 'selected_tables' :[] | |
| }, | |
| 'models': [], | |
| 'prompt': prompt_default | |
| } | |
| def load_data(file, path, use_default): | |
| """Carica i dati da un file, un percorso o usa il DataFrame di default.""" | |
| global df_current | |
| if file is not None: | |
| try: | |
| input_data["input_method"] = 'uploaded_file' | |
| input_data["db_name"] = os.path.splitext(os.path.basename(file))[0] | |
| if file.endswith('.sqlite'): | |
| #return 'Error: The uploaded file is not a valid SQLite database.' | |
| input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite") | |
| else: | |
| #change path | |
| input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite") | |
| input_data["data"] = us.load_data(file, input_data["db_name"]) | |
| df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame | |
| if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files | |
| table2primary_key = {} | |
| for table_name, df in input_data["data"]['data_frames'].items(): | |
| # Assign primary keys for each table | |
| table2primary_key[table_name] = 'id' | |
| input_data["data"]["db"] = SqliteConnector( | |
| relative_db_path=input_data["data_path"], | |
| db_name=input_data["db_name"], | |
| tables= input_data["data"]['data_frames'], | |
| table2primary_key=table2primary_key | |
| ) | |
| return input_data["data"]['data_frames'] | |
| except Exception as e: | |
| return f'Errore nel caricamento del file: {e}' | |
| if use_default: | |
| if(use_default == 'Custom'): | |
| input_data["input_method"] = 'custom' | |
| #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite") | |
| input_data["data_path"] = os.path.join(".","mytable_0.sqlite") | |
| #if file already exist | |
| while os.path.exists(input_data["data_path"]): | |
| input_data["data_path"] = us.increment_filename(input_data["data_path"]) | |
| input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0] | |
| input_data["data"]['data_frames'] = {'MyTable': df_current} | |
| if(input_data["data"]['data_frames']): | |
| table2primary_key = {} | |
| for table_name, df in input_data["data"]['data_frames'].items(): | |
| # Assign primary keys for each table | |
| table2primary_key[table_name] = 'id' | |
| input_data["data"]["db"] = SqliteConnector( | |
| relative_db_path=input_data["data_path"], | |
| db_name=input_data["db_name"], | |
| tables= input_data["data"]['data_frames'], | |
| table2primary_key=table2primary_key | |
| ) | |
| df_current = df_default.copy() # Ripristina i dati di default | |
| return input_data["data"]['data_frames'] | |
| if(use_default == 'Proprietary vs Non-proprietary'): | |
| input_data["input_method"] = 'default' | |
| #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite") | |
| #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite") | |
| #input_data["db_name"] = "default" | |
| #input_data["data"]['db'] = SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"]) | |
| input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl(PATH_PKL_TABLES) | |
| return input_data["data"]['data_frames'] | |
| selected_inputs = sum([file is not None, bool(path), use_default]) | |
| if selected_inputs > 1: | |
| return 'Error: Select only one input method at a time.' | |
| return input_data["data"]['data_frames'] | |
| def preview_default(use_default, file): | |
| if file: | |
| return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## β File successfully uploaded!", visible=True) | |
| else : | |
| if use_default == 'Custom': | |
| return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## π Toy Table", visible=True) | |
| else: | |
| return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True) | |
| #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato | |
| def update_df(new_df): | |
| """Aggiorna il DataFrame corrente.""" | |
| global df_current # Usa la variabile globale per aggiornarla | |
| df_current = new_df | |
| return df_current | |
| def open_accordion(target): | |
| # Apre uno e chiude l'altro | |
| if target == "reset": | |
| df_current = df_default.copy() | |
| input_data['input_method'] = "" | |
| input_data['data_path'] = "" | |
| input_data['db_name'] = "" | |
| input_data['data']['data_frames'] = {} | |
| input_data['data']['selected_tables'] = [] | |
| input_data['data']['db'] = None | |
| input_data['models'] = [] | |
| return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None) | |
| elif target == "model_selection": | |
| return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False) | |
| # Interfaccia Gradio | |
| #with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface: | |
| with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Image( | |
| value=os.path.join(".", "qatch_logo.png"), | |
| show_label=False, | |
| container=False, | |
| interactive=False, | |
| show_fullscreen_button=False, | |
| show_download_button=False, | |
| show_share_button=False, | |
| height=150, # in pixel | |
| width=300 | |
| ) | |
| with gr.Column(scale=1): | |
| pass | |
| data_state = gr.State(None) # Memorizza i dati caricati | |
| upload_acc = gr.Accordion("Upload data section", open=True, visible=True) | |
| select_table_acc = gr.Accordion("Select tables section", open=False, visible=False) | |
| select_model_acc = gr.Accordion("Select models section", open=False, visible=False) | |
| qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False) | |
| metrics_acc = gr.Accordion("Metrics section", open=False, visible=False) | |
| ################################# | |
| # DATABASE INSERTION # | |
| ################################# | |
| with upload_acc: | |
| gr.Markdown("## π₯Choose data input method") | |
| with gr.Row(): | |
| default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary') | |
| #default_checkbox = gr.Checkbox(label="Use default DataFrame" | |
| table_default = gr.Markdown(description, visible=True) | |
| preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default) | |
| gr.Markdown("## π Or upload your data") | |
| file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"]) | |
| submit_button = gr.Button("Load Data") # Disabled by default | |
| output = gr.JSON(visible=False) # Dictionary output | |
| # Function to enable the button if there is data to load | |
| def enable_submit(file, use_default): | |
| return gr.update(interactive=bool(file or use_default)) | |
| # Function to uncheck the checkbox if a file is uploaded | |
| def deselect_default(file): | |
| if file: | |
| return gr.update(value='Proprietary vs Non-proprietary') | |
| return gr.update() | |
| def enable_disable_first(enable): | |
| return ( | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable) | |
| ) | |
| # Enable the button when inputs are provided | |
| #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) | |
| #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button]) | |
| # Show preview of the default DataFrame when checkbox is selected | |
| default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) | |
| file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default]) | |
| preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output]) | |
| # Uncheck the checkbox when a file is uploaded | |
| file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox]) | |
| def handle_output(file, use_default): | |
| """Handles the output when the 'Load Data' button is pressed.""" | |
| result = load_data(file, None, use_default) | |
| if isinstance(result, dict): # If result is a dictionary of DataFrames | |
| if len(result) == 1: # If there's only one table | |
| input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys()) | |
| return ( | |
| gr.update(visible=False), # Hide JSON output | |
| result, # Save the data state | |
| gr.update(visible=False), # Hide table selection | |
| result, # Maintain the data state | |
| gr.update(interactive=False), # Disable the submit button | |
| gr.update(visible=True, open=True), # Proceed to select_model_acc | |
| gr.update(visible=True, open=False) | |
| ) | |
| else: | |
| return ( | |
| gr.update(visible=False), | |
| result, | |
| gr.update(open=True, visible=True), | |
| result, | |
| gr.update(interactive=False), | |
| gr.update(visible=False), # Keep current behavior | |
| gr.update(visible=True, open=False) | |
| ) | |
| else: | |
| return ( | |
| gr.update(visible=False), | |
| None, | |
| gr.update(open=False, visible=True), | |
| None, | |
| gr.update(interactive=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True, open=False) | |
| ) | |
| submit_button.click( | |
| fn=handle_output, | |
| inputs=[file_input, default_checkbox], | |
| outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc] | |
| ) | |
| submit_button.click( | |
| fn=enable_disable_first, | |
| inputs=[gr.State(False)], | |
| outputs=[ | |
| preview_output, | |
| submit_button, | |
| file_input, | |
| default_checkbox | |
| ] | |
| ) | |
| ###################################### | |
| # TABLE SELECTION PART # | |
| ###################################### | |
| with select_table_acc: | |
| previous_selection = gr.State([]) | |
| table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[]) | |
| excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False) | |
| table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)] | |
| selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False) | |
| # Model selection button (initially disabled) | |
| open_model_selection = gr.Button("Choose your models", interactive=False) | |
| def update_table_list(data): | |
| """Dynamically updates the list of available tables and excluded ones.""" | |
| if isinstance(data, dict) and data: | |
| table_names = [] | |
| excluded_tables = [] | |
| data_frames = input_data['data'].get('data_frames', {}) | |
| available_tables = [] | |
| for name, df in data.items(): | |
| df_real = data_frames.get(name, None) | |
| if input_data['input_method'] != "default": | |
| if df_real is not None and df_real.shape[1] > 15: | |
| excluded_tables.append(name) | |
| else: | |
| available_tables.append(name) | |
| else: | |
| available_tables.append(name) | |
| if input_data['input_method'] == "default": | |
| table_names.append("All") | |
| excluded_tables = [] | |
| elif len(available_tables) < 6: | |
| table_names.append("All") | |
| table_names.extend(available_tables) | |
| if excluded_tables and input_data['input_method'] != "default" : | |
| excluded_text = "<b>β οΈ The following tables have more than 15 columns and cannot be selected:</b><br>" + "<br>".join(f"- {t}" for t in excluded_tables) | |
| excluded_visible = True | |
| else: | |
| excluded_text = "" | |
| excluded_visible = False | |
| return [ | |
| gr.update(choices=table_names, value=[]), # CheckboxGroup update | |
| gr.update(value=excluded_text, visible=excluded_visible) # HTML display update | |
| ] | |
| return [ | |
| gr.update(choices=[], value=[]), | |
| gr.update(value="", visible=False) | |
| ] | |
| def show_selected_tables(data, selected_tables): | |
| updates = [] | |
| data_frames = input_data['data'].get('data_frames', {}) | |
| available_tables = [] | |
| for name, df in data.items(): | |
| df_real = data_frames.get(name) | |
| if input_data['input_method'] != "default" : | |
| if df_real is not None and df_real.shape[1] <= 15: | |
| available_tables.append(name) | |
| else: | |
| available_tables.append(name) | |
| input_method = input_data['input_method'] | |
| allow_all = input_method == "default" or len(available_tables) < 6 | |
| selected_set = set(selected_tables) | |
| tables_set = set(available_tables) | |
| if allow_all: | |
| if "All" in selected_set: | |
| selected_tables = ["All"] + available_tables | |
| elif selected_set == tables_set: | |
| selected_tables = [] | |
| else: | |
| selected_tables = [t for t in selected_tables if t in available_tables] | |
| else: | |
| selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5] | |
| tables = {name: data[name] for name in selected_tables if name in data} | |
| for i, (name, df) in enumerate(tables.items()): | |
| updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False)) | |
| for _ in range(len(tables), 50): | |
| updates.append(gr.update(visible=False)) | |
| updates.append(gr.update(interactive=bool(tables))) | |
| if allow_all: | |
| updates.insert(0, gr.update( | |
| choices=["All"] + available_tables, | |
| value=selected_tables | |
| )) | |
| else: | |
| if len(selected_tables) >= 5: | |
| updates.insert(0, gr.update( | |
| choices=selected_tables, | |
| value=selected_tables | |
| )) | |
| else: | |
| updates.insert(0, gr.update( | |
| choices=available_tables, | |
| value=selected_tables | |
| )) | |
| return updates | |
| def show_selected_table_names(data, selected_tables): | |
| """Displays the names of the selected tables when the button is pressed.""" | |
| if selected_tables: | |
| available_tables = list(data.keys()) # Actually available names | |
| if "All" in selected_tables: | |
| selected_tables = available_tables | |
| if (input_data['input_method'] != "default") : selected_tables = [t for t in selected_tables if len(data[t].columns) <= 15] | |
| input_data['data']['selected_tables'] = selected_tables | |
| return gr.update(value=", ".join(selected_tables), visible=False) | |
| return gr.update(value="", visible=False) | |
| # Automatically updates the checkbox list when `data_state` changes | |
| data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info]) | |
| # Updates the visible tables and the button state based on user selections | |
| #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection]) | |
| table_selector.change( | |
| fn=show_selected_tables, | |
| inputs=[data_state, table_selector], | |
| outputs=[table_selector] + table_outputs + [open_model_selection] | |
| ) | |
| # Shows the list of selected tables when "Choose your models" is clicked | |
| open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names]) | |
| open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc]) | |
| reset_data = gr.Button("Back to upload data section") | |
| reset_data.click( | |
| fn=enable_disable_first, | |
| inputs=[gr.State(True)], | |
| outputs=[ | |
| preview_output, | |
| submit_button, | |
| file_input, | |
| default_checkbox | |
| ] | |
| ) | |
| reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) | |
| #################################### | |
| # MODEL SELECTION PART # | |
| #################################### | |
| with select_model_acc: | |
| gr.Markdown("# Model Selection") | |
| # Assume that `us.read_models_csv` also returns the image path | |
| model_list_dict = us.read_models_csv(models_path) | |
| model_list = [model["code"] for model in model_list_dict] | |
| model_images = [model["image_path"] for model in model_list_dict] | |
| model_names = [model["name"] for model in model_list_dict] | |
| # Create a mapping between model_list and model_images_names | |
| model_mapping = dict(zip(model_list, model_names)) | |
| model_mapping_reverse = dict(zip(model_names, model_list)) | |
| model_checkboxes = [] | |
| rows = [] | |
| # Dynamically create checkboxes with images (3 per row) | |
| for i in range(0, len(model_list), 3): | |
| with gr.Row(): | |
| cols = [] | |
| for j in range(3): | |
| if i + j < len(model_list): | |
| model = model_list[i + j] | |
| image_path = model_images[i + j] | |
| with gr.Column(): | |
| gr.Image(image_path, | |
| show_label=False, | |
| container=False, | |
| interactive=False, | |
| show_fullscreen_button=False, | |
| show_download_button=False, | |
| show_share_button=False) | |
| checkbox = gr.Checkbox(label=model_mapping[model], value=False) | |
| model_checkboxes.append(checkbox) | |
| cols.append(checkbox) | |
| rows.append(cols) | |
| selected_models_output = gr.JSON(visible=False) | |
| # Function to get selected models | |
| def get_selected_models(*model_selections): | |
| selected_models = [model for model, selected in zip(model_list, model_selections) if selected] | |
| input_data['models'] = selected_models | |
| button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) | |
| return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state) | |
| # Add the Textbox to the interface | |
| with gr.Row(): | |
| button_prompt_nlsql = gr.Button("Choose NL2SQL task") | |
| button_prompt_tqa = gr.Button("Choose TQA task") | |
| prompt = gr.TextArea( | |
| label="Customise the prompt for selected models here or leave the default one.", | |
| placeholder=prompt_default, | |
| elem_id="custom-textarea" | |
| ) | |
| warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False) | |
| # Submit button (initially disabled) | |
| with gr.Row(): | |
| submit_models_button = gr.Button("Submit Models", interactive=False) | |
| def check_prompt(prompt): | |
| #TODO | |
| missing_elements = [] | |
| if(prompt==""): | |
| global flag_TQA | |
| if not flag_TQA: | |
| input_data["prompt"] = prompt_default | |
| else: | |
| input_data["prompt"] = prompt_default_tqa | |
| button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) | |
| else: | |
| input_data["prompt"] = prompt | |
| if "{db_schema}" not in prompt: | |
| missing_elements.append("{db_schema}") | |
| if "{question}" not in prompt: | |
| missing_elements.append("{question}") | |
| button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"]) | |
| if missing_elements: | |
| return gr.update( | |
| value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>" | |
| f"β Missing {', '.join(missing_elements)} in the prompt β</div>", | |
| visible=True | |
| ), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"]) | |
| return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"]) | |
| prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button]) | |
| # Link checkboxes to selection events | |
| for checkbox in model_checkboxes: | |
| checkbox.change( | |
| fn=get_selected_models, | |
| inputs=model_checkboxes, | |
| outputs=[selected_models_output, select_model_acc, submit_models_button] | |
| ) | |
| prompt.change( | |
| fn=get_selected_models, | |
| inputs=model_checkboxes, | |
| outputs=[selected_models_output, select_model_acc, submit_models_button] | |
| ) | |
| submit_models_button.click( | |
| fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)), | |
| inputs=model_checkboxes, | |
| outputs=[selected_models_output, select_model_acc, qatch_acc] | |
| ) | |
| def change_flag(): | |
| global flag_TQA | |
| flag_TQA = True | |
| def dis_flag(): | |
| global flag_TQA | |
| flag_TQA = False | |
| button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[]) | |
| button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[]) | |
| button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt]) | |
| button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt]) | |
| def enable_disable(enable): | |
| return ( | |
| *[gr.update(interactive=enable) for _ in model_checkboxes], | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| gr.update(interactive=enable), | |
| *[gr.update(interactive=enable) for _ in table_outputs], | |
| gr.update(interactive=enable) | |
| ) | |
| reset_data = gr.Button("Back to upload data section") | |
| submit_models_button.click( | |
| fn=enable_disable, | |
| inputs=[gr.State(False)], | |
| outputs=[ | |
| *model_checkboxes, | |
| submit_models_button, | |
| preview_output, | |
| submit_button, | |
| file_input, | |
| default_checkbox, | |
| table_selector, | |
| *table_outputs, | |
| open_model_selection | |
| ] | |
| ) | |
| reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) | |
| reset_data.click( | |
| fn=enable_disable, | |
| inputs=[gr.State(True)], | |
| outputs=[ | |
| *model_checkboxes, | |
| submit_models_button, | |
| preview_output, | |
| submit_button, | |
| file_input, | |
| default_checkbox, | |
| table_selector, | |
| *table_outputs, | |
| open_model_selection | |
| ] | |
| ) | |
| ############################# | |
| # QATCH EXECUTION # | |
| ############################# | |
| with qatch_acc: | |
| def change_text(text): | |
| return text | |
| loading_symbols= {1:"π", | |
| 2: "π π", | |
| 3: "π π π", | |
| 4: "π π π π", | |
| 5: "π π π π π", | |
| 6: "π π π π π π", | |
| 7: "π π π π π π π", | |
| 8: "π π π π π π π π", | |
| 9: "π π π π π π π π π", | |
| 10:"π π π π π π π π π π", | |
| } | |
| def generate_loading_text(percent): | |
| num_symbols = (round(percent) % 11) + 1 | |
| symbols = loading_symbols.get(num_symbols, "π") | |
| mirrored_symbols = f'<span class="mirrored">{symbols.strip()}</span>' | |
| css_symbols = f'<span class="fish">{symbols.strip()}</span>' | |
| return f""" | |
| <div class='barcontainer'> | |
| {css_symbols} | |
| <span class='loading' style="font-family: 'Inter', sans-serif;"> | |
| Generation {percent}% | |
| </span> | |
| {mirrored_symbols} | |
| </div> | |
| """ | |
| def generate_eval_text(text): | |
| symbols = "π‘ " | |
| mirrored_symbols = f'<span class="mirrored">{symbols.strip()}</span>' | |
| css_symbols = f'<span class="fish">{symbols.strip()}</span>' | |
| return f""" | |
| <div class='barcontainer'> | |
| {css_symbols} | |
| <span class='loading' style="font-family: 'Inter', sans-serif;"> | |
| {text} | |
| </span> | |
| {mirrored_symbols} | |
| </div> | |
| """ | |
| def qatch_flow_nl_sql(): | |
| global reset_flag | |
| global flag_TQA | |
| predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list} | |
| metrics_conc = pd.DataFrame() | |
| columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"] | |
| if (input_data['input_method']=="default"): | |
| #target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv") | |
| target_df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH) | |
| #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list} | |
| target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])] | |
| target_df = target_df[target_df["model"].isin(input_data['models'])] | |
| predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list} | |
| reset_flag = False | |
| for model in input_data['models']: | |
| model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) | |
| yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] | |
| count=1 | |
| for _, row in predictions_dict[model].iterrows(): | |
| #for index, row in target_df.iterrows(): | |
| if (reset_flag == False): | |
| percent_complete = round(count / len(predictions_dict[model]) * 100, 2) | |
| count=count+1 | |
| load_text = f"{generate_loading_text(percent_complete)}" | |
| question = row['question'] | |
| display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div> | |
| <div style='display: flex; align-items: center;'> | |
| <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div> | |
| <div style='font-size: 3rem'>β‘οΈ</div> | |
| </div> | |
| """ | |
| yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] | |
| prediction = row['predicted_sql'] | |
| display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div> | |
| <div style='display: flex; align-items: center;'> | |
| <div style='font-size: 3rem'>β‘οΈ</div> | |
| <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div> | |
| </div> | |
| """ | |
| yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] | |
| yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] | |
| metrics_conc = target_df | |
| if 'valid_efficency_score' not in metrics_conc.columns: | |
| metrics_conc['valid_efficency_score'] = metrics_conc['VES'] | |
| if 'VES' not in metrics_conc.columns: | |
| metrics_conc['VES'] = metrics_conc['valid_efficency_score'] | |
| eval_text = generate_eval_text("End evaluation") | |
| yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list] | |
| else: | |
| global flag_TQA | |
| orchestrator_generator = OrchestratorGenerator() | |
| target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables']) | |
| #create target_df[target_answer] | |
| if flag_TQA : | |
| # if (input_data["prompt"] == prompt_default): | |
| # input_data["prompt"] = prompt_default_tqa | |
| target_df['db_schema'] = target_df.apply( | |
| lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string( | |
| db_id=input_data["db_name"], | |
| base_path=input_data["data_path"], | |
| normalize=False, | |
| sql=row["query"], | |
| get_insert_into=True, | |
| model=None, | |
| prompt=input_data["prompt"].format(question=row["question"], db_schema="") | |
| ), | |
| axis=1 | |
| ) | |
| target_df = us.extract_answer(target_df) | |
| predictor = ModelPrediction() | |
| reset_flag = False | |
| for model in input_data["models"]: | |
| model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None) | |
| yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] | |
| count=0 | |
| for index, row in target_df.iterrows(): | |
| if (reset_flag == False): | |
| percent_complete = round(((index+1) / len(target_df)) * 100, 2) | |
| load_text = f"{generate_loading_text(percent_complete)}" | |
| question = row['question'] | |
| display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div> | |
| <div style='display: flex; align-items: center;'> | |
| <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div> | |
| <div style='font-size: 3rem'>β‘οΈ</div> | |
| </div> | |
| """ | |
| yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] | |
| #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"]) | |
| model_to_send = None if not flag_TQA else model | |
| db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string( | |
| db_id = input_data["db_name"], | |
| base_path = input_data["data_path"], | |
| normalize=False, | |
| sql=row["query"], | |
| get_insert_into=True, | |
| model = model_to_send, | |
| prompt = input_data["prompt"].format(question=question, db_schema=""), | |
| ) | |
| #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples) | |
| prompt_to_send = input_data["prompt"] | |
| #PREDICTION SQL | |
| # TODO add button for QA or SP and pass to .make_prediction parameter TASK | |
| if flag_TQA: task="QA" | |
| else: task="SP" | |
| start_time = time.time() | |
| response = predictor.make_prediction( | |
| question=question, | |
| db_schema=db_schema_text, | |
| model_name=model, | |
| prompt=f"{prompt_to_send}", | |
| task=task | |
| ) | |
| #if flag_TQA: response = {'response_parsed': "[['Alice'],['Bob'],['Charlie']]", 'cost': 0, 'response': "[['Alice'],['Bob'],['Charlie']]"} # TODO remove this line | |
| #else : response = {'response_parsed': "SELECT * FROM 'MyTable'", 'cost': 0, 'response': "SQL_QUERY"} | |
| end_time = time.time() | |
| prediction = response['response_parsed'] | |
| price = response['cost'] | |
| answer = response['response'] | |
| if flag_TQA: | |
| task_string = "Answer" | |
| else: | |
| task_string = "SQL" | |
| display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted {task_string}:</div> | |
| <div style='display: flex; align-items: center;'> | |
| <div style='font-size: 3rem'>β‘οΈ</div> | |
| <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div> | |
| </div> | |
| """ | |
| # Create a new row as dataframe | |
| new_row = pd.DataFrame([{ | |
| 'id': index, | |
| 'question': question, | |
| 'predicted_sql': prediction, | |
| 'time': end_time - start_time, | |
| 'query': row["query"], | |
| 'db_path': input_data["data_path"], | |
| 'price':price, | |
| 'answer': answer, | |
| 'number_question':count, | |
| 'target_answer' : row["target_answer"] if flag_TQA else None, | |
| }]).dropna(how="all") # Remove only completely empty rows | |
| count=count+1 | |
| # TODO: use a for loop | |
| if (flag_TQA) : | |
| new_row['predicted_answer'] = prediction | |
| for col in target_df.columns: | |
| if col not in new_row.columns: | |
| new_row[col] = row[col] | |
| # Update model's prediction dataframe incrementally | |
| if not new_row.empty: | |
| predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True) | |
| # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None | |
| yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] | |
| yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list] | |
| # END | |
| eval_text = generate_eval_text("Evaluation") | |
| yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] | |
| evaluator = OrchestratorEvaluator() | |
| for model in input_data["models"]: | |
| if not flag_TQA: | |
| metrics_df_model = evaluator.evaluate_df( | |
| df=predictions_dict[model], | |
| target_col_name="query", | |
| prediction_col_name="predicted_sql", | |
| db_path_name="db_path" | |
| ) | |
| else: | |
| metrics_df_model = us.evaluate_answer(predictions_dict[model]) | |
| metrics_df_model['model'] = model | |
| metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True) | |
| if 'VES' not in metrics_conc.columns and 'valid_efficency_score' not in metrics_conc.columns: | |
| metrics_conc['VES'] = 0 | |
| metrics_conc['valid_efficency_score'] = 0 | |
| if 'valid_efficency_score' not in metrics_conc.columns: | |
| metrics_conc['valid_efficency_score'] = metrics_conc['VES'] | |
| if 'VES' not in metrics_conc.columns: | |
| metrics_conc['VES'] = metrics_conc['valid_efficency_score'] | |
| eval_text = generate_eval_text("End evaluation") | |
| yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list] | |
| # Loading Bar | |
| with gr.Row(): | |
| # progress = gr.Progress() | |
| variable = gr.Markdown() | |
| # NL -> MODEL -> Generated Query | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Column(): | |
| question_display = gr.Markdown() | |
| with gr.Column(): | |
| model_logo = gr.Image(visible=True, | |
| show_label=False, | |
| container=False, | |
| interactive=False, | |
| show_fullscreen_button=False, | |
| show_download_button=False, | |
| show_share_button=False) | |
| with gr.Column(): | |
| with gr.Column(): | |
| prediction_display = gr.Markdown() | |
| dataframe_per_model = {} | |
| with gr.Tabs() as model_tabs: | |
| tab_dict = {} | |
| for model, model_name in zip(model_list, model_names): | |
| with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab: | |
| gr.Markdown(f"**Results for {model}**") | |
| tab_dict[model] = tab | |
| dataframe_per_model[model] = gr.DataFrame() | |
| #TODO download metrics per model | |
| # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False) | |
| evaluation_loading = gr.Markdown() | |
| def change_tab(): | |
| return [gr.update(visible=(model in input_data["models"])) for model in model_list] | |
| submit_models_button.click( | |
| change_tab, | |
| inputs=[], | |
| outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility | |
| ) | |
| selected_models_display = gr.JSON(label="Final input data", visible=False) | |
| metrics_df = gr.DataFrame(visible=False) | |
| metrics_df_out = gr.DataFrame(visible=False) | |
| submit_models_button.click( | |
| fn=qatch_flow_nl_sql, | |
| inputs=[], | |
| outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values()) | |
| ) | |
| submit_models_button.click( | |
| fn=lambda: gr.update(value=input_data), | |
| outputs=[selected_models_display] | |
| ) | |
| # Works for METRICS | |
| metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out]) | |
| proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False) | |
| proceed_to_metrics_button.click( | |
| fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)), | |
| outputs=[qatch_acc, metrics_acc] | |
| ) | |
| def allow_download(metrics_df_out): | |
| #path = os.path.join(".", "data", "data_results", "results.csv") | |
| path = os.path.join(".", "results.csv") | |
| metrics_df_out.to_csv(path, index=False) | |
| return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True) | |
| download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False) | |
| submit_models_button.click( | |
| fn=lambda: gr.update(visible=False), | |
| outputs=[download_metrics] | |
| ) | |
| def refresh(): | |
| global reset_flag | |
| global flag_TQA | |
| reset_flag = True | |
| flag_TQA = False | |
| reset_data = gr.Button("Back to upload data section", interactive=True) | |
| metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data]) | |
| reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) | |
| #WHY NOT WORKING? | |
| reset_data.click( | |
| fn=lambda: gr.update(visible=False), | |
| outputs=[download_metrics] | |
| ) | |
| reset_data.click(refresh) | |
| reset_data.click( | |
| fn=enable_disable, | |
| inputs=[gr.State(True)], | |
| outputs=[ | |
| *model_checkboxes, | |
| submit_models_button, | |
| preview_output, | |
| submit_button, | |
| file_input, | |
| default_checkbox, | |
| table_selector, | |
| *table_outputs, | |
| open_model_selection | |
| ] | |
| ) | |
| ########################################## | |
| # METRICS VISUALIZATION SECTION # | |
| ########################################## | |
| with metrics_acc: | |
| #data_path = 'test_results_metrics1.csv' | |
| def function_metrics(metrics_df_out): | |
| #################################### | |
| # UTILS FUNCTIONS SECTION # | |
| #################################### | |
| def load_data_csv_es(): | |
| if input_data["input_method"]=="default": | |
| global flag_TQA | |
| #df = pd.read_csv(pnp_path) | |
| df = us.load_csv(pnp_path) if not flag_TQA else us.load_csv(PNP_TQA_PATH) | |
| df = df[df['model'].isin(input_data["models"])] | |
| df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])] | |
| df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B') | |
| df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5') | |
| df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini') | |
| df['model'] = df['model'].replace('llama-70', 'Llama-70B') | |
| df['model'] = df['model'].replace('llama-8', 'Llama-8B') | |
| df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY') | |
| #if (flag_TQA) : flag_TQA = False #TODO delete after make pred | |
| return df | |
| return metrics_df_out | |
| def calculate_average_metrics(df, selected_metrics): | |
| # Exclude the 'tuple_order' column from the selected metrics | |
| #TODO tuple_order has NULL VALUE | |
| selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order'] | |
| #print(df[selected_metrics]) | |
| df['avg_metric'] = df[selected_metrics].mean(axis=1) | |
| return df | |
| def generate_model_colors(): | |
| """Generates a unique color map for models in the dataset.""" | |
| df = load_data_csv_es() | |
| unique_models = df['model'].unique() # Extract unique models | |
| num_models = len(unique_models) | |
| # Use the Plotly color scale (you can change it if needed) | |
| color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC'] | |
| #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...] | |
| # If there are more models than colors, cycle through them | |
| colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)} | |
| return colors | |
| MODEL_COLORS = generate_model_colors() | |
| def generate_db_category_colors(): | |
| """Assigns 3 distinct colors to db_category groups.""" | |
| return { | |
| "Spider": "#1f77b4", # blu | |
| "Beaver": "#ff7f0e", # arancione | |
| "Economic": "#2ca02c", # tutti gli altri verdi | |
| "Financial": "#2ca02c", | |
| "Medical": "#2ca02c", | |
| "Miscellaneous": "#2ca02c" | |
| } | |
| DB_CATEGORY_COLORS = generate_db_category_colors() | |
| def normalize_valid_efficency_score(df): | |
| df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0) | |
| df['valid_efficency_score'] = df['valid_efficency_score'].astype(int) | |
| min_val = df['valid_efficency_score'].min() | |
| max_val = df['valid_efficency_score'].max() | |
| if min_val == max_val : | |
| # All values are equal, so for avoid division by zero, we set the score to 1/0 | |
| if min_val == None: | |
| df['valid_efficency_score'] = 0 | |
| else: | |
| df['valid_efficency_score'] = 1.0 | |
| else: | |
| df['valid_efficency_score'] = ( | |
| df['valid_efficency_score'] - min_val | |
| ) / (max_val - min_val) | |
| return df | |
| #################################### | |
| # GRAPH FUNCTIONS SECTION # | |
| #################################### | |
| # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION | |
| def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models): | |
| df = df[df['model'].isin(selected_models)] | |
| df = normalize_valid_efficency_score(df) | |
| # Mappatura nomi leggibili -> tecnici | |
| qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] | |
| external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] | |
| selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal | |
| df = calculate_average_metrics(df, selected_metrics) | |
| if group_by == ["model"]: | |
| # Bar plot per "model" | |
| avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index() | |
| avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') | |
| fig = px.bar( | |
| avg_metrics, | |
| x="model", | |
| y="avg_metric", | |
| color="model", | |
| color_discrete_map=MODEL_COLORS, | |
| title='Average metrics per Model π§ ', | |
| labels={"model": "Model", "avg_metric": "Average Metrics"}, | |
| template='simple_white', | |
| #template='plotly_dark', | |
| text='text_label' | |
| ) | |
| else: | |
| if group_by != ["tbl_name", "model"]: | |
| group_by = ["tbl_name", "model"] | |
| avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index() | |
| avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') | |
| fig = px.bar( | |
| avg_metrics, | |
| x=group_by[0], | |
| y='avg_metric', | |
| color='model', | |
| color_discrete_map=MODEL_COLORS, | |
| barmode='group', | |
| title=f'Average metrics per {group_by[0]} π', | |
| labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'}, | |
| template='simple_white', | |
| #template='plotly_dark', | |
| text='text_label' | |
| ) | |
| fig.update_traces(textposition='outside', textfont_size=10) | |
| # Applica font Inter a tutto il layout | |
| fig.update_layout( | |
| margin=dict(t=80), | |
| title=dict( | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=22, | |
| #color="white" | |
| ), | |
| x=0.5 | |
| ), | |
| xaxis=dict( | |
| title=dict( | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=18, | |
| #color="white" | |
| ) | |
| ), | |
| tickfont=dict( | |
| family="Inter, sans-serif", | |
| #color="white" | |
| size=16 | |
| ) | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=18, | |
| #color="white" | |
| ) | |
| ), | |
| tickfont=dict( | |
| family="Inter, sans-serif", | |
| #color="white" | |
| ) | |
| ), | |
| legend=dict( | |
| title=dict( | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=16, | |
| #color="white" | |
| ) | |
| ), | |
| font=dict( | |
| family="Inter, sans-serif", | |
| #color="white" | |
| ) | |
| ) | |
| ) | |
| return gr.Plot(fig, visible=True) | |
| def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models): | |
| df = load_data_csv_es() | |
| return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models) | |
| # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION | |
| def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): | |
| if selected_models == "All": | |
| selected_models = models | |
| else: | |
| selected_models = [selected_models] | |
| df = df[df['model'].isin(selected_models)] | |
| df = normalize_valid_efficency_score(df) | |
| # Converti nomi leggibili -> tecnici | |
| qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics] | |
| external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric] | |
| selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal | |
| df = calculate_average_metrics(df, selected_metrics) | |
| avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() | |
| avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}') | |
| fig = px.bar( | |
| avg_metrics, | |
| x='db_category', | |
| y='avg_metric', | |
| color='model', | |
| color_discrete_map=MODEL_COLORS, | |
| barmode='group', | |
| title='Average metrics per database types π', | |
| labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'}, | |
| template='simple_white', | |
| text='text_label' | |
| ) | |
| fig.update_traces(textposition='outside', textfont_size=14) | |
| # Aggiorna layout con font Inter | |
| fig.update_layout( | |
| margin=dict(t=80), | |
| title=dict( | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=24, | |
| color="black" | |
| ), | |
| x=0.5 | |
| ), | |
| xaxis=dict( | |
| title=dict( | |
| text='Database Category', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=22, | |
| color='black' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| color='black', | |
| size=20 | |
| ) | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text='Average Metrics', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=22, | |
| color='black' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| color='black' | |
| ) | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text='Models', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=20, | |
| color='black' | |
| ) | |
| ), | |
| font=dict( | |
| family='Inter, sans-serif', | |
| color='black', | |
| size=18 | |
| ) | |
| ) | |
| ) | |
| return gr.Plot(fig, visible=True) | |
| def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models): | |
| df = load_data_csv_es() | |
| return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models) | |
| # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION | |
| def lollipop_propietary(selected_models): | |
| df = load_data_csv_es() | |
| # Filtra solo le categorie rilevanti | |
| target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous", "Beaver"] | |
| df = df[df['db_category'].isin(target_cats)] | |
| df = df[df['model'].isin(selected_models)] | |
| df = normalize_valid_efficency_score(df) | |
| df = calculate_average_metrics(df, qatch_metrics) | |
| # Calcola la media per db_category e modello | |
| avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index() | |
| # Separa Spider e le altre 4 categorie | |
| spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"] | |
| other_df = avg_metrics[avg_metrics["db_category"] != "Spider"] | |
| # Calcola media delle altre categorie per ciascun modello | |
| other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index() | |
| other_mean_df["db_category"] = "Others" | |
| # Rinominare per chiarezza e uniformitΓ | |
| spider_df = spider_df.rename(columns={"avg_metric": "Spider"}) | |
| other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"}) | |
| # Unione dei due dataset | |
| merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model") | |
| # Ordina per modello o per valore se vuoi | |
| merged_df = merged_df.sort_values(by="model") | |
| fig = go.Figure() | |
| # Aggiungi linee orizzontali tra Spider e Others | |
| for _, row in merged_df.iterrows(): | |
| fig.add_trace(go.Scatter( | |
| x=[row["Spider"], row["Others"]], | |
| y=[row["model"]] * 2, | |
| mode='lines', | |
| line=dict(color='gray', width=2), | |
| showlegend=False | |
| )) | |
| # Punto per Spider | |
| fig.add_trace(go.Scatter( | |
| x=merged_df["Spider"], | |
| y=merged_df["model"], | |
| mode='markers', | |
| name='Non-Proprietary (Spider)', | |
| marker=dict(size=10, color='#C84630') | |
| )) | |
| # Punto per Others (media delle altre 4 categorie) | |
| fig.add_trace(go.Scatter( | |
| x=merged_df["Others"], | |
| y=merged_df["model"], | |
| mode='markers', | |
| name='Proprietary Databases', | |
| marker=dict(size=10, color='#0077B6') | |
| )) | |
| fig.update_layout( | |
| xaxis_title='Average Metrics', | |
| yaxis_title='Models', | |
| template='simple_white', | |
| #template='plotly_dark', | |
| margin=dict(t=80), | |
| title=dict( | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=22, | |
| color="black" | |
| ), | |
| x=0.5, | |
| text='Dumbbell graph: Non-Proprietary (Spider π·οΈ) vs Proprietary Databases π' | |
| ), | |
| legend_title='Type of Databases:', | |
| height=600, | |
| xaxis=dict( | |
| title=dict( | |
| text='DB Category', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| color='black' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| color='black' | |
| ) | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text='Average Metrics', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| color='black' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| color='black' | |
| ) | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text='Models', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| color='black' | |
| ) | |
| ), | |
| font=dict( | |
| family='Inter, sans-serif', | |
| color='black', | |
| size=14 | |
| ) | |
| ) | |
| ) | |
| return gr.Plot(fig, visible=True) | |
| # RADAR OR BAR CHART BASED ON CATEGORY COUNT | |
| def plot_radar(df, selected_models, selected_metrics, selected_categories): | |
| if "External" in selected_metrics: | |
| selected_metrics = ["execution_accuracy", "valid_efficency_score"] | |
| else: | |
| selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] | |
| # Filtro modelli e normalizzazione | |
| df = df[df['model'].isin(selected_models)] | |
| df = normalize_valid_efficency_score(df) | |
| df = calculate_average_metrics(df, selected_metrics) | |
| avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index() | |
| if avg_metrics.empty: | |
| print("Error: No data available to compute averages.") | |
| return go.Figure() | |
| categories = selected_categories | |
| if len(categories) < 3: | |
| # π BAR PLOT | |
| fig = go.Figure() | |
| for model in selected_models: | |
| model_data = avg_metrics[avg_metrics['model'] == model] | |
| values = [ | |
| model_data[model_data['test_category'] == cat]['avg_metric'].values[0] | |
| if cat in model_data['test_category'].values else 0 | |
| for cat in categories | |
| ] | |
| fig.add_trace(go.Bar( | |
| x=categories, | |
| y=values, | |
| name=model, | |
| marker=dict(color=MODEL_COLORS.get(model, "gray")) | |
| )) | |
| fig.update_layout( | |
| barmode='group', | |
| title=dict( | |
| text='π Bar Plot of Metrics per Model (Few Categories)', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=22, | |
| #color='white' | |
| ), | |
| x=0.5 | |
| ), | |
| template='simple_white', | |
| #template='plotly_dark', | |
| xaxis=dict( | |
| title=dict( | |
| text='Test Category', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| #color='white' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| size=16 | |
| #color='white' | |
| ) | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text='Average Metrics', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| #color='white' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| #color='white' | |
| ) | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text='Models', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=16, | |
| #color='white' | |
| ) | |
| ), | |
| font=dict( | |
| family='Inter, sans-serif', | |
| #color='white' | |
| ) | |
| ) | |
| ) | |
| else: | |
| # π§ RADAR PLOT | |
| fig = go.Figure() | |
| for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): | |
| model_data = avg_metrics[avg_metrics['model'] == model] | |
| values = [ | |
| model_data[model_data['test_category'] == cat]['avg_metric'].values[0] | |
| if cat in model_data['test_category'].values else 0 | |
| for cat in categories | |
| ] | |
| fig.add_trace(go.Scatterpolar( | |
| r=values, | |
| theta=categories, | |
| fill='toself', | |
| name=model, | |
| line=dict(color=MODEL_COLORS.get(model, "gray")) | |
| )) | |
| fig.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| #color='white' | |
| ) | |
| ), | |
| angularaxis=dict( | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| size=16 | |
| #color='white' | |
| ) | |
| ) | |
| ), | |
| title=dict( | |
| text='βοΈ Radar Plot of Metrics per Model (Average per SQL Category)', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=22, | |
| #color='white' | |
| ), | |
| x=0.5 | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text='Models', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| #color='white' | |
| ) | |
| ), | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=16 | |
| #color='white' | |
| ) | |
| ), | |
| template='simple_white' | |
| #template='plotly_dark' | |
| ) | |
| return fig | |
| def update_radar(selected_models, selected_metrics, selected_categories): | |
| df = load_data_csv_es() | |
| return plot_radar(df, selected_models, selected_metrics, selected_categories) | |
| # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT | |
| def plot_radar_sub(df, selected_models, selected_metrics, selected_category): | |
| if "External" in selected_metrics: | |
| selected_metrics = ["execution_accuracy", "valid_efficency_score"] | |
| else: | |
| selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] | |
| df = df[df['model'].isin(selected_models)] | |
| df = normalize_valid_efficency_score(df) | |
| df = calculate_average_metrics(df, selected_metrics) | |
| if isinstance(selected_category, str): | |
| selected_category = [selected_category] | |
| df = df[df['test_category'].isin(selected_category)] | |
| avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index() | |
| if avg_metrics.empty: | |
| print("Error: No data available to compute averages.") | |
| return go.Figure() | |
| categories = df['sql_tag'].unique().tolist() | |
| if len(categories) < 3: | |
| # π BAR PLOT | |
| fig = go.Figure() | |
| for model in selected_models: | |
| model_data = avg_metrics[avg_metrics['model'] == model] | |
| values = [ | |
| model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] | |
| if cat in model_data['sql_tag'].values else 0 | |
| for cat in categories | |
| ] | |
| fig.add_trace(go.Bar( | |
| x=categories, | |
| y=values, | |
| name=model, | |
| marker=dict(color=MODEL_COLORS.get(model, "gray")) | |
| )) | |
| fig.update_layout( | |
| barmode='group', | |
| title=dict( | |
| text='π Bar Plot of Metrics per Model (Few Sub-Categories)', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=22, | |
| #color='white' | |
| ), | |
| x=0.5 | |
| ), | |
| template='simple_white', | |
| #template='plotly_dark', | |
| xaxis=dict( | |
| title=dict( | |
| text='SQL Tag (Sub Category)', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| #color='white' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| #color='white' | |
| ) | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text='Average Metrics', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=18, | |
| #color='white' | |
| ) | |
| ), | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| #color='white' | |
| ) | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text='Models', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=16, | |
| #color='white' | |
| ) | |
| ), | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=14 | |
| #color='white' | |
| ) | |
| ) | |
| ) | |
| else: | |
| # π§ RADAR PLOT | |
| fig = go.Figure() | |
| for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True): | |
| model_data = avg_metrics[avg_metrics['model'] == model] | |
| values = [ | |
| model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0] | |
| if cat in model_data['sql_tag'].values else 0 | |
| for cat in categories | |
| ] | |
| fig.add_trace(go.Scatterpolar( | |
| r=values, | |
| theta=categories, | |
| fill='toself', | |
| name=model, | |
| line=dict(color=MODEL_COLORS.get(model, "gray")) | |
| )) | |
| fig.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, max(avg_metrics['avg_metric'].max(), 0.5)], | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| #color='white' | |
| ) | |
| ), | |
| angularaxis=dict( | |
| tickfont=dict( | |
| family='Inter, sans-serif', | |
| size=16 | |
| #color='white' | |
| ) | |
| ) | |
| ), | |
| title=dict( | |
| text='βοΈ Radar Plot of Metrics per Model (Average per SQL Sub-Category)', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=22, | |
| #color='white' | |
| ), | |
| x=0.5 | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text='Models', | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=16, | |
| #color='white' | |
| ) | |
| ), | |
| font=dict( | |
| family='Inter, sans-serif', | |
| size=14, | |
| #color='white' | |
| ) | |
| ), | |
| template='simple_white' | |
| #template='plotly_dark' | |
| ) | |
| return fig | |
| def update_radar_sub(selected_models, selected_metrics, selected_category): | |
| df = load_data_csv_es() | |
| return plot_radar_sub(df, selected_models, selected_metrics, selected_category) | |
| # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION | |
| def worst_cases_text(df, selected_models, selected_metrics, selected_categories): | |
| global flag_TQA | |
| if selected_models == "All": | |
| selected_models = models | |
| else: | |
| selected_models = [selected_models] | |
| if selected_categories == "All": | |
| selected_categories = principal_categories | |
| else: | |
| selected_categories = [selected_categories] | |
| df = df[df['model'].isin(selected_models)] | |
| df = df[df['test_category'].isin(selected_categories)] | |
| if "external" in selected_metrics: | |
| selected_metrics = ["execution_accuracy", "valid_efficency_score"] | |
| else: | |
| selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] | |
| df = normalize_valid_efficency_score(df) | |
| df = calculate_average_metrics(df, selected_metrics) | |
| if flag_TQA: | |
| df["target_answer"] = df["target_answer"] = df["target_answer"].apply(lambda x: "[" + ", ".join(map(str, x)) + "]") | |
| worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() | |
| else: | |
| worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index() | |
| worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True) | |
| worst_cases_top_3 = worst_cases_df.head(3) | |
| worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2) | |
| worst_str = [] | |
| answer_str = [] | |
| medals = ["π₯", "π₯", "π₯"] | |
| for i, row in worst_cases_top_3.iterrows(): | |
| if flag_TQA: | |
| entry = ( | |
| f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Original Answer:</b> `{row['target_answer']}`</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Predicted Answer:</b> `{eval(row['predicted_answer'])}`</span> \n\n" | |
| ) | |
| worst_str.append(entry) | |
| else: | |
| entry = ( | |
| f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n" | |
| ) | |
| worst_str.append(entry) | |
| raw_answer = ( | |
| f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n" | |
| f"<span style='font-size:16px;'>- <b>Raw Answer:</b><br> `{row['answer']}`</span> \n" | |
| ) | |
| answer_str.append(raw_answer) | |
| return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2] | |
| def update_worst_cases_text(selected_models, selected_metrics, selected_categories): | |
| df = load_data_csv_es() | |
| return worst_cases_text(df, selected_models, selected_metrics, selected_categories) | |
| # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION | |
| def plot_cumulative_flow(df, selected_models, max_points): | |
| df = df[df['model'].isin(selected_models)] | |
| df = normalize_valid_efficency_score(df) | |
| fig = go.Figure() | |
| for model in selected_models: | |
| model_df = df[df['model'] == model].copy() | |
| # Limita il numero di punti se richiesto | |
| if max_points is not None: | |
| model_df = model_df.head(max_points + 1) | |
| # Tooltip personalizzato | |
| model_df['hover_info'] = model_df.apply( | |
| lambda row: | |
| f"<b>Id question</b>: {row['number_question']}<br>" | |
| f"<b>Question</b>: {row['question']}<br>" | |
| f"<b>Target</b>: {row['query']}<br>" | |
| f"<b>Prediction</b>: {row['predicted_sql']}<br>" | |
| f"<b>Category</b>: {row['test_category']}", | |
| axis=1 | |
| ) | |
| # Calcoli cumulativi | |
| model_df['cumulative_time'] = model_df['time'].cumsum() | |
| model_df['cumulative_price'] = model_df['price'].cumsum() | |
| # Colore del modello | |
| color = MODEL_COLORS.get(model, "gray") | |
| fig.add_trace(go.Scatter( | |
| x=model_df['cumulative_time'], | |
| y=model_df['cumulative_price'], | |
| mode='lines+markers', | |
| name=model, | |
| line=dict(width=2, color=color), | |
| customdata=model_df['hover_info'], | |
| hovertemplate= | |
| "<b>Model:</b> " + model + "<br>" + | |
| "<b>Cumulative Time:</b> %{x}s<br>" + | |
| "<b>Cumulative Price:</b> $%{y:.2f}<br>" + | |
| "<br><b>Details:</b><br>%{customdata}<extra></extra>" | |
| )) | |
| # Layout con font elegante | |
| fig.update_layout( | |
| title=dict( | |
| text="Cumulative Price Flow Chart π°", | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=24, | |
| #color="white" | |
| ), | |
| x=0.5 | |
| ), | |
| xaxis=dict( | |
| title=dict( | |
| text="Cumulative Time (s)", | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=20, | |
| #color="white" | |
| ) | |
| ), | |
| tickfont=dict( | |
| family="Inter, sans-serif", | |
| size=18 | |
| #color="white" | |
| ) | |
| ), | |
| yaxis=dict( | |
| title=dict( | |
| text="Cumulative Price ($)", | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=20, | |
| #color="white" | |
| ) | |
| ), | |
| tickfont=dict( | |
| family="Inter, sans-serif", | |
| size=18 | |
| #color="white" | |
| ) | |
| ), | |
| legend=dict( | |
| title=dict( | |
| text="Models", | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=18, | |
| #color="white" | |
| ) | |
| ), | |
| font=dict( | |
| family="Inter, sans-serif", | |
| size=16, | |
| #color="white" | |
| ) | |
| ), | |
| template='simple_white', | |
| #template="plotly_dark" | |
| ) | |
| return fig | |
| def update_query_rate(selected_models, max_points): | |
| df = load_data_csv_es() | |
| return plot_cumulative_flow(df, selected_models, max_points) | |
| ####################### | |
| # PARAMETER SECTION # | |
| ####################### | |
| qatch_metrics_dict = { | |
| "Cell Precision": "cell_precision", | |
| "Cell Recall": "cell_recall", | |
| "Tuple Order": "tuple_order", | |
| "Tuple Cardinality": "tuple_cardinality", | |
| "Tuple Constraint": "tuple_constraint" | |
| } | |
| qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"] | |
| last_valid_qatch_metrics_selection = qatch_metrics.copy() # Per salvare lβultima selezione valida | |
| def enforce_qatch_metrics_selection(selected): | |
| global last_valid_qatch_metrics_selection | |
| if not selected: # Se nessuna metrica Γ¨ selezionata | |
| return gr.update(value=last_valid_qatch_metrics_selection) | |
| last_valid_qatch_metrics_selection = selected # Altrimenti aggiorna la selezione valida | |
| return gr.update(value=selected) | |
| external_metrics_dict = { | |
| "Execution Accuracy": "execution_accuracy", | |
| "Valid Efficency Score": "valid_efficency_score" | |
| } | |
| external_metric = ["execution_accuracy", "valid_efficency_score"] | |
| last_valid_external_metric_selection = external_metric.copy() | |
| def enforce_external_metric_selection(selected): | |
| global last_valid_external_metric_selection | |
| if not selected: # Se nessuna metrica Γ¨ selezionata | |
| return gr.update(value=last_valid_external_metric_selection) | |
| last_valid_external_metric_selection = selected # Altrimenti aggiorna la selezione valida | |
| return gr.update(value=selected) | |
| all_metrics = { | |
| "Qatch": ["qatch"], | |
| "External": ["external"] | |
| } | |
| group_options = { | |
| "Table": ["tbl_name", "model"], | |
| "Model": ["model"] | |
| } | |
| df_initial = load_data_csv_es() | |
| models = models = df_initial['model'].unique().tolist() | |
| last_valid_model_selection = models.copy() # Per salvare lβultima selezione valida | |
| def enforce_model_selection(selected): | |
| global last_valid_model_selection | |
| if not selected: # Se nessuna metrica Γ¨ selezionata | |
| return gr.update(value=last_valid_model_selection) | |
| last_valid_model_selection = selected # Altrimenti aggiorna la selezione valida | |
| return gr.update(value=selected) | |
| all_categories = df_initial['sql_tag'].unique().tolist() | |
| principal_categories = df_initial['test_category'].unique().tolist() | |
| last_valid_category_selection = principal_categories.copy() # Per salvare lβultima selezione valida | |
| def enforce_category_selection(selected): | |
| global last_valid_category_selection | |
| if not selected: # Se nessuna metrica Γ¨ selezionata | |
| return gr.update(value=last_valid_category_selection) | |
| last_valid_category_selection = selected # Altrimenti aggiorna la selezione valida | |
| return gr.update(value=selected) | |
| all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories} | |
| all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories} | |
| all_categories_as_dic_ranking["All"] = principal_categories | |
| all_model_as_dic = {cat: [f"{cat}"] for cat in models} | |
| all_model_as_dic["All"] = models | |
| ########################### | |
| # VISUALIZATION SECTION # | |
| ########################### | |
| gr.Markdown("""# Model Performance Analysis""") | |
| #FOR BAR | |
| gr.Markdown("""## Section 1: Model - Data""") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| choose_metrics_bar = gr.Radio( | |
| choices=list(all_metrics.keys()), | |
| label="Select the metrics group that you want to use:", | |
| value="Qatch" | |
| ) | |
| with gr.Row(): | |
| qatch_info = gr.HTML(""" | |
| <div style='display: flex; align-items: center; margin-top: -8px; margin-bottom: 12px;'> | |
| <span | |
| title="Qatch metric info: | |
| Cell Precision: Fraction of predicted table cells also in the ground truth result. High means many correct predictions. | |
| Cell Recall: Fraction of ground truth cells retrieved by the prediction. High means relevant cells were captured. | |
| Tuple Constraint: Fraction of ground truth tuples matched exactly in output (schema, values, cardinality). | |
| Tuple Cardinality: Ratio of predicted to ground truth tuples. Checks only tuple count. | |
| Tuple Order: Spearman correlation between predicted and ground truth tuple ranks." | |
| style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;" | |
| >Qatch metric info βΉοΈ</span> | |
| </div> | |
| """, visible=True) | |
| external_info = gr.HTML(""" | |
| <div style='display: flex; align-items: center; margin-top: -8px; margin-bottom: 12px;'> | |
| <span | |
| title="External metric info: | |
| Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise. | |
| Valid Efficency Score: Evaluates the efficency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast." | |
| style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;" | |
| >External metric info βΉοΈ</span> | |
| </div> | |
| """, visible=False) | |
| qatch_metric_multiselect_bar = gr.CheckboxGroup( | |
| choices=list(qatch_metrics_dict.keys()), | |
| label="Select one or mode Qatch metrics:", | |
| value=list(qatch_metrics_dict.keys()), | |
| visible=True | |
| ) | |
| external_metric_select_bar = gr.CheckboxGroup( | |
| choices=list(external_metrics_dict.keys()), | |
| label="Select one or more External metrics:", | |
| visible=False | |
| ) | |
| if(input_data['input_method'] == 'default'): | |
| model_radio_bar = gr.Radio( | |
| choices=list(all_model_as_dic.keys()), | |
| label="Select the model that you want to use:", | |
| value="All" | |
| ) | |
| else: | |
| model_multiselect_bar = gr.CheckboxGroup( | |
| choices=models, | |
| label="Select one or more models:", | |
| value=models, | |
| interactive=len(models) > 1 | |
| ) | |
| group_radio = gr.Radio( | |
| choices=list(group_options.keys()), | |
| label="Select the grouping view:", | |
| value="Table" | |
| ) | |
| def toggle_metric_selector(selected_type): | |
| if selected_type == "Qatch": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[]) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys())) | |
| output_plot = gr.Plot(visible=False) | |
| if(input_data['input_method'] == 'default'): | |
| with gr.Row(): | |
| lollipop_propietary(models) | |
| #FOR RADAR | |
| gr.Markdown("""## Section 2: Model - Category""") | |
| with gr.Row(): | |
| all_metrics_radar = gr.Radio( | |
| choices=list(all_metrics.keys()), | |
| label="Select the metrics group that you want to use:", | |
| value="Qatch" | |
| ) | |
| model_multiselect_radar = gr.CheckboxGroup( | |
| choices=models, | |
| label="Select one or more models:", | |
| value=models, | |
| interactive=len(models) > 1 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| category_multiselect_radar = gr.CheckboxGroup( | |
| choices=principal_categories, | |
| label="Select one or more categories:", | |
| value=principal_categories | |
| ) | |
| with gr.Column(scale=1): | |
| category_radio_radar = gr.Radio( | |
| choices=list(all_categories_as_dic.keys()), | |
| label="Select the metrics that you want to use:", | |
| value=list(all_categories_as_dic.keys())[0] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories)) | |
| with gr.Column(scale=1): | |
| radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0])) | |
| #FOR RANKING | |
| with gr.Row(): | |
| all_metrics_ranking = gr.Radio( | |
| choices=list(all_metrics.keys()), | |
| label="Select the metrics group that you want to use:", | |
| value="Qatch" | |
| ) | |
| model_choices = list(all_model_as_dic.keys()) | |
| if len(model_choices) == 2: | |
| model_choices = [model_choices[0]] # supponiamo che il modello sia in prima posizione | |
| selected_value = model_choices[0] | |
| else: | |
| selected_value = "All" | |
| model_radio_ranking = gr.Radio( | |
| choices=model_choices, | |
| label="Select the model that you want to use:", | |
| value=selected_value | |
| ) | |
| category_radio_ranking = gr.Radio( | |
| choices=list(all_categories_as_dic_ranking.keys()), | |
| label="Select the category that you want to use", | |
| value="All" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## β 3 Worst Cases\n") | |
| worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All") | |
| with gr.Row(): | |
| first = gr.Markdown(worst_first) | |
| with gr.Row(): | |
| first_button = gr.Button("Show raw answer for π₯") | |
| with gr.Row(): | |
| second = gr.Markdown(worst_second) | |
| with gr.Row(): | |
| second_button = gr.Button("Show raw answer for π₯") | |
| with gr.Row(): | |
| third = gr.Markdown(worst_third) | |
| with gr.Row(): | |
| third_button = gr.Button("Show raw answer for π₯") | |
| with gr.Column(scale=1): | |
| gr.Markdown("""## Raw Answer""") | |
| row_answer_first = gr.Markdown(value=raw_first, visible=True) | |
| row_answer_second = gr.Markdown(value=raw_second, visible=False) | |
| row_answer_third = gr.Markdown(value=raw_third, visible=False) | |
| #FOR RATE | |
| gr.Markdown("""## Section 3: Time - Price""") | |
| with gr.Row(): | |
| model_multiselect_rate = gr.CheckboxGroup( | |
| choices=models, | |
| label="Select one or more models:", | |
| value=models, | |
| interactive=len(models) > 1 | |
| ) | |
| with gr.Row(): | |
| slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider") | |
| query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique()))) | |
| #FOR RESET | |
| reset_data = gr.Button("Back to upload data section") | |
| ############################### | |
| # CALLBACK FUNCTION SECTION # | |
| ############################### | |
| #FOR BAR | |
| def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models): | |
| return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models) | |
| def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models): | |
| return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models) | |
| #FOR RADAR | |
| def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories): | |
| return update_radar(selected_models, selected_metrics, selected_categories) | |
| def on_radar_radio_change(selected_models, selected_metrics, selected_category): | |
| return update_radar_sub(selected_models, selected_metrics, selected_category) | |
| #FOR RANKING | |
| def on_ranking_change(selected_models, selected_metrics, selected_categories): | |
| return update_worst_cases_text(selected_models, selected_metrics, selected_categories) | |
| def show_first(): | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False) | |
| ) | |
| def show_second(): | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False) | |
| ) | |
| def show_third(): | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ) | |
| ###################### | |
| # ON CLICK SECTION # | |
| ###################### | |
| #FOR BAR | |
| if(input_data['input_method'] == 'default'): | |
| proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) | |
| qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) | |
| external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) | |
| model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot) | |
| qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) | |
| choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) | |
| external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) | |
| else: | |
| proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) | |
| qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) | |
| external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) | |
| group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) | |
| model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot) | |
| qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar) | |
| model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar) | |
| choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar]) | |
| external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar) | |
| #FOR RADAR MULTISELECT | |
| model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) | |
| all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) | |
| category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect) | |
| model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar) | |
| category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar) | |
| #FOR RADAR RADIO | |
| model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) | |
| all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) | |
| category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio) | |
| #FOR RANKING | |
| model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) | |
| model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) | |
| all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) | |
| all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) | |
| category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third]) | |
| category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) | |
| model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking) | |
| category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking) | |
| first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third]) | |
| second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third]) | |
| third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third]) | |
| #FOR RATE | |
| model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) | |
| proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) | |
| model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate) | |
| slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot) | |
| #FOR RESET | |
| reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input]) | |
| reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics]) | |
| reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection]) | |
| interface.launch(share = True) |